GitXplorerGitXplorer
l

transformer-tensorflow

public
465 stars
87 forks
6 issues

Commits

List of commits on branch master.
Unverified
ddcebc3f799ecef47ef7a99027198d804868cd2c

default

llilianweng committed 6 years ago
Unverified
0c2786fd974d2954c4de7e268d4cb038aaf73794

readme 3

llilianweng committed 6 years ago
Unverified
d61214c9bc1b2548a6a6b56ea7617ccfb5be22ec

readme 2

llilianweng committed 6 years ago
Unverified
21cb7a3896c97b681b3197271b1a680d6cf9b29c

readme

llilianweng committed 6 years ago
Unverified
6957ed50f4890fbff59b57631d750c199714381b

eval

llilianweng committed 6 years ago
Unverified
ae2903048f6fc6333012eef5df65bafa178de2ca

fix model loading bug + add unittest

llilianweng committed 6 years ago

README

The README file for this repository.

Transformer

Implementation of the Transformer model in the paper:

Ashish Vaswani, et al. "Attention is all you need." NIPS 2017.

Transformer model

Check my blog post on attention and transformer:

Implementations that helped me:

Setup

$ git clone https://github.com/lilianweng/transformer-tensorflow.git
$ cd transformer-tensorflow
$ pip install -r requirements.txt

Train a Model

# Check the help message:

$ python train.py --help

Usage: train.py [OPTIONS]

Options:
  --seq-len INTEGER               Input sequence length.  [default: 20]
  --d-model INTEGER               d_model  [default: 512]
  --d-ff INTEGER                  d_ff  [default: 2048]
  --n-head INTEGER                n_head  [default: 8]
  --batch-size INTEGER            Batch size  [default: 128]
  --max-steps INTEGER             Max train steps.  [default: 300000]
  --dataset [iwslt15|wmt14|wmt15]
                                  Which translation dataset to use.  [default:
                                  iwslt15]
  --help                          Show this message and exit.

# Train a model on dataset WMT14:

$ python train.py --dataset wmt14

Evaluate a Trained Model

Let's say, the model is saved in folder transformer-wmt14-seq20-d512-head8-1541573730 in checkpoints folder.

$ python eval.py transformer-wmt14-seq20-d512-head8-1541573730

With the default config, this implementation gets BLEU ~ 20 on wmt14 test set.

Implementation Notes

[WIP] A couple of tricking points in the implementation.

  • How to construct the mask correctly?
  • How to correctly shift decoder input (as training input) and decoder target (as ground truth in the loss function)?
  • How to make the prediction in an autoregressive way?
  • Keeping the embedding of <pad> as a constant zero vector is sorta important.