-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
116 lines (95 loc) · 4.71 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from keras.engine.topology import Layer
from keras.utils.generic_utils import get_custom_objects
from keras import backend as K
import itertools
''' Theano Backend function '''
def depth_to_space_th(input, scale, data_format=None):
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
import theano.tensor as T
if data_format is None:
data_format = K.image_dim_ordering()
data_format = data_format.lower()
input = K._preprocess_conv2d_input(input, data_format)
b, k, row, col = input.shape
output_shape = (b, k // (scale ** 2), row * scale, col * scale)
out = T.zeros(output_shape)
r = scale
for y, x in itertools.product(range(scale), repeat=2):
out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x:: r * r, :, :])
out = K._postprocess_conv2d_output(out, input, None, None, None, data_format)
return out
''' Tensorflow Backend Function (NOT TESTED '''
# TODO: Test on Tensorflow backend
def depth_to_space_tf(input, scale, data_format=None):
''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
import tensorflow as tf
if data_format is None:
data_format = K.image_dim_ordering()
data_format = data_format.lower()
input = K._preprocess_conv2d_input(input, data_format)
out = tf.depth_to_space(input, scale)
out = K._postprocess_conv2d_output(out, data_format)
return out
class SubPixelUpscaling(Layer):
""" Sub-pixel convolutional upscaling layer based on the paper "Real-Time Single Image
and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network"
(https://arxiv.org/abs/1609.05158).
This layer requires a Convolution2D prior to it, having output filters computed according to
the formula :
filters = k * (scale_factor * scale_factor)
where k = a user defined number of filters (generally larger than 32)
scale_factor = the upscaling factor (generally 2)
This layer performs the depth to space operation on the convolution filters, and returns a
tensor with the size as defined below.
# Example :
```python
# A standard subpixel upscaling block
x = Convolution2D(256, 3, 3, padding='same', activation='relu')(...)
u = SubPixelUpscaling(scale_factor=2)(x)
[Optional]
x = Convolution2D(256, 3, 3, padding='same', activation='relu')(u)
```
In practice, it is useful to have a second convolution layer after the
SubPixelUpscaling layer to speed up the learning process.
However, if you are stacking multiple SubPixelUpscaling blocks, it may increase
the number of parameters greatly, so the Convolution layer after SubPixelUpscaling
layer can be removed.
# Arguments
scale_factor: Upscaling factor.
data_format: Can be None, 'channels_first' or 'channels_last'.
# Input shape
4D tensor with shape:
`(samples, k * (scale_factor * scale_factor) channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows, cols, k * (scale_factor * scale_factor) channels)` if data_format='channels_last'.
# Output shape
4D tensor with shape:
`(samples, k channels, rows * scale_factor, cols * scale_factor))` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows * scale_factor, cols * scale_factor, k channels)` if data_format='channels_last'.
"""
def __init__(self, scale_factor=2, data_format=None, **kwargs):
super(SubPixelUpscaling, self).__init__(**kwargs)
self.scale_factor = scale_factor
self.data_format = K.image_dim_ordering() if data_format == None else data_format
def build(self, input_shape):
pass
def call(self, x, mask=None):
if K.backend() == 'th':
y = depth_to_space_th(x, self.scale_factor, self.data_format)
else:
y = depth_to_space_tf(x, self.scale_factor, self.data_format)
return y
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
b, k, r, c = input_shape
return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor)
else:
b, r, c, k = input_shape
return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2))
def get_config(self):
config = {'scale_factor': self.scale_factor,
'data_format': self.data_format}
base_config = super(SubPixelUpscaling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling})