forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
gru_cell.py
172 lines (150 loc) · 5.01 KB
/
gru_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import functools
from caffe2.python import brew, rnn_cell
class GRUCell(rnn_cell.RNNCell):
def __init__(
self,
input_size,
hidden_size,
forget_bias, # Currently unused! Values here will be ignored.
memory_optimization,
drop_states=False,
linear_before_reset=False,
**kwargs
):
super(GRUCell, self).__init__(**kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.forget_bias = float(forget_bias)
self.memory_optimization = memory_optimization
self.drop_states = drop_states
self.linear_before_reset = linear_before_reset
# Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
# (reset gate -> output_gate)
# So, much of the logic to calculate the reset gate output and modified
# output gate input is set here, in the graph definition.
# The remaining logic lives in in gru_unit_op.{h,cc}.
def _apply(
self,
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs=None,
):
hidden_t_prev = states[0]
# Split input tensors to get inputs for each gate.
input_t_reset, input_t_update, input_t_output = model.net.Split(
[
input_t,
],
[
self.scope('input_t_reset'),
self.scope('input_t_update'),
self.scope('input_t_output'),
],
axis=2,
)
# Fully connected layers for reset and update gates.
reset_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('reset_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
update_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('update_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Calculating the modified hidden state going into output gate.
reset_gate_t = model.net.Sum(
[reset_gate_t, input_t_reset],
self.scope('reset_gate_t')
)
reset_gate_t_sigmoid = model.net.Sigmoid(
reset_gate_t,
self.scope('reset_gate_t_sigmoid')
)
# `self.linear_before_reset = True` matches cudnn semantics
if self.linear_before_reset:
output_gate_fc = brew.fc(
model,
hidden_t_prev,
self.scope('output_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
output_gate_t = model.net.Mul(
[reset_gate_t_sigmoid, output_gate_fc],
self.scope('output_gate_t_mul')
)
else:
modified_hidden_t_prev = model.net.Mul(
[reset_gate_t_sigmoid, hidden_t_prev],
self.scope('modified_hidden_t_prev')
)
output_gate_t = brew.fc(
model,
modified_hidden_t_prev,
self.scope('output_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Add input contributions to update and output gate.
# We already (in-place) added input contributions to the reset gate.
update_gate_t = model.net.Sum(
[update_gate_t, input_t_update],
self.scope('update_gate_t'),
)
output_gate_t = model.net.Sum(
[output_gate_t, input_t_output],
self.scope('output_gate_t_summed'),
)
# Join gate outputs and add input contributions
gates_t, _gates_t_concat_dims = model.net.Concat(
[
reset_gate_t,
update_gate_t,
output_gate_t,
],
[
self.scope('gates_t'),
self.scope('_gates_t_concat_dims'),
],
axis=2,
)
if seq_lengths is not None:
inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
else:
inputs = [hidden_t_prev, gates_t, timestep]
hidden_t = model.net.GRUUnit(
inputs,
list(self.get_state_names()),
forget_bias=self.forget_bias,
drop_states=self.drop_states,
sequence_lengths=(seq_lengths is not None),
)
model.net.AddExternalOutputs(hidden_t)
return (hidden_t,)
def prepare_input(self, model, input_blob):
return brew.fc(
model,
input_blob,
self.scope('i2h'),
dim_in=self.input_size,
dim_out=3 * self.hidden_size,
axis=2,
)
def get_state_names(self):
return (self.scope('hidden_t'),)
def get_output_dim(self):
return self.hidden_size
GRU = functools.partial(rnn_cell._LSTM, GRUCell)