This is the code and source for the paper How to Fine-Tune BERT for Text Classification?
In this paper, we conduct exhaustive experiments to investigate different fine-tuning methods of BERT on text classification task and provide a general solution for BERT fine-tuning.
*********** update at Mar 14, 2020 *************
Our checkpoint can be loaded in BertEmbedding from the latest fastNLP package.
Link to fastNLP.embeddings.BertEmbedding
For further pre-training, we borrow some code from Google BERT. Thus, we need:
- tensorflow==1.1x
- spacy
- pandas
- numpy
For fine-tuning, we borrow some codes from pytorch-pretrained-bert package (now well known as transformers). Thus, we need:
- torch>=0.4.1,<=1.2.0
We determine the category of the news based on the URL, such as “sports” corresponding to “http://sports.sohu.com”. We choose 6 categories – “sports”, “house”, “business”, “entertainment”, “women” and “technology”. The number of training samples selected for each class is 9,000 and testing 1,000.
Data is available at here.
The rest data sets were built by Zhang et al. (2015). We download from URL created by Xiang Zhang.
Here we use AG's News as example:
python generate_corpus_agnews.py
File agnews_corpus_test.txt
can be found in directory ./data
.
python create_pretraining_data.py \
--input_file=./AGnews_corpus.txt \
--output_file=tmp/tf_AGnews.tfrecord \
--vocab_file=./uncased_L-12_H-768_A-12/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
python run_pretraining.py \
--input_file=./tmp/tf_AGnews.tfrecord \
--output_dir=./uncased_L-12_H-768_A-12_AGnews_pretrain \
--do_train=True \
--do_eval=True \
--bert_config_file=./uncased_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=./uncased_L-12_H-768_A-12/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=100000 \
--num_warmup_steps=10000 \
--save_checkpoints_steps=10000 \
--learning_rate=5e-5
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/model.ckpt-100000 \
--bert_config_file ./uncased_L-12_H-768_A-12_AGnews_pretrain/bert_config.json \
--pytorch_dump_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/pytorch_model.bin
While fine-tuning on downstream tasks, we notice that different GPU (e.g.: 1080Ti and Titan Xp) may cause slight differences in experimental results even though we fix the initial random seed. Here we use 1080Ti * 4 as example.
Take Exp-I (See Section 5.3) as example,
export CUDA_VISIBLE_DEVICES=0,1,2,3
python run_classifier_single_layer.py \
--task_name imdb \
--do_train \
--do_eval \
--do_lower_case \
--data_dir ./IMDB_data/ \
--vocab_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/vocab.txt \
--bert_config_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/bert_config.json \
--init_checkpoint ./uncased_L-12_H-768_A-12_IMDB_pretrain/pytorch_model.bin \
--max_seq_length 512 \
--train_batch_size 24 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir ./imdb \
--seed 42 \
--layers 11 10 \
--trunc_medium -1
where num_train_epochs
can be 3.0, 4.0, or 6.0.
layers
indicates list of layers which will be taken as feature for classification.
-2 means use pooled output, -1 means concat all layer, the command above means concat
layer-10 and layer-11 (last two layers).
trunc_medium
indicates dealing with long texts. -2 means head-only, -1 means tail-only,
0 means head-half + tail-half (e.g.: head256+tail256),
other natural number k means head-k + tail-rest (e.g.: head-k + tail-(512-k)).
There also other arguments for fine-tuning:
pooling_type
indicates which feature will be used for classification. mean
means
mean-pooling for hidden state of the whole sequence, max
means max-pooling, default means
taking hidden state of [CLS]
token as features.
layer_learning_rate
and layer_learning_rate_decay
in run_classifier_discriminative.py
indicates layer-wise decreasing layer rate (See Section 5.3.4).
We upload IMDb-based further pre-trained checkpoints at here.
For other checkpoints, please contact us by e-mail.
@inproceedings{sun2019fine,
title={How to fine-tune {BERT} for text classification?},
author={Sun, Chi and Qiu, Xipeng and Xu, Yige and Huang, Xuanjing},
booktitle={China National Conference on Chinese Computational Linguistics},
pages={194--206},
year={2019},
organization={Springer}
}