- [x] Finish implementing the
UNet2D
model inmodeling_unte2d.py
. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file. - [x] Adapt the
PNDMScheduler
fromdiffusers
for JAX: Usejnp
arrays and make it stateless. - [x] Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in
modeling_vae.py
file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence. - [x] Add an inference loop in
pipeline_stabel_diffusion
. We should able tojit
/pmap
the loop to deploy on TPUs.
p
stable-diffusion-jax
public
86 stars
8 forks
6 issues
Commits
List of commits on branch main.Verified
36c7f00ecc2ad75e51d5ce2fd1ae3c8fa8d2ea50Update run.py
ppatrickvonplaten committed 2 years ago
Verified
47297f53bb4907f119079654310bfb14134c2714Merge pull request #7 from patil-suraj/fix-sinusoidal-embeddings
ppatrickvonplaten committed 2 years ago
Verified
6168da8dd197523b31402bc726cbcdec7d282f58Merge pull request #4 from patil-suraj/scheduler-fix
ppatrickvonplaten committed 2 years ago
Unverified
a4f9b04d814ceca822a02cb9aaf6c7a4b4952c19Fix sinusoidal embeddings.
ppcuenca committed 2 years ago
Unverified
eb1e6a126e047dd65513aca1223cf9714d2f2786Set offset to 1 in PNDMScheduler.
ppcuenca committed 2 years ago
Unverified
c5014d49c51a6f279f2d6b4d58ba2903ce050f48upload all
ppatrickvonplaten committed 2 years ago