-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.lua
116 lines (101 loc) · 3.61 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
--[[
Trainer class for GalaxyZoo
By Xiang Zhang @ New York University
--]]
require("sys")
local Train = torch.class("Train")
-- Initialization of the trainer class
-- data: the data object
-- model: the model object
-- loss: the loss object
-- config: (optional) the configuration table
-- .rates: (optional) the table of learning rates, indexed by the number of epoches
-- .epoch: (optional) current epoch
function Train:__init(data,model,loss,config)
-- Store the objects
self.data = data
self.model = model
self.loss = loss
-- Store the configurations and states
local config = config or {}
self.rates = config.rates or {1e-3}
self.epoch = config.epoch or 1
-- Get the parameters and gradients
self.params, self.grads = self.model:getParameters()
self.old_grads = self.grads:clone():zero()
-- Make the loss correct type
self.loss:type(self.model:type())
-- Find the current rate
local max_epoch = 1
self.rate = self.rates[1]
for i,v in pairs(self.rates) do
if i <= self.epoch and i > max_epoch then
max_epoch = i
self.rate = v
end
end
-- Timing table
self.time = {}
-- Store the configurations
self.momentum = config.momentum or 0
self.decay = config.decay or 0
self.normalize = config.normalize
end
-- Run for a number of steps
-- epoches: number of epoches
-- logfunc: (optional) a function to execute after each step.
function Train:run(epoches,logfunc)
self.params, self.grads = nil,nil
collectgarbage()
self.params, self.grads = self.model:getParameters()
for i = 1,epoches do
self:batchStep()
if logfunc then logfunc(self,i) end
end
end
-- Run for one batch step
function Train:batchStep()
self.clock = sys.clock()
-- Get a batch of data
self.batch_untyped,self.labels_untyped = self.data:getBatch(self.batch_untyped,self.labels_untyped)
-- Make the data to correct type
self.batch = self.batch or self.batch_untyped:type(self.model:type())
self.labels = self.labels or self.labels_untyped:type(self.model:type())
self.batch:copy(self.batch_untyped)
self.labels:copy(self.labels_untyped)
-- Normalize
if self.normalize then
for i = 1,self.batch:size(1) do
self.batch[i]:add(-self.batch[i]:mean()):div(self.batch[i]:std()+1e-8)
end
end
-- Record time
if self.model:type() == "torch.CudaTensor" then cutorch.synchronize() end
self.time.data = sys.clock() - self.clock
self.clock = sys.clock()
-- Forward propagation
self.output = self.model:forward(self.batch)
self.objective = self.loss:forward(self.output,self.labels)
if type(self.objective) ~= "number" then self.objective = self.objective[1] end
-- Record time
if self.model:type() == "torch.CudaTensor" then cutorch.synchronize() end
self.time.forward = sys.clock() - self.clock
self.clock = sys.clock()
-- Backward propagation
self.grads:zero()
self.gradOutput = self.loss:backward(self.output,self.labels)
self.gradBatch = self.model:backward(self.batch,self.gradOutput)
-- Record time
if self.model:type() == "torch.CudaTensor" then cutorch.synchronize() end
self.time.backward = sys.clock() - self.clock
self.clock = sys.clock()
-- Update the step
self.old_grads:mul(self.momentum):add(self.grads:mul(-self.rate))
self.params:mul(1-self.decay*self.rate):add(self.old_grads)
if self.model:type() == "torch.CudaTensor" then cutorch.synchronize() end
self.time.update = sys.clock() - self.clock
-- Increment on the epoch
self.epoch = self.epoch + 1
-- Change the learning rate
self.rate = self.rates[self.epoch] or self.rate
end