Inspired by minGPT
A PyTorch reimplementation of Reversible Vision Transformer architecture that is prefers simplicity over tricks, hackability over tedious organization, and interpretability over generality.
It is meant to serve as an educational guide for newcomers that are not familiar with the reversible backpropagation algorithm and reversible vision transformer.
The entire Reversible Vision Transformer is implemented from scratch in under <300 lines of pytorch code, including the memory-efficient reversible backpropagation algorithm (<100 lines). Even the driver code is < 150 lines. The repo supports both memory-efficient training and testing on CIFAR-10.
💥 The CVPR 2021 oral talk for a 5-minute introduction to RevViT.
💥 A gentle and in-depth 15 minute introduction to RevViT.
💥 Additional implementations of reversible MViT and Swin for examples of hierarchical transformers.
💥 New implementations of fast, parallelized reversible backpropagation (PaReprop), featured as a spotlight at the workshop on Transformers for Vision @ CVPR 2023.
Simple! 🌟
(if using conda for env, otherwise use pip)
conda create -n revvit python=3.8
conda activate revvit
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia
If you wish to also use RevSwin and RevMViT, also install timm
.
conda install timm=0.9.2
The code organization is also minimal 💫:
-
rev.py
defines the reversible vision model that supports:- The vanilla backpropagation
- The memory-efficient reversible backpropagation
-
main.py
that has the driver code for training on CIFAR-10. By default,--model vit
is enabled. -
fast_rev.py
contains a simplified implementation of fast, parallelized reversible backpropagation (PaReprop). Use--pareprop True
to enable. -
rev_swin.py
contains the reversible Swin Transformer with PaReprop functionality. Use--model swin
to enable. -
rev_mvit.py
andutils.py
contain the reversible MViTv2 with PaReprop functionality. Use--model mvit
to enable.
We currently provide three model options for reversible training: ViT, Swin, and MViT. The architectures in some cases have been simplified to run on CIFAR-10.
python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit
By default, the --model
flag is set to vit
. This will achieve 80%+
validation accuracy on CIFAR-10 from scratch training!
Here are the Training/Validation Logs 💯
python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit --vanilla_bp True
Will train the same network but without memory-efficient backpropagation to the same accuracy as above. Hence, there is no accuracy drop from the memory-efficient reversible backpropagation.
Here are the Training/Validation Logs 💯
python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 4 --n_head 4 --epochs 100 --model swin
This will achieve 80%+
validation accuracy on CIFAR-10 from scratch training for a Reversible Swin!
python main.py --lr 1e-3 --bs 128 --embed_dim 64 --depth 4 --n_head 1 --epochs 100 --model mvit
This will achieve 90%+
validation accuracy on CIFAR-10 from scratch training for a Reversible MViT!
You can find Training/Validation Logs for both Swin and MViT here 💯
👁️ Note: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that a much higher accuracy can be achieved with the same code, using a better chosen model design and optimization parameters. The authors have done no tuning since this repository is meant for understanding code, not pushing performance.
Mixed precision training is also supported and can be enabled by adding --amp True
flag to above commands. Training progresses smoothly and achieves 80%+
validation accuracy on CIFAR-10 similar to training without AMP.
📝 Note: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from 12000
to 14500
( ~20%
). At usual training batch size (128
) there is small gain in GPU training memory (about 4%).
There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in paper (also see below) are obtained in DDP setups (>64
GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast distributed training setup.
For more usecases such as reproducing numbers from original paper, see the full code in PySlowFast that supports
- ImageNet
- Kinetics-400/600/700
- RevViT, all sizes with configs
- RevMViT, all sizes with configs
to state-of-the-art accuracies.