-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
77 lines (65 loc) · 2.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import os
import torch.nn as nn
import numpy as np
import time
from model import textCNN
import sen2inds
import textCNN_data
word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt')
label_w2n, label_n2w = sen2inds.read_labelFile('label.txt')
textCNN_param = {
'vocab_size': len(word2ind),
'embed_dim': 60,
'class_num': len(label_w2n),
"kernel_num": 16,
"kernel_size": [3, 4, 5],
"dropout": 0.5,
}
dataLoader_param = {
'batch_size': 128,
'shuffle': True,
}
def main():
#init net
print('init net...')
net = textCNN(textCNN_param)
weightFile = 'weight.pkl'
modelFile = 'model/'
if os.path.exists(weightFile):
print('load weight')
net.load_state_dict(torch.load(weightFile))
else:
net.init_weight()
print(net)
net.cuda()
#init dataset
print('init dataset...')
dataLoader = textCNN_data.textCNN_dataLoader(dataLoader_param)
valdata = textCNN_data.get_valdata()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
criterion = nn.NLLLoss() # 负的log likelihood loss损失。用于训练一个n类分类器。
log = open('log_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
log.write('epoch step loss\n')
log_test = open('log_test_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
log_test.write('epoch step test_acc\n')
print("training...")
for epoch in range(100):
for i, (clas, sentences) in enumerate(dataLoader):
optimizer.zero_grad()
sentences = sentences.type(torch.LongTensor).cuda()
clas = clas.type(torch.LongTensor).cuda()
out = net(sentences)
loss = criterion(out, clas) # Input: (N,C) , C是类别的个数 Target: (N) , target中每个值的大小满足 0 <= targets[i] <= C-1
loss.backward()
optimizer.step()
if (i + 1) % 1 == 0:
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())
data = str(epoch + 1) + ' ' + str(i + 1) + ' ' + str(loss.item()) + '\n'
log.write(data)
print("save model...")
torch.save(net.state_dict(), weightFile)
torch.save(net.state_dict(), modelFile+"model\{}_model_iter_{}_{}_loss_{:.2f}.pkl".format(time.strftime('%y%m%d%H'), epoch, i, loss.item())) # current is model.pkl
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())
if __name__ == "__main__":
main()