-
Notifications
You must be signed in to change notification settings - Fork 72
/
main.py
26 lines (20 loc) · 955 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pandas as pd
from torch.utils.data import DataLoader
from es_rnn.data_loading import create_datasets, SeriesDataset
from es_rnn.config import get_config
from es_rnn.trainer import ESRNNTrainer
from es_rnn.model import ESRNN
import time
print('loading config')
config = get_config('Monthly')
print('loading data')
info = pd.read_csv('../data/info.csv')
train_path = '../data/Train/%s-train.csv' % (config['variable'])
test_path = '../data/Test/%s-test.csv' % (config['variable'])
train, val, test = create_datasets(train_path, test_path, config['output_size'])
dataset = SeriesDataset(train, val, test, info, config['variable'], config['chop_val'], config['device'])
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
run_id = str(int(time.time()))
model = ESRNN(num_series=len(dataset), config=config)
tr = ESRNNTrainer(model, dataloader, run_id, config, ohe_headers=dataset.dataInfoCatHeaders)
tr.train_epochs()