Reinforcement Learning has been the go-to learning paradigm for trajectory planning, where the model predicts (state, action) pairs from start to goal state. Recent works, including Diffuser[1], have proposed modeling trajectory planning as a generation problem which can be controlled with classifier guidance. However, existing methods have three major limitations:
- The backbone used for training diffusion models is not suitable for sequential prediction tasks
- Lack of open-source and easy-to-use library for training diffusion models for planning, and,
- Lack of application in complex planning situations.
We propose DiTP: A transfomer-based path-planning algorithm using diffusion. DiTP has shown notable increase in performance as compared to traditional model-free RL algorithm (IQL) and Diffuser (UNet+Diffusion) in complex maze2D environment.
To get started with the repository, follow these steps:
-
Clone the repository and pull the latest changes:
git clone https://github.com/dsrivastavv/DiffusionBasedRL.git cd DiffusionBasedRL git pull
-
Pull the Docker image from Docker Hub or build it from the Dockerfile:
docker pull revenths/diffuser:latest
or alternatively build image using:
docker build -t revenths/diffuser:latest .
-
Run the Docker container with the mounted repository directory:
docker run -it -v <path_to_DiffusionBasedRL_on_local>:/root/DiffusionBasedRL --runtime=nvidia revenths/diffuser:latest
-
Inside the container within directory: DiffusionBasedRL, you can train the Diffuser models using the following commands:
- To train Diffuser:UNET:
nohup python3 -u -m scripts.train --config config.maze2d --dataset <maze2d-large-v1/maze2d-medium-v1/maze2d-umaze-v1> > trainlogs_unet.log &
- To train Diffuser:DiT:
nohup python3 -u -m scripts.train --config config.maze2d_dit --dataset <maze2d-large-v1/maze2d-medium-v1/maze2d-umaze-v1> > trainlogs_dit.log &
-
You can download the pretrained model weights from Model Pretrained Weights.
-
Unzip the zip file
-
For inference, run below commands. This will create a
scorelist.json
file in the DiffusionBasedRL directory.python3 -m scripts.maze2dtable --config config.maze2d --dataset maze2d-umaze-v1 --numepisodes 100 python3 -m scripts.maze2dtable --config config.maze2d --dataset maze2d-medium-v1 --numepisodes 100 python3 -m scripts.maze2dtable --config config.maze2d --dataset maze2d-large-v1 --numepisodes 100 python3 -m scripts.maze2dtable --config config.maze2d_dit --dataset maze2d-umaze-v1 --numepisodes 100 python3 -m scripts.maze2dtable --config config.maze2d_dit --dataset maze2d-medium-v1 --numepisodes 100 python3 -m scripts.maze2dtable --config config.maze2d_dit --dataset maze2d-large-v1 --numepisodes 100