Skip to content

Code and source for paper ``How to Fine-Tune BERT for Text Classification?``

License

Notifications You must be signed in to change notification settings

fotstrt/BERT4doc-Classification

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

How to Fine-Tune BERT for Text Classification?

This is the code and source for the paper How to Fine-Tune BERT for Text Classification?

In this paper, we conduct exhaustive experiments to investigate different fine-tuning methods of BERT on text classification task and provide a general solution for BERT fine-tuning.

*********** update at Mar 14, 2020 *************

Our checkpoint can be loaded in BertEmbedding from the latest fastNLP package.

Link to fastNLP.embeddings.BertEmbedding

Requirements

For further pre-training, we borrow some code from Google BERT. Thus, we need:

  • tensorflow==1.1x
  • spacy
  • pandas
  • numpy

For fine-tuning, we borrow some codes from pytorch-pretrained-bert package (now well known as transformers). Thus, we need:

  • torch>=0.4.1,<=1.2.0

Run the code

1) Prepare the data set:

Sogou News

We determine the category of the news based on the URL, such as “sports” corresponding to “http://sports.sohu.com”. We choose 6 categories – “sports”, “house”, “business”, “entertainment”, “women” and “technology”. The number of training samples selected for each class is 9,000 and testing 1,000.

Data is available at here.

The rest data sets

The rest data sets were built by Zhang et al. (2015). We download from URL created by Xiang Zhang.

2) Prepare Google BERT:

BERT-Base, Uncased

BERT-Base, Chinese

3) Further Pre-Training:

Generate Further Pre-Training Corpus

Here we use AG's News as example:

python generate_corpus_agnews.py

File agnews_corpus_test.txt can be found in directory ./data.

Run Further Pre-Training

python create_pretraining_data.py \
  --input_file=./AGnews_corpus.txt \
  --output_file=tmp/tf_AGnews.tfrecord \
  --vocab_file=./uncased_L-12_H-768_A-12/vocab.txt \
  --do_lower_case=True \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --masked_lm_prob=0.15 \
  --random_seed=12345 \
  --dupe_factor=5
  
python run_pretraining.py \
  --input_file=./tmp/tf_AGnews.tfrecord \
  --output_dir=./uncased_L-12_H-768_A-12_AGnews_pretrain \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=./uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=./uncased_L-12_H-768_A-12/bert_model.ckpt \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=100000 \
  --num_warmup_steps=10000 \
  --save_checkpoints_steps=10000 \
  --learning_rate=5e-5

4) Fine-Tuning

Convert Tensorflow checkpoint to PyTorch checkpoint

python convert_tf_checkpoint_to_pytorch.py \
  --tf_checkpoint_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/model.ckpt-100000 \
  --bert_config_file ./uncased_L-12_H-768_A-12_AGnews_pretrain/bert_config.json \
  --pytorch_dump_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/pytorch_model.bin

Fine-Tuning on downstream tasks

While fine-tuning on downstream tasks, we notice that different GPU (e.g.: 1080Ti and Titan Xp) may cause slight differences in experimental results even though we fix the initial random seed. Here we use 1080Ti * 4 as example.

Take Exp-I (See Section 5.3) as example,

export CUDA_VISIBLE_DEVICES=0,1,2,3
python run_classifier_single_layer.py \
  --task_name imdb \
  --do_train \
  --do_eval \
  --do_lower_case \
  --data_dir ./IMDB_data/ \
  --vocab_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/vocab.txt \
  --bert_config_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/bert_config.json \
  --init_checkpoint ./uncased_L-12_H-768_A-12_IMDB_pretrain/pytorch_model.bin \
  --max_seq_length 512 \
  --train_batch_size 24 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
  --output_dir ./imdb \
  --seed 42 \
  --layers 11 10 \
  --trunc_medium -1

where num_train_epochs can be 3.0, 4.0, or 6.0.

layers indicates list of layers which will be taken as feature for classification. -2 means use pooled output, -1 means concat all layer, the command above means concat layer-10 and layer-11 (last two layers).

trunc_medium indicates dealing with long texts. -2 means head-only, -1 means tail-only, 0 means head-half + tail-half (e.g.: head256+tail256), other natural number k means head-k + tail-rest (e.g.: head-k + tail-(512-k)).

There also other arguments for fine-tuning:

pooling_type indicates which feature will be used for classification. mean means mean-pooling for hidden state of the whole sequence, max means max-pooling, default means taking hidden state of [CLS] token as features.

layer_learning_rate and layer_learning_rate_decay in run_classifier_discriminative.py indicates layer-wise decreasing layer rate (See Section 5.3.4).

Further Pre-Trained Checkpoints

We upload IMDb-based further pre-trained checkpoints at here.

For other checkpoints, please contact us by e-mail.

How to cite our paper

@inproceedings{sun2019fine,
  title={How to fine-tune {BERT} for text classification?},
  author={Sun, Chi and Qiu, Xipeng and Xu, Yige and Huang, Xuanjing},
  booktitle={China National Conference on Chinese Computational Linguistics},
  pages={194--206},
  year={2019},
  organization={Springer}
}

About

Code and source for paper ``How to Fine-Tune BERT for Text Classification?``

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%