GitXplorerGitXplorer
k

minREV

public
119 stars
8 forks
4 issues

Commits

List of commits on branch main.
Verified
b0d95f2302bc756eff9d10d8ef7ed71914328b8b

Merge pull request #11 from tyleryzhu/revmvit

kkarttikeya committed 2 years ago
Unverified
a4a0e1d952a29acbfb51781ad7e0d62ac4e115ec

add pareprop project page ref.

ttyleryzhu committed 2 years ago
Unverified
3eef984c36b587e788ece1c0d0c1d52126bf7b26

fix up README, add AMP.

ttyleryzhu committed 2 years ago
Unverified
89a0876074f354d47777dd2d85127aecdbf37be3

finish PaReprop, revmvit + revswin changes

ttyleryzhu committed 2 years ago
Unverified
86615db2257cf78a564e02c10d5c6d303811eead

delete training scripts

ttyleryzhu committed 2 years ago
Unverified
b3157496e3998fdd349d057ac124c656ca9766b7

initial add of revmvit, revswin w/ pareprop.

ttyleryzhu committed 2 years ago

README

The README file for this repository.

minREV

Inspired by minGPT

A PyTorch reimplementation of Reversible Vision Transformer architecture that is prefers simplicity over tricks, hackability over tedious organization, and interpretability over generality.

It is meant to serve as an educational guide for newcomers that are not familiar with the reversible backpropagation algorithm and reversible vision transformer.

The entire Reversible Vision Transformer is implemented from scratch in under <300 lines of pytorch code, including the memory-efficient reversible backpropagation algorithm (<100 lines). Even the driver code is < 150 lines. The repo supports both memory-efficient training and testing on CIFAR-10.

💥 The CVPR 2021 oral talk for a 5-minute introduction to RevViT.

💥 A gentle and in-depth 15 minute introduction to RevViT.

💥 Additional implementations of reversible MViT and Swin for examples of hierarchical transformers.

💥 New implementations of fast, parallelized reversible backpropagation (PaReprop), featured as a spotlight at the workshop on Transformers for Vision @ CVPR 2023.

Setting Up

Simple! 🌟

(if using conda for env, otherwise use pip)

conda create -n revvit python=3.8
conda activate revvit
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia

If you wish to also use RevSwin and RevMViT, also install timm.

conda install timm=0.9.2

Code Organization

The code organization is also minimal 💫:

  • rev.py defines the reversible vision model that supports:
    • The vanilla backpropagation
    • The memory-efficient reversible backpropagation
  • main.py that has the driver code for training on CIFAR-10. By default, --model vit is enabled.
  • fast_rev.py contains a simplified implementation of fast, parallelized reversible backpropagation (PaReprop). Use --pareprop True to enable.
  • rev_swin.py contains the reversible Swin Transformer with PaReprop functionality. Use --model swin to enable.
  • rev_mvit.py and utils.py contain the reversible MViTv2 with PaReprop functionality. Use --model mvit to enable.

Running CIFAR-10

We currently provide three model options for reversible training: ViT, Swin, and MViT. The architectures in some cases have been simplified to run on CIFAR-10.

Reversible ViT 🍦

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit

By default, the --model flag is set to vit. This will achieve 80%+ validation accuracy on CIFAR-10 from scratch training!

Here are the Training/Validation Logs 💯

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit --vanilla_bp True

Will train the same network but without memory-efficient backpropagation to the same accuracy as above. Hence, there is no accuracy drop from the memory-efficient reversible backpropagation.

Here are the Training/Validation Logs 💯

Reversible Swin 🐬

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 4 --n_head 4 --epochs 100 --model swin

This will achieve 80%+ validation accuracy on CIFAR-10 from scratch training for a Reversible Swin!

Reversible MViT 🏰

python main.py --lr 1e-3 --bs 128 --embed_dim 64 --depth 4 --n_head 1 --epochs 100 --model mvit

This will achieve 90%+ validation accuracy on CIFAR-10 from scratch training for a Reversible MViT!

You can find Training/Validation Logs for both Swin and MViT here 💯

👁️ Note: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that a much higher accuracy can be achieved with the same code, using a better chosen model design and optimization parameters. The authors have done no tuning since this repository is meant for understanding code, not pushing performance.

Mixed precision training

Mixed precision training is also supported and can be enabled by adding --amp True flag to above commands. Training progresses smoothly and achieves 80%+ validation accuracy on CIFAR-10 similar to training without AMP.

📝 Note: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from 12000 to 14500 ( ~20%). At usual training batch size (128) there is small gain in GPU training memory (about 4%).

Distributed Data Parallel Training

There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in paper (also see below) are obtained in DDP setups (>64 GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast distributed training setup.

Running ImageNet, Kinetics-400 and more

For more usecases such as reproducing numbers from original paper, see the full code in PySlowFast that supports

  • ImageNet
  • Kinetics-400/600/700
  • RevViT, all sizes with configs
  • RevMViT, all sizes with configs

to state-of-the-art accuracies.