-
Notifications
You must be signed in to change notification settings - Fork 92
/
main.py
73 lines (55 loc) · 2.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import random
import torch
import torch.nn as nn
import torch.optim as optim
from model import GGNN
from utils.train import train
from utils.test import test
from utils.data.dataset import bAbIDataset
from utils.data.dataloader import bAbIDataloader
parser = argparse.ArgumentParser()
parser.add_argument('--task_id', type=int, default=4, help='bAbI task id')
parser.add_argument('--question_id', type=int, default=0, help='question types')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=10, help='input batch size')
parser.add_argument('--state_dim', type=int, default=4, help='GGNN hidden state size')
parser.add_argument('--n_steps', type=int, default=5, help='propogation steps number of GGNN')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--verbal', action='store_true', help='print training info or not')
parser.add_argument('--manualSeed', type=int, help='manual seed')
opt = parser.parse_args()
print(opt)
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
opt.dataroot = 'babi_data/processed_1/train/%d_graphs.txt' % opt.task_id
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)
def main(opt):
train_dataset = bAbIDataset(opt.dataroot, opt.question_id, True)
train_dataloader = bAbIDataloader(train_dataset, batch_size=opt.batchSize, \
shuffle=True, num_workers=2)
test_dataset = bAbIDataset(opt.dataroot, opt.question_id, False)
test_dataloader = bAbIDataloader(test_dataset, batch_size=opt.batchSize, \
shuffle=False, num_workers=2)
opt.annotation_dim = 1 # for bAbI
opt.n_edge_types = train_dataset.n_edge_types
opt.n_node = train_dataset.n_node
net = GGNN(opt)
net.double()
print(net)
criterion = nn.CrossEntropyLoss()
if opt.cuda:
net.cuda()
criterion.cuda()
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
for epoch in range(0, opt.niter):
train(epoch, train_dataloader, net, criterion, optimizer, opt)
test(test_dataloader, net, criterion, optimizer, opt)
if __name__ == "__main__":
main(opt)