-
Notifications
You must be signed in to change notification settings - Fork 15
/
normalizations.py
554 lines (480 loc) · 25.2 KB
/
normalizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
import keras_contrib_backend as K_contrib
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
class BatchRenormalization(Layer):
"""Batch renormalization layer (Sergey Ioffe, 2017).
Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchRenormalization`.
momentum: momentum in the computation of the
exponential average of the mean and standard deviation
of the data, for feature-wise normalization.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
epsilon: small float > 0. Fuzz parameter.
Theano expects epsilon >= 1e-5.
r_max_value: Upper limit of the value of r_max.
d_max_value: Upper limit of the value of d_max.
t_delta: At each iteration, increment the value of t by t_delta.
weights: Initialization weights.
List of 2 Numpy arrays, with shapes:
`[(input_shape,), (input_shape,)]`
Note that the order of this list is [gamma, beta, mean, std]
beta_initializer: name of initialization function for shift parameter
(see [initializers](../initializers.md)), or alternatively,
Theano/TensorFlow function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
gamma_initializer: name of initialization function for scale parameter (see
[initializers](../initializers.md)), or alternatively,
Theano/TensorFlow function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
moving_mean_initializer: Initializer for the moving mean.
moving_variance_initializer: Initializer for the moving variance.
gamma_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the gamma vector.
beta_regularizer: instance of [WeightRegularizer](../regularizers.md),
applied to the beta vector.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def __init__(self, axis=-1, momentum=0.99, center=True, scale=True, epsilon=1e-3,
r_max_value=3., d_max_value=5., t_delta=1e-3, weights=None, beta_initializer='zero',
gamma_initializer='one', moving_mean_initializer='zeros',
moving_variance_initializer='ones', gamma_regularizer=None, beta_regularizer=None,
beta_constraint=None, gamma_constraint=None, **kwargs):
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.momentum = momentum
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.initial_weights = weights
self.r_max_value = r_max_value
self.d_max_value = d_max_value
self.t_delta = t_delta
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
self.moving_variance_initializer = initializers.get(moving_variance_initializer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
super(BatchRenormalization, self).__init__(**kwargs)
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(shape,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
name='{}_gamma'.format(self.name))
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
name='{}_beta'.format(self.name))
else:
self.beta = None
self.running_mean = self.add_weight(shape, initializer=self.moving_mean_initializer,
name='{}_running_mean'.format(self.name),
trainable=False)
self.running_variance = self.add_weight(shape, initializer=self.moving_variance_initializer,
name='{}_running_std'.format(self.name),
trainable=False)
self.r_max = K.variable(1, name='{}_r_max'.format(self.name))
self.d_max = K.variable(0, name='{}_d_max'.format(self.name))
self.t = K.variable(0, name='{}_t'.format(self.name))
self.t_delta_tensor = K.constant(self.t_delta)
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def call(self, inputs, training=None):
assert self.built, 'Layer must be built before being called'
input_shape = K.int_shape(inputs)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
mean_batch, var_batch = K_contrib.moments(inputs, reduction_axes, shift=None, keep_dims=False)
std_batch = (K.sqrt(var_batch + self.epsilon))
r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
r = K.stop_gradient(K_contrib.clip(r, 1 / self.r_max, self.r_max))
d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon)
d = K.stop_gradient(K_contrib.clip(d, -self.d_max, self.d_max))
if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
x_normed_batch = (inputs - mean_batch) / std_batch
x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
else:
# need broadcasting
broadcast_mean = K.reshape(mean_batch, broadcast_shape)
broadcast_std = K.reshape(std_batch, broadcast_shape)
broadcast_r = K.reshape(r, broadcast_shape)
broadcast_d = K.reshape(d, broadcast_shape)
broadcast_beta = K.reshape(self.beta, broadcast_shape)
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
x_normed_batch = (inputs - broadcast_mean) / broadcast_std
x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta
# explicit update to moving mean and standard deviation
self.add_update([K.moving_average_update(self.running_mean, mean_batch, self.momentum),
K.moving_average_update(self.running_variance, std_batch ** 2, self.momentum)], inputs)
# update r_max and d_max
r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t))
d_val = self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t)))
self.add_update([K.update(self.r_max, r_val),
K.update(self.d_max, d_val),
K.update_add(self.t, self.t_delta_tensor)], inputs)
if training in {0, False}:
return x_normed
else:
def normalize_inference():
if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
x_normed_running = K.batch_normalization(
inputs, self.running_mean, self.running_variance,
self.beta, self.gamma,
epsilon=self.epsilon)
return x_normed_running
else:
# need broadcasting
broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
broadcast_running_std = K.reshape(self.running_variance, broadcast_shape)
broadcast_beta = K.reshape(self.beta, broadcast_shape)
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
x_normed_running = K.batch_normalization(
inputs, broadcast_running_mean, broadcast_running_std,
broadcast_beta, broadcast_gamma,
epsilon=self.epsilon)
return x_normed_running
# pick the normalized form of inputs corresponding to the training phase
# for batch renormalization, inference time remains same as batchnorm
x_normed = K.in_train_phase(x_normed, normalize_inference, training=training)
return x_normed
def get_config(self):
config = {'epsilon': self.epsilon,
'axis': self.axis,
'center': self.center,
'scale': self.scale,
'momentum': self.momentum,
'gamma_regularizer': initializers.serialize(self.gamma_regularizer),
'beta_regularizer': initializers.serialize(self.beta_regularizer),
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint),
'r_max_value': self.r_max_value,
'd_max_value': self.d_max_value,
't_delta': self.t_delta}
base_config = super(BatchRenormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'BatchRenormalization': BatchRenormalization})
class GroupNormalization(Layer):
"""Group normalization layer
Group Normalization divides the channels into groups and computes within each group
the mean and variance for normalization. Group Normalization's computation is independent
of batch sizes, and its accuracy is stable in a wide range of batch sizes.
Relation to Layer Normalization:
If the number of groups is set to 1, then this operation becomes identical to
Layer Normalization.
Relation to Instance Normalization:
If the number of groups is set to the input dimension (number of groups is equal
to number of channels), then this operation becomes identical to Instance Normalization.
# Arguments
groups: Integer, the number of groups for Group Normalization.
Can be in the range [1, N] where N is the input dimension.
The input dimension must be divisible by the number of groups.
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Group Normalization](https://arxiv.org/abs/1803.08494)
"""
def __init__(self,
groups=32,
axis=-1,
epsilon=1e-5,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(GroupNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
if dim < self.groups:
raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
'more than the number of channels (' +
str(dim) + ').')
if dim % self.groups != 0:
raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
'multiple of the number of channels (' +
str(dim) + ').')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)
# Prepare broadcasting shape.
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(1, self.groups)
reshape_group_shape = K.shape(inputs)
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
group_axes[self.axis] = input_shape[self.axis] // self.groups
group_axes.insert(1, self.groups)
# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
group_shape = K.stack(group_shape)
inputs = K.reshape(inputs, group_shape)
group_reduction_axes = list(range(len(group_axes)))
mean, variance = K_contrib.moments(inputs, group_reduction_axes[2:], keep_dims=True)
inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
# prepare broadcast shape
inputs = K.reshape(inputs, group_shape)
outputs = inputs
# In this case we must explicitly broadcast all parameters.
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta
# finally we reshape the output back to the input shape
outputs = K.reshape(outputs, tensor_input_shape)
return outputs
def get_config(self):
config = {
'groups': self.groups,
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(GroupNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
get_custom_objects().update({'GroupNormalization': GroupNormalization})