GitXplorerGitXplorer
f

maze_navigation_MLMU

public
11 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
fff60da895fec2fd316e7b929d0f5d7b1b871b0f

remove setup.py ref from readme and add more deps

nniklasnolte committed a month ago
Verified
b1ec3cd8a8166feb4ae5ec978db7cd9c8bba6812

Update README.md

nniklasnolte committed a month ago
Unverified
c4ccbfa2188fb2717e0b92410b76fd0b5b8ded49

update arxiv and website link

nniklasnolte committed a month ago
Unverified
91bee45f402dbdce144e3aa61e6326f68271fcdf

update readme

nniklasnolte committed a month ago
Unverified
c64fede9615fb8d0ea72f26d07f4bc6bd2542882

initial commit

nniklasnolte committed a month ago

README

The README file for this repository.

Code for Paper "Transformers can navigate mazes with Multi-Step Prediction"

arXiv deploy

Installation

  1. Install PyTorch
  2. pip install pytest submitit hydra-core hydra-submitit-launcher loguru tqdm gitpython transformers lightning matplotlib datasets sortedcontainers maze-dataset pymongo numpy maze-dataset

If you want to run A* mazes (from https://github.com/facebookresearch/searchformer/)

  1. Install mongodb
  2. Download maze.gz and maze.vocabulary.gz from https://github.com/facebookresearch/searchformer/blob/main/doc/mongodb.md
  3. add those to your mongodb
    mongorestore --gzip --archive=maze.gz
    mongorestore --gzip --archive=maze.vocabulary.gz

adjust locations: search for "TODO" and you will find them:

  1. main.py --> code snapshot dir
  2. train_defaults.yaml --> logs dir
  3. train_defaults.yaml --> data dir

Run next token (AR) Baseline

Locally

python main.py -m mode=local model=gpt dataset=maze datamodule.grid_n=4

  • use_wandb=False or True to enable or disable debugging

Run MLM-U

python main.py -m mode=local model=past dataset=maze datamodule.grid_n=4

PAST is an encoder-decoder model that runs best with mlm-u (model.train_mode=absorbing). GPT is the best model for AR (left to right next token prediction)

Contributing

See the CONTRIBUTING file for how to help out.

License

This project is Apache 2.0 licensed, as found in the LICENSE file.

The stargraph dataset has been adapted from https://github.com/gregorbachmann/Next-Token-Failures/