In this repository was implemented homework at the Deep Learning university course. During this homework I implemented DCGAN, CVAE and DDPM. All models are successfully trained with the except for DDPM.
The original task statement.
├── README.md <- Top-level README
├── report-ru.pdf <- report on Russian language
├── requirements.txt <- project requirements
├── train.py <- train code
├── congigs/ <- models configs
│ ├── cvae_mnist_train.json
│ ├── cvae32x32_train.json
│ ├── cvae64x64_train.json
│ ├── cvae256x256_train.json
│ └── dcgan_train.json
│
├── results/... <- images and gifs with results
│
├── scripts/
│ └── create_gif.ipynb <- download images from wandb, convert to gif
│
└── src/
├── datasets/
│ ├── artbench10_32x32.py
│ ├── artbench10_256x256.py
│ ├── base_dataset.py
│ ├── cats_faces.py
│ └── mnist_wrapper.py
│
├── models/
│ ├── cvae/
│ │ └── cvae.py
│ │
│ ├── dcgan/
│ │ └── dcgan.py
│ │
│ ├── ddpm/
│ │ ├── diffusion.py
│ │ └── unet.py
│ ├── __init__.py
│ └── base_model.py
│
├── trainers/
│ ├── __init__.py
│ ├── trainer.py <- cvae and ddpm trainer
│ └── gan_trainer.py <- gan trainer
│
├── utils/
│ ├── __init__.py
│ └── utils.py <- inf_loop, WandbWriter, LocalWriter, images conversion functions
│
├── __init__.py
└── collate.py <- collate_fn, collate_w_target_fn
- To install libraries run from the root directory
pip3 install -r requirements.txt
- I used 3 datasets:
- CatsFaces on Kaggle for DCGAN.
- MNIST and ArtBench for CVAE. MNIST and ArtBench-32x32 will be downloaded automatically. Download the ArtBench-256x256 version here: ArtBench-10 on GitHub, the extended version Kaggle (256x256 + AI gen). ArtBench-64x64 is obtained by resizing the 256x256 version to 64x64.
python3 train.py -c config_path -w True
Generated samples during training | Final generated samples |
---|---|
Reconstructed images during training (real in the second row) | Generated samples during training | Final generated samples |
---|---|---|
CVAE possesses the potential to train on ArtBench-10, but 12 hours on free Kaggle's resources are not enough for such complex dataset.
-
report-ru.pdf
(only Russian). - GAN wandb project.
- CVAE wandb project.
- Timofei Gritsaev
DCGAN implementation was taken from the pytorch DCGAN tutorial.
DDPM impementation was taken from the DL-2 in HSE homework, which is based on the openai/guided-diffusion. Unet implementation was taken from the Pytorch-UNet.