GitXplorerGitXplorer
s

binary-wide-resnet

public
124 stars
15 forks
2 issues

Commits

List of commits on branch master.
Unverified
c6a76d6fdd22e69d5adec82c6f5790d63cefaee1

fix typo

sszagoruyko committed 6 years ago
Verified
e0dd2982a51ec2f14a746ab57150de82b235ac2a

Update README.md

sszagoruyko committed 6 years ago
Unverified
02efd3d62a95f578433f250e00165f6d7fa15c13

use newer autograd function api

sszagoruyko committed 6 years ago
Verified
5d1b6d6356f8c7a3a5c12dd62d22b097d9fbdc4e

Update README.md

sszagoruyko committed 6 years ago
Unverified
6b0a8f2de3da117258b0bcbd9210a8935e3dea84

add files

sszagoruyko committed 6 years ago
Unverified
17b8653943999e87714b71416933f4a16402c21e

add model

sszagoruyko committed 6 years ago

README

The README file for this repository.

1-bit Wide ResNet

PyTorch implementation of training 1-bit Wide ResNets from this paper:

Training wide residual networks for deployment using a single bit for each weight by Mark D. McDonnell at ICLR 2018

https://openreview.net/forum?id=rytNfI1AZ

https://arxiv.org/abs/1802.08530

The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function:

class ForwardSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w):
        return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign()

    @staticmethod
    def backward(ctx, g):
        return g

On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100:

network accuracy (5 runs mean +- std) checkpoint (Mb)
WRN-20-10 80.5 +- 0.24 205 Mb
WRN-20-10-1bit 80.0 +- 0.26 3.5 Mb

Details

Here are the differences with WRN code https://github.com/szagoruyko/wide-residual-networks:

  • BatchNorm has no affine weight and bias parameters
  • First layer has 16 * width channels
  • Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d
  • Downsample is done by F.avg_pool2d + torch.cat instead of strided conv
  • SGD with cosine annealing and warm restarts

I used PyTorch 0.4.1 and Python 3.6 to run the code.

Reproduce WRN-20-10 with 1-bit training on CIFAR-100:

python main.py --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100

Convergence plot (train error in dash):

download

I've also put 3.5 Mb checkpoint with binary weights packed with np.packbits, and a very short script to evaluate it:

python evaluate_packed.py --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100

S3 url to checkpoint: https://s3.amazonaws.com/modelzoo-networks/wrn20-10-1bit-packed.pth.tar