1

8장에 8초! 텍스트입력 이미지 생성모델 - Stable Diffusion+JAX+TPU

AF 김태영
2022.10.14 18:37
2640

원하는 이미지를 얻을 때까지 많은 프롬프트로 실험을 해야하기 때문에, 추론 속도도 상당히 중요합니다. 기존에는 3장 생성하는 데, 15초 정도가 걸렸는데, 이제 허깅페이스의 diffusers가 JAX와 TPU 조합으로 이제 8초에 8장씩 생성할 수 있게 되었습니다. 각 TPU 서버는 8개의 TPU 가속기를 가지고 있는데, 이를 병렬로 처리하여 8장을 동시에 처리할 수 있습니다. 단 JAX 코드와 Flax 파이프라인을 사용하기 위한 몇가지 설정이 필요합니다.

위 이미지는 아래 프롬프트로 생성하였습니다.

Digital art, trending on artstation, a illustration of an astronaut riding a unicorn horse which has white horns, red hair and red tail


코랩 따라하기

허깅페이스 깃헙 : https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_jax_how_to.ipynb

 

부연설명

코드에 포함된 몇가지 용어에 대해서 정리해봤습니다.

  • JIT : Just In Time Compilation으로 런타임 시에 코드를 컴파일 하는 방식을 말합니다. 컴파일을 한 번 해두면, 매번 인터프리터를 동작할 필요가 없어 속도가 개선됩니다.
  • JAX : 자동 미분이 가능한 CPU, GPU, TPU에서 동작하는 numpy 프레임워크입니다. 게다가 JIT로 런타임 시에 코드 컴파일이 가능합니다. 가속기에 최적화된 (게다가 신경망에 핵심 연산인 자동 미분도 되니)  속도가 빠르겠죠?
  • Flax : JAX용 고성능 신경망 라이브러리 및 에코시스템 입니다. 
1
0개의 댓글
로그인 후 이용해주세요!