Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument types to be able to use torch JIT #54

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 138 additions & 127 deletions unimatch/attention.py
Original file line number Diff line number Diff line change
@@ -1,166 +1,180 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d

class single_head_full_attention(nn.Module):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3

def single_head_full_attention(q, k, v):
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
attn = torch.softmax(scores, dim=2) # [B, L, L]
out = torch.matmul(attn, v) # [B, L, C]

scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
attn = torch.softmax(scores, dim=2) # [B, L, L]
out = torch.matmul(attn, v) # [B, L, C]

return out
return out

class single_head_full_attention_1d(nn.Module):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
h: Optional[int] = None,
w: Optional[int] = None
) -> torch.Tensor:
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w

def single_head_full_attention_1d(q, k, v,
h=None,
w=None,
):
# q, k, v: [B, L, C]
b, _, c = q.size()

assert h is not None and w is not None
assert q.size(1) == h * w
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)

b, _, c = q.size()
scale_factor = c ** 0.5

q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]

scale_factor = c ** 0.5
attn = torch.softmax(scores, dim=-1)

scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]

attn = torch.softmax(scores, dim=-1)
return out

out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
class single_head_split_window_attention(nn.Module):
def __init__(self):
super().__init__()
self.split_feature = split_feature()
self.merge_splits = merge_splits()

return out
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
num_splits: int = 1,
with_shift: bool = False,
h: Optional[int] = None,
w: Optional[int] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3

assert h is not None and w is not None
assert q.size(1) == h * w

def single_head_split_window_attention(q, k, v,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
b, _, c = q.size()

assert h is not None and w is not None
assert q.size(1) == h * w
b_new = b * num_splits * num_splits

b, _, c = q.size()
window_size_h = int(h // num_splits)
window_size_w = int(w // num_splits)

b_new = b * num_splits * num_splits
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)

window_size_h = h // num_splits
window_size_w = w // num_splits
scale_factor = c ** 0.5

q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
shift_size_w = 0
shift_size_h = 0

scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_h = window_size_h // 2
shift_size_w = window_size_w // 2

if with_shift:
assert attn_mask is not None # compute once
shift_size_h = window_size_h // 2
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))

q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
q = self.split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
k = self.split_feature(k, num_splits=num_splits, channel_last=True)
v = self.split_feature(v, num_splits=num_splits, channel_last=True)

q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
k = split_feature(k, num_splits=num_splits, channel_last=True)
v = split_feature(v, num_splits=num_splits, channel_last=True)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]

scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
if with_shift and attn_mask is not None:
scores += attn_mask.repeat(b, 1, 1)

if with_shift:
scores += attn_mask.repeat(b, 1, 1)
attn = torch.softmax(scores, dim=-1)

attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]

out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
out = self.merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
num_splits=num_splits, channel_last=True) # [B, H, W, C]

out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
num_splits=num_splits, channel_last=True) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))

# shift back
if with_shift:
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
out = out.view(b, -1, c)

out = out.view(b, -1, c)
return out

return out
class single_head_split_window_attention_1d(nn.Module):
def __init__(self):
super().__init__()
self.split_feature_1d = split_feature_1d()
self.merge_splits_1d = merge_splits_1d()

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
num_splits: int = 1,
with_shift: bool = False,
h: Optional[int] = None,
w: Optional[int] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# q, k, v: [B, L, C]

def single_head_split_window_attention_1d(q, k, v,
relative_position_bias=None,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# q, k, v: [B, L, C]
assert h is not None and w is not None
assert q.size(1) == h * w

assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()

b, _, c = q.size()
b_new = b * num_splits * h

b_new = b * num_splits * h
window_size_w = w // num_splits

window_size_w = w // num_splits
q = q.view(b * h, w, c) # [B*H, W, C]
k = k.view(b * h, w, c)
v = v.view(b * h, w, c)

q = q.view(b * h, w, c) # [B*H, W, C]
k = k.view(b * h, w, c)
v = v.view(b * h, w, c)
scale_factor = c ** 0.5

scale_factor = c ** 0.5
shift_size_w = 0

if with_shift:
assert attn_mask is not None # compute once
shift_size_w = window_size_w // 2
if with_shift:
assert attn_mask is not None # compute once
shift_size_w = window_size_w // 2

q = torch.roll(q, shifts=-shift_size_w, dims=1)
k = torch.roll(k, shifts=-shift_size_w, dims=1)
v = torch.roll(v, shifts=-shift_size_w, dims=1)
q = torch.roll(q, shifts=-shift_size_w, dims=1)
k = torch.roll(k, shifts=-shift_size_w, dims=1)
v = torch.roll(v, shifts=-shift_size_w, dims=1)

q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
k = split_feature_1d(k, num_splits=num_splits)
v = split_feature_1d(v, num_splits=num_splits)
q = self.split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
k = self.split_feature_1d(k, num_splits=num_splits)
v = self.split_feature_1d(v, num_splits=num_splits)

scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*H*K, W/K, W/K]
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*H*K, W/K, W/K]

if with_shift:
# attn_mask: [K, W/K, W/K]
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
if with_shift and attn_mask is not None:
# attn_mask: [K, W/K, W/K]
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]

attn = torch.softmax(scores, dim=-1)
attn = torch.softmax(scores, dim=-1)

out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]

out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
out = self.merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]

# shift back
if with_shift:
out = torch.roll(out, shifts=shift_size_w, dims=2)
# shift back
if with_shift:
out = torch.roll(out, shifts=shift_size_w, dims=2)

out = out.view(b, -1, c)
out = out.view(b, -1, c)

return out
return out


class SelfAttnPropagation(nn.Module):
Expand All @@ -169,9 +183,7 @@ class SelfAttnPropagation(nn.Module):
query: feature0, key: feature0, value: flow
"""

def __init__(self, in_channels,
**kwargs,
):
def __init__(self, in_channels: int):
super(SelfAttnPropagation, self).__init__()

self.q_proj = nn.Linear(in_channels, in_channels)
Expand All @@ -181,11 +193,10 @@ def __init__(self, in_channels,
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def forward(self, feature0, flow,
local_window_attn=False,
local_window_radius=1,
**kwargs,
):
def forward(self, feature0: torch.Tensor, flow: torch.Tensor,
local_window_attn: bool = False,
local_window_radius: int = 1
) -> torch.Tensor:
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
if local_window_attn:
return self.forward_local_window_attn(feature0, flow,
Expand Down Expand Up @@ -214,9 +225,9 @@ def forward(self, feature0, flow,

return out

def forward_local_window_attn(self, feature0, flow,
local_window_radius=1,
):
def forward_local_window_attn(self, feature0: torch.Tensor, flow: torch.Tensor,
local_window_radius: int = 1
) -> torch.Tensor:
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
assert local_window_radius > 0

Expand All @@ -227,21 +238,21 @@ def forward_local_window_attn(self, feature0, flow,
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]

kernel_size = 2 * local_window_radius + 1
kernel_size = int(2 * local_window_radius + 1)

feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)

feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
feature0_window = F.unfold(feature0_proj, kernel_size=int(kernel_size),
padding=int(local_window_radius)) # [B, C*(2R+1)^2), H*W]

feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
feature0_window = feature0_window.view(b, c, int(kernel_size ** 2), h, w).permute(
0, 3, 4, 1, 2).reshape(b * h * w, c, int(kernel_size ** 2)) # [B*H*W, C, (2R+1)^2]

flow_window = F.unfold(flow, kernel_size=kernel_size,
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
flow_window = F.unfold(flow, kernel_size=int(kernel_size),
padding=int(local_window_radius)) # [B, 2*(2R+1)^2), H*W]

flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
flow_window = flow_window.view(b, value_channel, int(kernel_size ** 2), h, w).permute(
0, 3, 4, 2, 1).reshape(b * h * w, int(kernel_size ** 2), value_channel) # [B*H*W, (2R+1)^2, 2]

scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]

Expand Down
Loading