-
Notifications
You must be signed in to change notification settings - Fork 0
/
transform_net.py
46 lines (40 loc) · 1.64 KB
/
transform_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization, Conv2D, Add, Layer, Conv2DTranspose, Activation
from layers import ConvLayer, ResBlock, ConvTLayer
class TransformNet:
def __init__(self):
self.conv1 = ConvLayer(32, (9,9), strides=(1,1), padding='same', name="conv_1")
self.conv2 = ConvLayer(64, (3,3), strides=(2,2), padding='same', name="conv_2")
self.conv3 = ConvLayer(128, (3,3), strides=(2,2), padding='same', name="conv_3")
self.res1 = ResBlock(128, prefix="res_1")
self.res2 = ResBlock(128, prefix="res_2")
self.res3 = ResBlock(128, prefix="res_3")
self.res4 = ResBlock(128, prefix="res_4")
self.res5 = ResBlock(128, prefix="res_5")
self.convt1 = ConvTLayer(64, (3,3), strides=(2,2), padding='same', name="conv_t_1")
self.convt2 = ConvTLayer(32, (3,3), strides=(2,2), padding='same', name="conv_t_2")
self.conv4 = ConvLayer(3, (9,9), strides=(1,1), padding='same', activate=False, name="conv_4")
self.tanh = Activation('tanh')
self.model = self._get_model()
def _get_model(self):
inputs = tf.keras.Input(shape=(None,None,3))
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.convt1(x)
x = self.convt2(x)
x = self.conv4(x)
x = self.tanh(x)
x = (x + 1) * (255. / 2)
return tf.keras.Model(inputs, x, name="transformnet")
def get_variables(self):
return self.model.trainable_variables
def preprocess(self, img):
return img / 255.0
def postprocess(self, img):
return tf.clip_by_value(img, 0.0, 255.0)