1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
Follow the instructions here to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in this issue or check out the code here.
Follow the instructions here to download the original Extreme Summarization datasets, or check out the code here, Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
TASK=cnn_dm
for SPLIT in train val
do
for LANG in source target
do
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "$TASK/$SPLIT.$LANG" \
--outputs "$TASK/$SPLIT.bpe.$LANG" \
--workers 60 \
--keep-empty;
done
done
fairseq-preprocess \
--source-lang "source" \
--target-lang "target" \
--trainpref "${TASK}/train.bpe" \
--validpref "${TASK}/val.bpe" \
--destdir "${TASK}-bin/" \
--workers 60 \
--srcdict dict.txt \
--tgtdict dict.txt;
Example fine-tuning CNN-DM
TOTAL_NUM_UPDATES=20000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=2048
UPDATE_FREQ=4
BART_PATH=/path/to/bart/model.pt
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \
--restore-file $BART_PATH \
--max-tokens $MAX_TOKENS \
--task translation \
--source-lang source --target-lang target \
--truncate-source \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--arch bart_large \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --update-freq $UPDATE_FREQ \
--skip-invalid-size-inputs-valid-test \
--find-unused-parameters;
Above is expected to run on 1
node with 8 32gb-V100
.
Expected training time is about 5 hours
. Training time can be reduced with distributed training on 4
nodes and --update-freq 1
.
Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
After training the model as mentioned in previous step, you can perform inference with checkpoints in checkpoints/
directory using following python code snippet:
import torch
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='cnn_dm-bin'
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation