원하는 이미지를 얻을 때까지 많은 프롬프트로 실험을 해야하기 때문에, 추론 속도도 상당히 중요합니다. 기존에는 3장 생성하는 데, 15초 정도가 걸렸는데, 이제 허깅페이스의 diffusers가 JAX와 TPU 조합으로 이제 8초에 8장씩 생성할 수 있게 되었습니다. 각 TPU 서버는 8개의 TPU 가속기를 가지고 있는데, 이를 병렬로 처리하여 8장을 동시에 처리할 수 있습니다. 단 JAX 코드와 Flax 파이프라인을 사용하기 위한 몇가지 설정이 필요합니다.

위 이미지는 아래 프롬프트로 생성하였습니다.
Digital art, trending on artstation, a illustration of an astronaut riding a unicorn horse which has white horns, red hair and red tail
코랩 따라하기
부연설명
코드에 포함된 몇가지 용어에 대해서 정리해봤습니다.
- JIT : Just In Time Compilation으로 런타임 시에 코드를 컴파일 하는 방식을 말합니다. 컴파일을 한 번 해두면, 매번 인터프리터를 동작할 필요가 없어 속도가 개선됩니다.
- JAX : 자동 미분이 가능한 CPU, GPU, TPU에서 동작하는 numpy 프레임워크입니다. 게다가 JIT로 런타임 시에 코드 컴파일이 가능합니다. 가속기에 최적화된 (게다가 신경망에 핵심 연산인 자동 미분도 되니) 속도가 빠르겠죠?
- Flax : JAX용 고성능 신경망 라이브러리 및 에코시스템 입니다.