In an attempt to learn Tensorflow, I have implemented the model in A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction using Tensorflow 1.13.
- Nasdaq data is used for testing, which is from repo da-rnn.
- Based on the discussion, i implemented both cases where current exogenous factor is included, i.e., as well as excluded, i.e. . The switch between the two modes is control by FLAGS.use_cur_exg in da_rnn/main.py.
- To avoid overfitting, a flag to shuffle the train data has been added, activated by FLAGS.shuffle_train in da_rnn/main.py.
- A ModelRunner class is added to control the pipeline of model training and evaluation.
Put the downloaded Nasdaq csv file under data/data_nasdaq.
da_rnn
|__data
|__data_nasdaq
|__nasdaq100_padding.csv
Suppose we want to run 100 epochs and use Tensorboard to visualize the process
cd da_rnn
python main.py --write_summary True --max_epoch 200
To check the description of all flags
python main.py -helpful
To open tensorboard
tensorboard --logdir=path
where path can be found in the log which shows the relative dir where the model is saved, e.g. logs/ModelWrapper/lr-0.001_encoder-32_decoder-32/20190922-103703/saved_model/tfb_dir.
Results of my experiments are listed below. Running more epochs and applying larger encoder/decoder dimension could possibly achieve better results.
# Epoch | Shuffle Train | Use Current Exg | Econder/Decoder Dim | RMSE | MAE | MAPE |
---|---|---|---|---|---|---|
100 | False | False | 32 | 105.671 | 104.60 | 2.15% |
100 | True | False | 32 | 29.849 | 29.033 | 0.59% |
100 | False | True | 32 | 46.287 | 32.398 | 0.66% |
100 | True | True | 32 | 1.491 | 1.172 | 0.024% |
# To shuffle the train data and use current exogeneous factor
python main.py --write_summary True --max_epoch 100 --shuffle_train True --use_cur_exg True
After 100 epochs(with data shuffled and current exogenous factors used) the prediction is plot as
tensorflow==1.13.1
scikit-learn==0.21.3
numpy==1.16.4
Although I have not tested, I guess it should be working under tf 1.12 and tf 1.14 as well.