This is an alpha release of the training script. Please try it and let me know which problems you experience!
An example script of training BART, an encoder-decoder that is trained on the objective of denoising tokens and spans.
This code builds on the Flax example and the data collator was ported from fairseq to a collator format rather than a dataset (i.e., batched).
Inspired by the process here.
python prepare_tokenizer.py \
oscar \
--dataset_config_name unshuffled_deduplicated_nl \
--dataset_split train \
--dout ./my-bart-model
python prepare_config.py \
--pretrained_model_name facebook/bart-base \
--dout ./my-bart-model
python run_bart_dlm.py \
--config_name ./my-bart-model \
--tokenizer_name ./my-bart-model \
--dataset_name oscar \
--dataset_config_name unshuffled_deduplicated_nl \
--output_dir ./my-bart-model \
--do_train \
--do_eval
As part of BART, the sentences in a sample may be permuted (reordered). To detect sentences for each sample, we need
sentence splitting. By dfault, we'll use NLTK's English punct sentence splitter but by passing a spaCy model name
to spacy_model
(e.g. en_core_web_sm
) you can also rely on spaCy for better (but slower) sentence splitting.
You can also disable sentence splitting completely with --no_sentence_splitting
. In that case, make sure the
sentences are already split with a padding token between them (<pad>
).
The defaults are set to the
given BART args. This differs from
the Flax defaults in one respect, namely poisson_lambda
, which is now set to 3.5
instead of 3.0
.
There are some differences in implementation between fairseq, the HF FLAX example, and this PyTorch implementation.
-
argwhere
in the Flax example in this position is not the same as what is happening in fairseq. In fairseq we check explicitly that the previous token was not a "full stop" (padding token) but in HF we just check whether the current token is a full stop. In the current example I also explicitly check that the next token is not a full stop, in case of padding. (However, in practice that should be a non-issue since all batches/samples should have the same sequence length and there should not be any padding.) - I found that the result of sentence permutation was not consistent in terms of where the separating pad token ended up (bug report), so I have reimplemented that method so that sentences in a sequence are still separated by a padding token, even after permutation.
- In HF FLAX, the token_mask is restricted to non-special and non-padding tokens.
In Fairseq, by default, only the first and last tokens are excluded and all others
are prone to masking. The HF implementation seems sensible so I follow that.
get_special_tokens_mask
includes the padding token, though, so no need to add that separately. - The Flax example does not include methods to add more noise. I have ported those as well.
- However, I did not adapt
add_insertion_noise
to work well with padded sequences. So the inserted noise may occur ANYWHERE. It is unclear whether this is intended behavior.
Alternatively, we could implement all this processing on the dataset level and use Dataset.map
. This has some
advantages:
- more true to fairseq implementation (sample level rather than batch level);
- cached.
... and disadvantages:
- potentially slower (not batched), although we can integrate a batched approach. But as discussed above, this will be
less true to the original fairseq implementation in
add_insertion_noise
- every sample is always processed the same. So in small datasets which are seen multiple times by the model, the same sample will always be processed the same. In a dataloader, that will not be the case because the processing occurs on every iteration rather than once before training.
- Do the padding tokens still serve a purpose after permutation? (Teaching the model to learn to detect sentence boundaries?)
- It seems that
add_insertion_noise
can insert noise anywhere, which means that it will also overwrite special tokens and that sequence don't necessarily end with a EOS token. Is that a problem?