GitXplorerGitXplorer
B

train_bart_pt

public
4 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
1d79669bc607b8e481838ecf49629f409caf421d

remove amr translate

committed 2 years ago
Unverified
be2bf853614ad8c3fdd1e1bb20ca48ea4d1a9a5d

better way to set num threads

committed 2 years ago
Unverified
804ac4919e336db836437f16e698a9aedac168aa

re-add zip

committed 2 years ago
Unverified
6449ab04b9081e7b351f2b015ad5f73f38563024

allow set_num_threads

committed 2 years ago
Unverified
17f254a152c00f0e45b2190de6e1f3f394dfa68a

remove unncessary zip

committed 2 years ago
Unverified
a2d9c01c5352b8ac98276f36de0e7bba8aa2ff7d

add translate script

committed 2 years ago

README

The README file for this repository.

Train BART with PyTorch

This is an alpha release of the training script. Please try it and let me know which problems you experience!

An example script of training BART, an encoder-decoder that is trained on the objective of denoising tokens and spans.

This code builds on the Flax example and the data collator was ported from fairseq to a collator format rather than a dataset (i.e., batched).

Usage

Inspired by the process here.

1. Train a tokenizer on a dataset on the hub

python prepare_tokenizer.py \
    oscar \
    --dataset_config_name unshuffled_deduplicated_nl \
    --dataset_split train \
    --dout ./my-bart-model

2. Prepare a model config file based on an existing model

python prepare_config.py \
    --pretrained_model_name facebook/bart-base \
    --dout ./my-bart-model

3. Train the model and specific tokenizer and config

python run_bart_dlm.py \
    --config_name ./my-bart-model \
    --tokenizer_name ./my-bart-model \
    --dataset_name oscar \
    --dataset_config_name unshuffled_deduplicated_nl \
    --output_dir ./my-bart-model \
    --do_train \
    --do_eval

Some notes

Sentence splitting

As part of BART, the sentences in a sample may be permuted (reordered). To detect sentences for each sample, we need sentence splitting. By dfault, we'll use NLTK's English punct sentence splitter but by passing a spaCy model name to spacy_model (e.g. en_core_web_sm) you can also rely on spaCy for better (but slower) sentence splitting. You can also disable sentence splitting completely with --no_sentence_splitting. In that case, make sure the sentences are already split with a padding token between them (<pad>).

Default values

The defaults are set to the given BART args. This differs from the Flax defaults in one respect, namely poisson_lambda, which is now set to 3.5 instead of 3.0.

HF (Flax), fairseq, and current implementation

There are some differences in implementation between fairseq, the HF FLAX example, and this PyTorch implementation.

  • argwhere in the Flax example in this position is not the same as what is happening in fairseq. In fairseq we check explicitly that the previous token was not a "full stop" (padding token) but in HF we just check whether the current token is a full stop. In the current example I also explicitly check that the next token is not a full stop, in case of padding. (However, in practice that should be a non-issue since all batches/samples should have the same sequence length and there should not be any padding.)
  • I found that the result of sentence permutation was not consistent in terms of where the separating pad token ended up (bug report), so I have reimplemented that method so that sentences in a sequence are still separated by a padding token, even after permutation.
  • In HF FLAX, the token_mask is restricted to non-special and non-padding tokens. In Fairseq, by default, only the first and last tokens are excluded and all others are prone to masking. The HF implementation seems sensible so I follow that. get_special_tokens_mask includes the padding token, though, so no need to add that separately.
  • The Flax example does not include methods to add more noise. I have ported those as well.
  • However, I did not adapt add_insertion_noise to work well with padded sequences. So the inserted noise may occur ANYWHERE. It is unclear whether this is intended behavior.

Alternatively, we could implement all this processing on the dataset level and use Dataset.map. This has some advantages:

  • more true to fairseq implementation (sample level rather than batch level);
  • cached.

... and disadvantages:

  • potentially slower (not batched), although we can integrate a batched approach. But as discussed above, this will be less true to the original fairseq implementation in add_insertion_noise
  • every sample is always processed the same. So in small datasets which are seen multiple times by the model, the same sample will always be processed the same. In a dataloader, that will not be the case because the processing occurs on every iteration rather than once before training.

Questions/Uncertainties

  • Do the padding tokens still serve a purpose after permutation? (Teaching the model to learn to detect sentence boundaries?)
  • It seems that add_insertion_noise can insert noise anywhere, which means that it will also overwrite special tokens and that sequence don't necessarily end with a EOS token. Is that a problem?