Official implementation for the experiments in Vanishing Gradients in Reinforcement Finetuning of Language Models (ICLR 2024), based on the PyTorch, HuggingFace Transformers, and RL4LMs libraries.
Tested with python 3.7.
- Install PyTorch from the official website (tested with version 1.11.0), including
torchvision
. - Install the remaining requirements using the
update_env.sh
script.
Alternatively, you can use the docker image and requirements.txt provided in the RL4LMs library.
Our experiments on the GRUE benchmark are based on the RL4LMs library.
For a GRUE dataset, the following command generates multiple output text samples per train example, computes their reward, and saves the results to a file.
python rl4lms_extract_samples_from_model_runner.py --config_path <config_path>
Replace <config_path>
with one of the paths below according to the desired dataset.
Dataset | Configuration File Path |
---|---|
NarrativeQA | rl4lms_scripts/extract_samples/t5_narrative_qa_extract_samples.yml |
ToTTo | rl4lms_scripts/extract_samples/t5_totto_extract_samples.yml |
CommonGen | rl4lms_scripts/extract_samples/t5_commongen_extract_samples.yml |
IWSLT 2017 | rl4lms_scripts/extract_samples/t5_iwslt2017_extract_samples.yml |
CNN/Daily Mail | rl4lms_scripts/extract_samples/t5_cnn_extract_samples.yml |
DailyDialog | rl4lms_scripts/extract_samples/gpt2_dialog_extract_samples.yml |
IMDB | rl4lms_scripts/extract_samples/gpt2_imdb_extract_samples.yml |
By default, the outputs will be generated using a pretrained model. To modify the model you can replace alg.model_name
configuration with the HuggingFace model name of your choice or a directory with a HuggingFace model checkpoint.
Note that the reward mean and standard deviation scatter plots in the paper (e.g. Figure 1) were created based on the results of this command.
Additional Notes:
- The output will be saved under the directory specified by the
output_path
run argument (default isoutputs/rl4lms_stats
). - To modify the number of output samples generated per train example, change the
evaluation.num_samples_per_input
field in the configuration file (default is 10). - To modify the number of train examples used, change the
datapool.num_train_samples
field in the configuration file (default is 5000).
The following command runs a single finetuning experiment based on the given configuration file.
python rl4lms_train_text_generation.py --config_path <config_path>
Replace <config_path>
with the path to the desired finetuning configuration, which can be found under the rl4lms_scripts/training
directory.
By default, finetuning is based on a pretrained model (either GPT-2 or T5-base). To modify the model you can replace the alg.model_name
configuration with the HuggingFace model name of your choice or a directory with a HuggingFace model checkpoint.
Specifically, to perform an initial supervised finetuning phase before reinforcement finetuning, first finetune the pretrained model using the supervised finetuning configuration (e.g. rl4lms_scripts/training/narrative_qa/t5_supervised.yml
for the NarrativeQA dataset). Then, set the model name in the "ppo_on_supervised" configuration (e.g. rl4lms_scripts/training/narrative_qa/t5_ppo_on_supervised.yml
for running reinforcement finetuning using PPO on the NarrativeQA dataset) to the saved model checkpoint from the supervised finetuning run.
Additional Notes:
- Results and model checkpoint will be saved under the directory specified by the
base_path_to_store_results
run argument (default isoutputs/rl4lm
). - Use the
-h
flag for information on the customizable run arguments. - See the RL4LMs GitHub page for more information on the available configuration options and customization.
Increasing learning rate, temperature, and entropy regularization (Section 5.1):
- The learning rate can be modified via the
alg.args.learning_rate
field in the configuration file. - The temperature can be modified via the
policy.args.generation_kwargs.temperature
field in the configuration file. - The entropy regularization coefficient via the
alg.args.ent_coef
field in the configuration file.
Reducing the number of supervised finetuning optimization steps and fraction of labeled samples (Section 5.2):
- The maximal number of supervised finetuning optimization steps can be modified via the
alg.training_args.max_steps
field in the configuration file (default is -1, which means the number of steps is determined by the number of epochs and batch size). - The fraction of training examples can be modified via the
datapool.frac_train_samples
field in the configuration file (default is -1, which means the whole training set will be used).
python sft_vs_rft_controlled_experiments_plan_runner.py --plan_config_path <path_to_config_file>
Replace <path_to_config_file>
with one of the paths below according to the desired experiment.
Model and Dataset | Configuration File Path |
---|---|
MLP on MNIST | sft_vs_rft/experiment_plans/sft_vs_rft_mlp_mnist_pretrain_experiments_plan.json |
ResNet18 on CIFAR10 | sft_vs_rft/experiment_plans/sft_vs_rft_resnet_cifar_pretrain_experiments_plan.json |
BERT-mini on STS-B | sft_vs_rft/experiment_plans/sft_vs_rft_bert_stsb_pretrain_experiments_plan.json |
Additional Notes:
- A folder with log files, metrics, and the model checkpoint will be automatically created under the directory specified by
outputs_dir
in the configuration file (default isoutputs/controlled
). - It is recommended to use a GPU by adding an available gpu id to the
gpu_ids
field in the configuration file. - The different configuration options are documents in
common/experiment/fit_experiment_base.py
andsft_vs_rft/experiment/sft_vs_rft_controlled_experiment.py
.
python sft_vs_rft_controlled_experiments_plan_runner.py --plan_config_path <path_to_config_file>
Replace <path_to_config_file>
with one of the paths below according to the desired experiment.
Model and Dataset | Configuration File Path |
---|---|
MLP on MNIST w/ Adam | sft_vs_rft/experiment_plans/sft_vs_rft_mlp_mnist_finetune_experiments_plan.json |
MLP on MNIST w/ SGD | sft_vs_rft/experiment_plans/sft_vs_rft_mlp_mnist_finetune_sgd_experiments_plan.json |
MLP on MNIST w/ High Initial Reward | sft_vs_rft/experiment_plans/sft_vs_rft_mlp_mnist_finetune_high_init_reward_experiments_plan.json |
ResNet18 on CIFAR10 w/ Adam | sft_vs_rft/experiment_plans/sft_vs_rft_resnet_cifar_finetune_experiments_plan.json |
ResNet18 on CIFAR10 w/ SGD | sft_vs_rft/experiment_plans/sft_vs_rft_resnet_cifar_finetune_sgd_experiments_plan.json |
ResNet18 on CIFAR10 w/ High Initial Reward | sft_vs_rft/experiment_plans/sft_vs_rft_resnet_cifar_finetune_high_init_reward_experiments_plan.json |
BERT-mini on STS-B w/ Adam | sft_vs_rft/experiment_plans/sft_vs_rft_bert_stsb_finetune_experiments_plan.json |
BERT-mini on STS-B w/ SGD | sft_vs_rft/experiment_plans/sft_vs_rft_bert_stsb_finetune_sgd_experiments_plan.json |
BERT-mini on STS-B w/ High Initial Reward | sft_vs_rft/experiment_plans/sft_vs_rft_bert_stsb_finetune_high_init_reward_experiments_plan.json |
In the chosen configuration file, set the load_model_from_checkpoint
field to the path of the checkpoint created during pretraining.
Additional Notes:
- Make sure that the
samples_rnd_seed
configuration is the same as that used for pretraining, so that the data generation process for finetuning will be consistent with that for pretraining. - The configuration
opt_method
determines which type of finetuning is conducted. By default, the included configurations will run two experiments, one with supervised finetuning (supervised
) and one with reinforcement finetuning using expected gradients (expected_reinforce
). - A folder with log files, metrics, and the model checkpoint will be automatically created under the directory specified by
outputs_dir
in the configuration file (default isoutputs/controlled
). - It is recommended to use a GPU by adding an available gpu id to the
gpu_ids
field in the configuration file. - The different configuration options are documented in
common/experiment/fit_experiment_base.py
andsft_vs_rft/experiment/sft_vs_rft_controlled_experiment.py
.
For citing the paper you can use:
@inproceedings{razin2024vanishing,
title={Vanishing Gradients in Reinforcement Finetuning of Language Models},
author={Razin, Noam and Zhou, Hattie and Saremi, Omid and Thilak, Vimal and Bradley, Arwen and Nakkiran, Preetum and Susskind, Joshua and Littwin, Etai},
booktitle={International Conference on Learning Representations},
year={2024}
}