GitXplorerGitXplorer
J

policy-refactorization

public
6 stars
2 forks
0 issues

Commits

List of commits on branch main.
Unverified
c626c598d735d4c08c2c0553da34196b3fba0b6d

first version

ttongzhoumu committed 4 years ago
Verified
339119b1fff1fffdc751d8cf01bc7a392d4c7a5e

Initial commit

JJiayuan-Gu committed 4 years ago

README

The README file for this repository.

Refactoring Policy for Compositional Generalizabilityusing Self-Supervised Object Proposals

This repository contains the official code of the NeurIPS 2020 paper Refactoring Policy for Compositional Generalizabilityusing Self-Supervised Object Proposals.

Installation

conda env create -f environment.yml

Experiment pipeline of FallingDigit

Note

  • Our method is a two-stage framework. The first stage is to train an RL teacher policy and collect demonstration dataset, and the second stage is to train a self-supervised object detector and GNN-based student policy on this demonstration dataset. Please follow the below steps one by one.
  • We provide some example configuration files in configs directory, and the most outputs the experiments will be saved in output directory.
  • The format of game environment name is FallingDigit${bg}_${n}-v1, where ${bg} can be Black or CIFAR, ${n} is the number of target digits (from 3 to 9). And we also provides the test environment FallingDigit${bg}_${n}_test-v1, which contains different game levels compared to the training environment (each level is generated by an unique random seed). Since we have different game environments with different backgournds, make sure you use environments with the same background through the whole experiment pipeline.
  • Since training a teacher policy by RL takes some time, we provide two trained teacher policies (trained on FallingDigitBlack_3-v1, trained on FallingDigitCIFAR_3-v1). With them, you can skip start from the step of collecting demostration datset.

Train a teacher policy by DQN

python dqn/main.py --cfg configs/falling_digit_rl/dqn_relation_net.yml env FallingDigitCIFAR_3-v1

The checkpoints directory will be something like ./outputs/falling_digit_rl/dqn_relation_net/11-09_22-16-09_FallingDigitCIFAR_3-v1 .

Select a teacher checkpoint and collect demonstration dataset

python tools/select_teacher_checkpoint.py --env FallingDigitCIFAR_3-v1 \
    --cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
    --ckpt-dir ${YOUR_RL_OUTPUT_DIR}
python tools/collect_demo_dataset_for_falling_digit.py --env FallingDigitCIFAR_3-v1 \
    --cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
    --ckpt ${THE_SELECTED_GOOD_TEACHER_CHECKPOINT_PATH}

The collected demostration dataset will be saved in data directory.

Train a self-supervised object detector and generate object proposals for demo dataset

Paste the path of the collected demostration dataset into configs/falling_digit_space/cifar_space_v1.yaml. Specificially, paste into DATASET.TRAIN.path and DATASET.VAL.path, we use different splits of the same dataset as training set and validation set. Then run the following commands.

python space/train_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml
python space/predict_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml

Train a GNN-based student policy

Similarly, paste the path of the collected demostration dataset into configs/falling_digit_refactor/cifar_gnn.yaml. Then run the following command.

python refactorization/train_gnn.py --cfg configs/falling_digit_refactor/cifar_gnn.yaml

Test a GNN-based student policy

python tools/eval_student_policy.py \
    --env FallingDigitCIFAR_9_test-v1 \
    --n-episode 100 \
    gnn \
    --detector-model SPACE_v1 \
    --detector-checkpoint ${YOUR_DETECTOR_CHECKPOINT_PATH} \
    --gnn-model EdgeConvNet \
    --gnn-checkpoint ${YOUR_GNN_POLICY_CHECKPOINT_PATH}

${YOUR_DETECTOR_CHECKPOINT_PATH} should be something like outputs/falling_digit_space/cifar_space_v1/model_060000.pth, ${YOUR_GNN_POLICY_CHECKPOINT_PATH} should be like outputs/falling_digit_refactor/cifar_gnn/model_best.pth. Note that the test environment FallingDigit${bg}_${n}_test-v1 should be used here.

Citation

If you find our paper useful in an academic setting, please cite:

@article{mu2020refactoring,
  title={Refactoring Policy for Compositional Generalizability using Self-Supervised Object Proposals},
  author={Mu, Tongzhou and Gu, Jiayuan and Jia, Zhiwei and Tang, Hao and Su, Hao},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

Acknowledgments

The self-supervised object detector part in this implementation refers to some details in Zhixuan Lin's original implementaion. The reinforcement learning part in this implementation is adapted from Shaotong Zhang's DeepRL code base.