GitXplorerGitXplorer
S

simulation-free-node

public
6 stars
1 forks
0 issues

Commits

List of commits on branch master.
Unverified
b1ea99c3d6bf653abb637681c44bde0e9a946de4

udpate README

SSeminKim committed 3 months ago
Unverified
169a19046d5df5b9b01483008d671cc4c9dec64a

initial commit

SSeminKim committed 3 months ago

README

The README file for this repository.

Simulation-Free Training of Neural ODEs on Paired Data

This repository contains the official implementation of the paper "Simulation-Free Training of Neural ODEs on Paired Data (NeurIPS 2024)"

Semin Kim*, Jaehoon Yoo*, Jinwoo Kim, Yeonwoo Cha, Saehoon Kim, Seunghoon Hong

Paper Link

Setup

To set up the environment, start by installing dependencies listed in requirements.txt. You can also use Docker to streamline the setup process.

  1. Docker Setup:
docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
docker run -it pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel bash
  1. Clone the Repository:
git clone https://github.com/seminkim/simulation-free-node.git
  1. Install Requirements:
pip install -r requirements.txt

Datasets

Place all datasets in the .data directory. By default, this code automatically downloads the MNIST, CIFAR-10, and SVHN datasets into the .data directory.

The UCI dataset, composed of 10 tasks (bostonHousing, concrete, energy, kin8nm, naval-propulsion-plant, power-plant, protein-tertiary-structure, wine-quality-red, yacht, and YearPredictionMSD), can be manually downloaded from the Usage part of the following repository: CARD.

Training

Scripts for training are available for both classification and regression tasks.

Classification

To train a model for a classification task, run:

python main.py fit --config configs/{dataset_name}.yaml --name {exp_name}

Regression

For regression tasks (only supported with UCI datasets), use the following command:

python main.py fit --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num}

In this command, specify the UCI task name and the data split number accordingly.

Inference

Use the following commands for model evaluation.

Classification

python main.py validate --config configs/{dataset_name}.yaml --name {exp_name} --ckpt_path {ckpt_path}

Regression

For UCI regression tasks:

python main.py validate --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} --ckpt_path {ckpt_path}

Checkpoints

Trained checkpoints can be found at release tab of this repository.

Dataset Dopri Acc. Link
MNIST 99.30% Download
SVHN 96.12% Download
CIFAR10 88.89% Download

Additional Notes

Logging

We use wandb to monitor training progress and inference results. The wandb run name will match the argument provided for --name. You can also change the project name by modifying trainer.logger.init_args.project in the configuration file (default value is SFNO_exp).

Running Your Own Experiment

Our code is implented with LightningCLI, so you can simply overwrite the config via command-line arguments to experiment with various settings.

Examples:

# Run MNIST experiment with batch size 128
python main.py fit --config configs/mnist.yaml --name mnist_b128 --data.batch_size 128

# Run SVHN experiment with explicit sampling of $t=0$ with probability 0.01
python main.py fit --config configs/svhn.yaml --name svhn_zero_001 --model.init_args.force_zero_prob 0.01

# Run CIFAR10 experiment with 'concave' dynamics 
python main.py fit --config configs/cifar10.yaml --name cifar10_concave --model.init_args.dynamics concave

Refer to Lightning Trainer documentation for controlling trainer-related configurations (e.g., training steps or logging frequency).

Acknowledgements

This implementation of this code was based on the following repositories: NeuralODE, ANODE, and CARD.

Citation

@article{kim2024simfreenode,
         title={Simulation-Free Training of Neural ODEs on Paired Data}, 
         author={Semin Kim and
                 Jaehoon Yoo and
                 Jinwoo Kim and
                 Yeonwoo Cha and
                 Saehoon Kim and
                 Seunghoon Hong},
         journal={arXiv preprint arXiv:2410.22918},
         year={2024},
         url={https://arxiv.org/abs/2410.22918}, 
}