GitXplorerGitXplorer
j

vq_bet_official

public
129 stars
12 forks
2 issues

Commits

List of commits on branch main.
Unverified
09d4851288ca5deaaa1ab367a208e520f8ee9a84

first commit

jjayLEE0301 committed a year ago

README

The README file for this repository.

VQ-BeT: Behavior Generation with Latent Actions

Official implementation of VQ-BeT: Behavior Generation with Latent Actions.

project website: https://sjlee.cc/vq-bet

Installation

  • Make a conda environemnt (We tested on python 3.7 and 3.9) and activate the environment

    conda create -n vq-bet python=3.9
    conda activate vq-bet
  • Clone this repo

    git clone https://github.com/jayLEE0301/vq_bet_official.git
    export PROJ_ROOT=$(pwd)
  • Install pytorch (We tested on PyTorch 1.12.1 and 2.1.0)

  • Install VQ-BeT

    cd $PROJ_ROOT/vq_bet_official
    pip install -r requirements.txt
    pip install -e .

    or, you can use sh install.sh, instead of pip install -r requirements.txt

  • Install MuJoCo and D4RL

    D4RL can be installed by cloning the repository as follows:

    cd $PROJ_ROOT
    git clone https://github.com/Farama-Foundation/d4rl.git
    cd $PROJ_ROOT/d4rl
    pip install -e .
    cd $PROJ_ROOT/vq_bet_official

    Also, to run UR3 env, you should install UR3 env

    cd $PROJ_ROOT/vq_bet_official/envs/ur3
    pip install -e .
    cd $PROJ_ROOT/vq_bet_official
  • To enable logging, log in with a wandb account:

    wandb login

    Alternatively, to disable logging altogether, set the environment variable WANDB_MODE:

    export WANDB_MODE=disabled

Usage

Step 0: Download dataset and set dataset path / saving path

  • Download datasets here.

    • Optionally, use gdown to do that: gdown --fuzzy https://drive.google.com/file/d/1aHb4kV0mpMvuuApBpVGYjAPs6MCNVTNb/view?usp=sharing.
  • Add path to your dataset directory and save path in ./examples/configs/env_vars/env_vars.yaml.

    # TODO fill these out
    dataset_path: PATH_TO_YOUR_[env_name]_DATASET
    save_path: YOUR_SAVE_PATH
    wandb_entity: YOUR_WANDB_ENTITY

Step 1: pretrain vq-vae

  • To pretrain Residual VQ, set config_name="pretrain_[env name]" in ./examples/pretrain_vqvae.py and run pretrain_vqvae.py. (e.g., for Goal-cond / Non-goal-cond Kitchen env, config_name="pretrain_kitchen")

    python examples/pretrain_vqvae.py

Step 2: train and evaluate vq-bet

  • Add path to your pre-trained Residual VQ in ./examples/configs/train_[env name].yaml to load them.

    vqvae_load_dir: YOUR_PATH_TO_PRETRAINED_VQVAE/trained_vqvae.pt
  • Then, set config_name="train_[env name]" in ./examples/train.py and run train.py (e.g., for Non-goal-cond Kitchen env, config_name="train_kitchen_nongoalcond")

    python examples/train.py

Training visual observation envs:

In this repo, we provide pre-processed embedding vectors with ResNet18 for the PushT and Kitchen environments. To train VQ-BeT with visual observation, set visual_input: true in ./examples/train_[env name].yaml. Please not that using freezed embedding could show lower performance compared to fine-tuning ResNet18 while it is much faster (We will release additional modules for fine-tuning ResNet with VQ-BeT soon).

(Optional) quick start: evaluating VQ-BeT with pretrained weights (on goal-cond Kitchen env)

If you want to quickly see the performance of VQ-BeT on goal-cond Kitchen env without training it from scratch, please check the description below.

  • Download pretrained Residual VQ, and VQ-BeT here.

    • Optionally, use gdown to do that: gdown --fuzzy https://drive.google.com/file/d/1iGRyxwPHMsSVDFGojTiPteU3NVNNXMfP/view?usp=sharing.
  • Add path to your pre-trained weights in ./examples/configs/train_kitchen_goalcond.yaml to load them.

    vqvae_load_dir: YOUR_PATH_TO_DOWNLOADED_WEIGHTS/rvq/trained_vqvae.pt
    load_path: YOUR_PATH_TO_DOWNLOADED_WEIGHTS/vq-bet
  • Then, set config_name="train_kitchen_goalcond" in ./examples/train.py and run train.py.

    python examples/train.py

How can I train VQ-BeT using my own Env?

NOTE: You should make your own ./examples/configs/train_[env name].yaml and ./examples/configs/pretrain_[env name].yaml

  • First, copy train_your_env.yaml and pretrain_your_env.yaml files from ./examples/configs/template to ./examples/configs

  • Then, add path to your dataset directory and save path in ./examples/configs/env_vars/env_vars.yaml.

    env_vars:
      # TODO fill these out
      dataset_path: PATH_TO_YOUR_[env_name]_DATASET
      save_path: YOUR_SAVE_PATH
      wandb_entity: YOUR_WANDB_ENTITY
  • Also, add the following line under "datasets:" in ./examples/configs/env_vars/env_vars.yaml containing your environment name.

    [env_name]: ${env_vars.dataset_path}/[env_name]
  • Then, add your own env file at examples/[env name]_env.py. Please note that it should follow OpenAI Gym style, and contain def get_goal_fn if you are training a goal-conditioned tasks.

  • Finally, follow Step1: pretrain vq-vae and Step2: train and evaluate vq-bet in section Usage to pretrain Residual VQ, and train VQ-BeT.

Tips for hyperparameter tuning on you own env.

During Residual VQ pretraining, the hyperparameters to be determined (in order of importance, with the most important at the top):

  1. action_window_size:

    • 1 (single-step prediction): Generally sufficient for most environments.

    • 3~5 (multi-step prediction): Can be helpful in environments where action correlation, such as in PushT, is important.

  2. encoder_loss_multiplier: Adjust this value when the action scale is not between -1 and 1. For example, if the action scale is -100 to 100, a value of 0.01 could be used. If action data is normalized, the default value can be used without adjustment.

  3. vqvae_n_embed: (10~16 or more) This represents the total possible number of modes, calculated as vqvae_n_embed^vqvae_groups. VQ-BeT has robust performance to the size of the dictionary if it is enough to capture the major modes in the dataset (it depends on the tasks, but usually >= 10). Please refer to Section B.1. in the manuscript to see the performance of VQ-BeT with various size of Residual VQ dictionary.

Hyperparameters to be determined during the VQ-BeT training (in order of importance, with the most important at the top):

  1. window_size: 10 ~ 100: While 10 is suitable in most cases, consider increasing it if a longer observation history is deemed beneficial.

  2. offset_loss_multiplier: If the action scale is around -1 to 1, the most common value of offset_loss_multiplier is 100 (default). Adjust this value if the action scale is not between -1 and 1. For example, if the action scale is -100 to 100, a value of 1 could be used.

  3. secondary_code_multiplier: The default value is 0.5. Experimenting with values between 0.5 and 3 is recommended. A larger value emphasizes predictions for the secondary code more than offset predictions.

Common errors and solutions

  • Cython compile error

    Cython.Compiler.Errors.CompileError

    Try pip install "cython<3" (https://github.com/openai/mujoco-py/issues/773)

  • MuJoCo gcc error

    fatal error: GL/glew.h: No such file or directory
    distutils.errors.CompileError: command '/usr/bin/gcc' failed with exit code 1

    Try the following solution

    conda install -c conda-forge glew
    conda install -c conda-forge mesalib
    conda install -c menpo glfw3

    Then add your conda environment include to CPATH (put this in your .bashrc to make it permanent):

    export CPATH=$CONDA_PREFIX/include
    

    Finally, install patchelf with pip install patchelf

  • MuJoCo missing error:

    Error: You appear to be missing MuJoCo.  We expected to find the file here: /home/usr_name/.mujoco/mujoco210 .

    Can be solved by following instructions here.

  • gladLoadGL error

    Error in call to target 'gym.envs.registration.make':
    FatalError('gladLoadGL error')

    Try putting MUJOCO_GL=egl in front of your command

      MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python examples/train.py 

Our code sourced and modified from miniBeT implementation for conditional and unconditional behavior transformer Algorithm. Also, we utilizes residual VQ-VAE codes from Vector Quantization - Pytorch repo, PushT env from Diffusion Policy, Ant env base from DHRL and UR3 env from here.