-
Notifications
You must be signed in to change notification settings - Fork 35
/
module.py
110 lines (92 loc) · 3.63 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch import nn, einsum
from einops import rearrange
class SepConv2d(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,):
super(SepConv2d, self).__init__()
self.depthwise = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels)
self.bn = torch.nn.BatchNorm2d(in_channels)
self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.bn(x)
x = self.pointwise(x)
return x
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ConvAttention(nn.Module):
def __init__(self, dim, img_size, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
last_stage=False):
super().__init__()
self.last_stage = last_stage
self.img_size = img_size
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
pad = (kernel_size - q_stride)//2
self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad)
self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad)
self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
if self.last_stage:
cls_token = x[:, 0]
x = x[:, 1:]
cls_token = rearrange(cls_token.unsqueeze(1), 'b n (h d) -> b h n d', h = h)
x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size)
q = self.to_q(x)
q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)
v = self.to_v(x)
v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
k = self.to_k(x)
k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
if self.last_stage:
q = torch.cat((cls_token, q), dim=2)
v = torch.cat((cls_token, v), dim=2)
k = torch.cat((cls_token, k), dim=2)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out