GitXplorerGitXplorer
p

stable-diffusion-jax

public
86 stars
8 forks
6 issues

Commits

List of commits on branch main.
Verified
36c7f00ecc2ad75e51d5ce2fd1ae3c8fa8d2ea50

Update run.py

ppatrickvonplaten committed 2 years ago
Verified
47297f53bb4907f119079654310bfb14134c2714

Merge pull request #7 from patil-suraj/fix-sinusoidal-embeddings

ppatrickvonplaten committed 2 years ago
Verified
6168da8dd197523b31402bc726cbcdec7d282f58

Merge pull request #4 from patil-suraj/scheduler-fix

ppatrickvonplaten committed 2 years ago
Unverified
a4f9b04d814ceca822a02cb9aaf6c7a4b4952c19

Fix sinusoidal embeddings.

ppcuenca committed 2 years ago
Unverified
eb1e6a126e047dd65513aca1223cf9714d2f2786

Set offset to 1 in PNDMScheduler.

ppcuenca committed 2 years ago
Unverified
c5014d49c51a6f279f2d6b4d58ba2903ce050f48

upload all

ppatrickvonplaten committed 2 years ago

README

The README file for this repository.

TODOs:

  • [x] Finish implementing the UNet2D model in modeling_unte2d.py. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
  • [x] Adapt the PNDMScheduler from diffusers for JAX: Use jnp arrays and make it stateless.
  • [x] Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in modeling_vae.py file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
  • [x] Add an inference loop in pipeline_stabel_diffusion. We should able to jit/pmap the loop to deploy on TPUs.