This code repository implements a variety of deep learning models for text classification using the Keras framework, which includes: FastText, TextCNN, TextRNN, TextBiRNN, TextAttBiRNN, HAN, RCNN, RCNNVariant, etc. In addition to the model implementation, a simplified application is included.
- Python 3.6
- NumPy 1.15.2
- Keras 2.2.0
- Tensorflow 1.8.0
All codes are located in the directory /model
, and each kind of model has a corresponding directory in which the model and application are placed.
For example, the model and application of FastText are located under /model/FastText
, the model part is fast_text.py
, and the application part is main.py
.
FastText was proposed in the paper Bag of Tricks for Efficient Text Classification.
- Using a look-up table, bags of ngram covert to word representations.
- Word representations are averaged into a text representation, which is a hidden variable.
- Text representation is in turn fed to a linear classifier.
- Use the softmax function to compute the probability distribution over the predefined classes.
Network structure of FastText:
TextCNN was proposed in the paper Convolutional Neural Networks for Sentence Classification.
- Represent sentence with static and non-static channels.
- Convolve with multiple filter widths and feature maps.
- Use max-over-time pooling.
- Use fully connected layer with dropout and softmax ouput.
Network structure of TextCNN:
TextRNN has been mentioned in the paper Recurrent Neural Network for Text Classification with Multi-Task Learning.
Network structure of TextRNN:
TextBiRNN is an improved model based on TextRNN. It improves the RNN layer in the network structure into a bidirectional RNN layer. It is hoped that not only the forward encoding information but also the reverse encoding information can be considered. No related papers have been found yet.
Network structure of TextBiRNN:
TextAttBiRNN is an improved model which introduces attention mechanism based on TextBiRNN. For the representation vectors obtained by bidirectional RNN encoder, the model can focus on the information most relevant to decision making through the attention mechanism. The attention mechanism was first proposed in the paper Neural Machine Translation by Jointly Learning to Align and Translate, and the implementation of the attention mechanism here is referred to this paper Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems.
In the paper Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems, the feed forward attention is simplified as follows,
Function a
, a learnable function, is recognized as a feed forward network. In this formulation, attention can be seen as producing a fixed-length embedding c
of the input sequence by computing an adaptive weighted average of the state sequence h
.
The implementation of attention is not described here, please refer to the source code directly.
Network structure of TextAttBiRNN:
HAN was proposed in the paper Hierarchical Attention Networks for Document Classification.
- Word Encoder. Encoding by bidirectional GRU, an annotation for a given word is obtained by concatenating the forward hidden state and backward hidden state, which summarizes the information of the whole sentence centered around word in current time step.
- Word Attention. By a one-layer MLP and softmax function, it is enable to calculate normalized importance weights over the previous word annotations. Then, compute the sentence vector as a weighted sum of the word annotations based on the weights.
- Sentence Encoder. In a similar way with word encoder, use a bidirectional GRU to encode the sentences to get an annotation for a sentence.
- Sentence Attention. Similar with word attention, use a one-layer MLP and softmax function to get the weights over sentence annotations. Then, calculate a weighted sum of the sentence annotations based on the weights to get the document vector.
- Document Classification. Use the softmax function to calculate the probability of all classes.
The implementation of attention here is based on FeedForwardAttention, which is the same as the attention in TextAttBiRNN.
Network structure of HAN:
The TimeDistributed wrapper is used here, since the parameters of the Embedding, Bidirectional RNN, and Attention layers are expected to be shared on the time step dimension.
RCNN was proposed in the paper Recurrent Convolutional Neural Networks for Text Classification.
- Word Representation Learning. RCNN uses a recurrent structure, which is a bi-directional recurrent neural network, to capture the contexts. Then, combine the word and its context to present the word. And apply a linear transformation together with the
tanh
activation fucntion to the representation. - Text Representation Learning. When all of the representations of words are calculated, it applys a element-wise max-pooling layer in order to capture the most important information throughout the entire text. Finally, do the linear transformation and apply the softmax function.
Network structure of RCNN:
RCNNVariant is an improved model based on RCNN with the following improvements. No related papers have been found yet.
- The three inputs are changed to single input. The input of the left and right contexts is removed.
- Use bidirectional LSTM/GRU instead of traditional RNN for encoding context.
- Use multi-channel CNN to represent the semantic vectors.
- Replace the Tanh activation layer with the ReLU activation layer.
- Use both AveragePooling and MaxPooling.
Network structure of RCNNVariant:
- Bag of Tricks for Efficient Text Classification
- Keras Example IMDB FastText
- Convolutional Neural Networks for Sentence Classification
- Keras Example IMDB CNN
- Recurrent Neural Network for Text Classification with Multi-Task Learning
- Neural Machine Translation by Jointly Learning to Align and Translate
- Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems
- cbaziotis's Attention
- Hierarchical Attention Networks for Document Classification
- Richard's HAN
- Recurrent Convolutional Neural Networks for Text Classification
- airalcorn2's RCNN