This repository contains the official code of the NeurIPS 2020 paper Refactoring Policy for Compositional Generalizabilityusing Self-Supervised Object Proposals.
conda env create -f environment.yml
- Our method is a two-stage framework. The first stage is to train an RL teacher policy and collect demonstration dataset, and the second stage is to train a self-supervised object detector and GNN-based student policy on this demonstration dataset. Please follow the below steps one by one.
- We provide some example configuration files in
configs
directory, and the most outputs the experiments will be saved inoutput
directory. - The format of game environment name is
FallingDigit${bg}_${n}-v1
, where ${bg} can beBlack
orCIFAR
, ${n} is the number of target digits (from 3 to 9). And we also provides the test environmentFallingDigit${bg}_${n}_test-v1
, which contains different game levels compared to the training environment (each level is generated by an unique random seed). Since we have different game environments with different backgournds, make sure you use environments with the same background through the whole experiment pipeline. - Since training a teacher policy by RL takes some time, we provide two trained teacher policies (trained on FallingDigitBlack_3-v1, trained on FallingDigitCIFAR_3-v1). With them, you can skip start from the step of collecting demostration datset.
python dqn/main.py --cfg configs/falling_digit_rl/dqn_relation_net.yml env FallingDigitCIFAR_3-v1
The checkpoints directory will be something like ./outputs/falling_digit_rl/dqn_relation_net/11-09_22-16-09_FallingDigitCIFAR_3-v1
.
python tools/select_teacher_checkpoint.py --env FallingDigitCIFAR_3-v1 \
--cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
--ckpt-dir ${YOUR_RL_OUTPUT_DIR}
python tools/collect_demo_dataset_for_falling_digit.py --env FallingDigitCIFAR_3-v1 \
--cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
--ckpt ${THE_SELECTED_GOOD_TEACHER_CHECKPOINT_PATH}
The collected demostration dataset will be saved in data
directory.
Paste the path of the collected demostration dataset into configs/falling_digit_space/cifar_space_v1.yaml
. Specificially, paste into DATASET.TRAIN.path
and DATASET.VAL.path
, we use different splits of the same dataset as training set and validation set. Then run the following commands.
python space/train_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml
python space/predict_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml
Similarly, paste the path of the collected demostration dataset into configs/falling_digit_refactor/cifar_gnn.yaml
. Then run the following command.
python refactorization/train_gnn.py --cfg configs/falling_digit_refactor/cifar_gnn.yaml
python tools/eval_student_policy.py \
--env FallingDigitCIFAR_9_test-v1 \
--n-episode 100 \
gnn \
--detector-model SPACE_v1 \
--detector-checkpoint ${YOUR_DETECTOR_CHECKPOINT_PATH} \
--gnn-model EdgeConvNet \
--gnn-checkpoint ${YOUR_GNN_POLICY_CHECKPOINT_PATH}
${YOUR_DETECTOR_CHECKPOINT_PATH} should be something like outputs/falling_digit_space/cifar_space_v1/model_060000.pth
, ${YOUR_GNN_POLICY_CHECKPOINT_PATH} should be like outputs/falling_digit_refactor/cifar_gnn/model_best.pth
.
Note that the test environment FallingDigit${bg}_${n}_test-v1
should be used here.
If you find our paper useful in an academic setting, please cite:
@article{mu2020refactoring,
title={Refactoring Policy for Compositional Generalizability using Self-Supervised Object Proposals},
author={Mu, Tongzhou and Gu, Jiayuan and Jia, Zhiwei and Tang, Hao and Su, Hao},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}
The self-supervised object detector part in this implementation refers to some details in Zhixuan Lin's original implementaion. The reinforcement learning part in this implementation is adapted from Shaotong Zhang's DeepRL code base.