GitXplorerGitXplorer
b

pytorch-checkpoint

public
3 stars
2 forks
0 issues

Commits

List of commits on branch master.
Unverified
4372e7034f6f97aeb13a5d1c5f016afbb927ba49

Merge branch 'master' of github.com:bomri/pytorch-checkpoint

bbomri committed 6 years ago
Unverified
436ba2dcdd9dd31c2e526deb54d230813338f5ba

adding the medium blog post to README

bbomri committed 6 years ago
Verified
5a346c05a91b39c74aadc7be57eaa2a898701206

update readme for optimizer saving

bbomri committed 6 years ago
Unverified
dc64b58fa72c614246a83b32cdc75780f2c44fd8

converting a DataParallel model to be able load later without DataParallel

bbomri committed 6 years ago
Unverified
41cdaea208ad182bd6127099e618e210ee369ae6

storing and getting a general value

bbomri committed 6 years ago
Unverified
019223145becca028eb87d8f82c53acb4de0fa19

map model to cpu when loading

bbomri committed 6 years ago

README

The README file for this repository.

pytorch-checkpoint

PyPI version

This package supports saving and loading PyTorch training checkpoints. It is useful when trying the resume model training from a previous step, and can become handy when working with spot instances or when trying to reproduce results.

A model is saved not only with its weights, as one might do for later inference, but the entire state of the model, including the optimizer state and parameters.

In addition, it allows saving metrics and other values generated while training, such as accuracy and loss values. This makes it possible to recreate the learning curves from past values and continue to update them as training proceed.

See accompanying blog post here: Where did I put my loss values?


Prerequisites

Developed with Python 3.7.3, but should be compatible with previous Python version.

pip install torch==1.1.0 torchvision==0.3.0

Installation

pip install pytorchcheckpoint

Usage

from pytorchcheckpoint.checkpoint import CheckpointHandler
checkpoint_handler = CheckpointHandler()

Storing a general value

checkpoint_handler.store_var(var_name='num_of_classes', value=1000)

Reading a general value

num_of_classes = checkpoint_handler.get_var(var_name='num_of_classes')

Storing values and metrics for each epoch/iteration. For example, the loss value:

checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
checkpoint_handler.store_running_var(var_name='loss', iteration=1, value=0.9)
checkpoint_handler.store_running_var(var_name='loss', iteration=2, value=0.8)

Reading stored values for epoch/iteration

loss = checkpoint_handler.get_running_var(var_name='loss', iteration=0)

Storing values and metrics per set: train/valid/test for each epoch/iteration. For example, the top1 value of the train and valid sets:

checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=0, value=80)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=1, value=85)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=2, value=90)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=3, value=91)

checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=0, value=70)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=1, value=75)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=2, value=80)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=3, value=85)

Reading stored values per set: train/valid/test for epoch/iteration

loss = checkpoint_handler.get_running_var_with_header(header='train', var_name='loss', iteration=0)

Save checkpoint:

import torchvision.models as models
from torch import optim
checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
model = models.resnet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint_handler.optimizer = optimizer
path2save = '/tmp'
checkpoint_path = checkpoint_handler.generate_checkpoint_path(path2save=path2save)
checkpoint_handler.save_checkpoint(checkpoint_path=checkpoint_path, iteration=25, model=model)

Load checkpoint:

checkpoint_path = '<checkpoint_path>'
checkpoint_handler = checkpoint_handler.load_checkpoint(checkpoint_path)