-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
180 lines (150 loc) · 5.44 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
""" Parts of the U-Net model """
from flax import linen as nn
import math
import jax.numpy as jnp
from typing import Optional
import jax
from functools import partial
@partial(jax.jit, static_argnames=['d_model', 'length'])
def positionalencoding1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
pe = jnp.zeros((length, d_model))
position = jnp.expand_dims(jnp.arange(0, length), 1)
div_term = jnp.exp(
(jnp.arange(0, d_model, 2, dtype=jnp.float32) * -(math.log(10000.0) / d_model))
)
pe = pe.at[:, 0::2].set(jnp.sin(position.astype(jnp.float32) * div_term))
pe = pe.at[:, 1::2].set(jnp.cos(position.astype(jnp.float32) * div_term))
return pe
def get_position_embeddings(t):
x = positionalencoding1d(32, 1000)
emb = x[t]
return emb
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
out_channels: int
mid_channels: Optional[int] = None
@nn.compact
def __call__(self, x, train: bool):
if not self.mid_channels:
mid_channels = self.out_channels
else:
mid_channels = self.mid_channels
x = nn.Conv(mid_channels, kernel_size=(3, 3), padding=1, use_bias=False)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
x = nn.Conv(self.out_channels, kernel_size=(3, 3), padding=1, use_bias=False)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
return x
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
out_channels: int
@nn.compact
def __call__(self, x, train: bool):
x = nn.max_pool(x, (2, 2), (2, 2))
x = DoubleConv(self.out_channels)(x, train)
return x
class Up(nn.Module):
"""Upscaling then double conv"""
in_channels: int
out_channels: int
bilinear: Optional[bool] = False
def setup(self):
if self.bilinear:
self.conv = DoubleConv(self.out_channels, self.in_channels // 2)
else:
self.up = nn.ConvTranspose(self.in_channels // 2, [2, 2], (2, 2))
self.conv = DoubleConv(self.out_channels)
def __call__(self, x1, x2, train: bool):
B, H, W, C = x1.shape
if self.bilinear:
x = jax.image.resize(x1, (B * 2, H * 2, W * 2, C * 2), method="bilinear")
else:
x = self.up(x1)
diffY = x2.shape[1] - x1.shape[1]
diffX = x2.shape[2] - x1.shape[2]
x1 = jnp.pad(
x1,
[
(0, 0),
(diffX // 2, diffX - diffX // 2),
(diffY // 2, diffY - diffY // 2),
(0, 0),
],
)
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = jnp.concatenate([x2, x1], axis=3)
return self.conv(x, train)
class OutConv(nn.Module):
out_channels: int
@nn.compact
def __call__(self, x):
return nn.Conv(self.out_channels, (1, 1))(x)
class UNet(nn.Module):
n_channels: int
n_classes: int
bilinear: bool
def setup(self):
self.inc = DoubleConv(64)
self.down1 = Down(128)
self.down2 = Down(256)
self.down3 = Down(512)
factor = 2 if self.bilinear else 1
self.down4 = Down(1024 // factor)
self.up1 = Up(1024, 512 // factor, self.bilinear)
self.up2 = Up(512, 256 // factor, self.bilinear)
self.up3 = Up(256, 128 // factor, self.bilinear)
self.up4 = Up(128, 64, self.bilinear)
self.outc = OutConv(self.n_classes)
self.class_embed = nn.Dense(32)
input_size = [32, 64, 128, 256, 512, 1024, 512, 256, 128, 64]
self.linears = [nn.Dense(input_size[i + 1]) for i in range(len(input_size) - 1)]
def __call__(self, x, t, y=None, train=True):
x1 = self.inc(x, train)
if y is not None:
y_embed = self.class_embed(y)
t = t + y_embed
t1 = self.linears[0](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x1 = x1 + t1
x2 = self.down1(x1, train)
t1 = self.linears[1](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x2 = x2 + t1
x3 = self.down2(x2, train)
t1 = self.linears[2](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x3 = x3 + t1
x4 = self.down3(x3, train)
t1 = self.linears[3](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x4 = x4 + t1
x5 = self.down4(x4, train)
t1 = self.linears[4](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x5 = x5 + t1
x = self.up1(x5, x4, train)
t1 = self.linears[5](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x = x + t1
x = self.up2(x, x3, train)
t1 = self.linears[6](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x = x + t1
x = self.up3(x, x2, train)
t1 = self.linears[7](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x = x + t1
x = self.up4(x, x1, train)
t1 = self.linears[8](t)
t1 = jnp.expand_dims(jnp.expand_dims(t1, 1), 1)
x = x + t1
logits = self.outc(x)
return logits