GitXplorerGitXplorer
f

simple-diffusion

public
110 stars
11 forks
1 issues

Commits

List of commits on branch main.
Verified
c123f983b111937976bf7c484ef0178f9f9ce8cd

Fix double SiLU in ResidualBlock

ffilipbasara0 committed 2 months ago
Unverified
bb23e2571aeef309084d05832af86e180b05830e

Simplified ddim scheduler

ffilipbasara0 committed 7 months ago
Verified
688cde8642553e19195db2b6bf722d668a8dfcf1

Add flash attention

ffilipbasara0 committed 9 months ago
Verified
2589973e06c69c95b7aa2c1e6bc654bb0b49cd65

Merge pull request #3 from filipbasara0/pip-pkg

ffilipbasara0 committed 9 months ago
Unverified
f5d0337e7388f8a3859e341990267587e413b9a2

Add flash attention

ffilipbasara0 committed 9 months ago
Verified
4007be9d71d5d52b3fcda4f8bf6e74dfecd11166

Merge pull request #2 from filipbasara0/pip-pkg

ffilipbasara0 committed a year ago

README

The README file for this repository.

Simple Denoising Diffusion

A minimal implementation of a denoising diffusion uncoditional image generation model in PyTorch. The idea was to test the performance of a very small model on the Oxford Flowers dataset.

Includes the DDIM scheduler and the UNet architecture with residual connections and Attention layers.

Oxford Flowers

flowers

So far, the model was tested on the Oxford Flowers dataset - the results can be seen on the image above. Images were generated with 50 DDIM steps.

The results were surprisingly decent and training unexpectedly smooth, considering the model size.

Training was done for 40k steps, with a batch size of 64. Learning rate was 1e-3 and weight decay was 5e-2. Training took ~6 hours on GTX 1070Ti.

Hidden dims of [16, 32, 64, 128] were used, which resulted in a total of 2,346,835 million params.

To train the model, run the following command:

 python train.py   --dataset_name="huggan/flowers-102-categories"   --resolution=64   --output_dir="trained_models/ddpm-ema-64.pth"   --train_batch_size=16   --num_epochs=121 --gradient_accumulation_steps=1   --learning_rate=1e-4   --lr_warmup_steps=300

Conclusions

  • Skip and residual connections are a must - training doesn't converge without them
  • Attention speeds up convergence and improves the quality of generated samples
  • Normalizing images to N(0,1) didn't yield improvents compared to the standard -1 to 1 normalization
  • Learning rate of 1e-3 resulted in a faster convergence for the smaller models, compared to 1e-4 which is usually used in literature

Improvements

  • Training longer - these models require a lot of iterations. For example, in Diffusion Models Beat GANs on Image Synthesis, iterations ranged between 300K and 4360k!
  • Using bigger models
  • Would like to explore the impact of more diverse augmentations

Future steps

  • Training on huggan/pokemons dataset with a bigger model. This dataset proved to be too difficult for the 2M model
  • Training a model on a custom task