From 29f9db150b244dc3d4983706248160baa09e5573 Mon Sep 17 00:00:00 2001 From: AdrianEddy Date: Mon, 1 Apr 2024 04:18:33 +0200 Subject: [PATCH 1/3] Add types to be able to use torch JIT --- unimatch/attention.py | 265 +++++++++++++++++----------------- unimatch/backbone.py | 8 +- unimatch/geometry.py | 77 ++++++---- unimatch/matching.py | 51 ++++--- unimatch/transformer.py | 101 +++++++------ unimatch/trident_conv.py | 3 +- unimatch/unimatch.py | 74 ++++++---- unimatch/utils.py | 298 +++++++++++++++++++-------------------- utils/utils.py | 5 +- 9 files changed, 469 insertions(+), 413 deletions(-) diff --git a/unimatch/attention.py b/unimatch/attention.py index a10f758..240d5c7 100755 --- a/unimatch/attention.py +++ b/unimatch/attention.py @@ -1,166 +1,180 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Optional from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d +class single_head_full_attention(nn.Module): + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 -def single_head_full_attention(q, k, v): - # q, k, v: [B, L, C] - assert q.dim() == k.dim() == v.dim() == 3 + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] - scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] - attn = torch.softmax(scores, dim=2) # [B, L, L] - out = torch.matmul(attn, v) # [B, L, C] - - return out + return out +class single_head_full_attention_1d(nn.Module): + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + h: Optional[int] = None, + w: Optional[int] = None + ) -> torch.Tensor: + # q, k, v: [B, L, C] + assert h is not None and w is not None + assert q.size(1) == h * w -def single_head_full_attention_1d(q, k, v, - h=None, - w=None, - ): - # q, k, v: [B, L, C] + b, _, c = q.size() - assert h is not None and w is not None - assert q.size(1) == h * w + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) - b, _, c = q.size() + scale_factor = c ** 0.5 - q = q.view(b, h, w, c) # [B, H, W, C] - k = k.view(b, h, w, c) - v = v.view(b, h, w, c) + scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] - scale_factor = c ** 0.5 + attn = torch.softmax(scores, dim=-1) - scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] + out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] - attn = torch.softmax(scores, dim=-1) + return out - out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] +class single_head_split_window_attention(nn.Module): + def __init__(self): + super().__init__() + self.split_feature = split_feature() + self.merge_splits = merge_splits() - return out + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + num_splits: int = 1, + with_shift: bool = False, + h: Optional[int] = None, + w: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + assert h is not None and w is not None + assert q.size(1) == h * w -def single_head_split_window_attention(q, k, v, - num_splits=1, - with_shift=False, - h=None, - w=None, - attn_mask=None, - ): - # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py - # q, k, v: [B, L, C] - assert q.dim() == k.dim() == v.dim() == 3 + b, _, c = q.size() - assert h is not None and w is not None - assert q.size(1) == h * w + b_new = b * num_splits * num_splits - b, _, c = q.size() + window_size_h = int(h // num_splits) + window_size_w = int(w // num_splits) - b_new = b * num_splits * num_splits + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) - window_size_h = h // num_splits - window_size_w = w // num_splits + scale_factor = c ** 0.5 - q = q.view(b, h, w, c) # [B, H, W, C] - k = k.view(b, h, w, c) - v = v.view(b, h, w, c) + shift_size_w = 0 + shift_size_h = 0 - scale_factor = c ** 0.5 + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 - if with_shift: - assert attn_mask is not None # compute once - shift_size_h = window_size_h // 2 - shift_size_w = window_size_w // 2 + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) - q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) - k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) - v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + q = self.split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = self.split_feature(k, num_splits=num_splits, channel_last=True) + v = self.split_feature(v, num_splits=num_splits, channel_last=True) - q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] - k = split_feature(k, num_splits=num_splits, channel_last=True) - v = split_feature(v, num_splits=num_splits, channel_last=True) + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] - scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) - ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + if with_shift and attn_mask is not None: + scores += attn_mask.repeat(b, 1, 1) - if with_shift: - scores += attn_mask.repeat(b, 1, 1) + attn = torch.softmax(scores, dim=-1) - attn = torch.softmax(scores, dim=-1) + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] - out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + out = self.merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] - out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), - num_splits=num_splits, channel_last=True) # [B, H, W, C] + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) - # shift back - if with_shift: - out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + out = out.view(b, -1, c) - out = out.view(b, -1, c) + return out - return out +class single_head_split_window_attention_1d(nn.Module): + def __init__(self): + super().__init__() + self.split_feature_1d = split_feature_1d() + self.merge_splits_1d = merge_splits_1d() + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + num_splits: int = 1, + with_shift: bool = False, + h: Optional[int] = None, + w: Optional[int] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # q, k, v: [B, L, C] -def single_head_split_window_attention_1d(q, k, v, - relative_position_bias=None, - num_splits=1, - with_shift=False, - h=None, - w=None, - attn_mask=None, - ): - # q, k, v: [B, L, C] + assert h is not None and w is not None + assert q.size(1) == h * w - assert h is not None and w is not None - assert q.size(1) == h * w + b, _, c = q.size() - b, _, c = q.size() + b_new = b * num_splits * h - b_new = b * num_splits * h + window_size_w = w // num_splits - window_size_w = w // num_splits + q = q.view(b * h, w, c) # [B*H, W, C] + k = k.view(b * h, w, c) + v = v.view(b * h, w, c) - q = q.view(b * h, w, c) # [B*H, W, C] - k = k.view(b * h, w, c) - v = v.view(b * h, w, c) + scale_factor = c ** 0.5 - scale_factor = c ** 0.5 + shift_size_w = 0 - if with_shift: - assert attn_mask is not None # compute once - shift_size_w = window_size_w // 2 + if with_shift: + assert attn_mask is not None # compute once + shift_size_w = window_size_w // 2 - q = torch.roll(q, shifts=-shift_size_w, dims=1) - k = torch.roll(k, shifts=-shift_size_w, dims=1) - v = torch.roll(v, shifts=-shift_size_w, dims=1) + q = torch.roll(q, shifts=-shift_size_w, dims=1) + k = torch.roll(k, shifts=-shift_size_w, dims=1) + v = torch.roll(v, shifts=-shift_size_w, dims=1) - q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] - k = split_feature_1d(k, num_splits=num_splits) - v = split_feature_1d(v, num_splits=num_splits) + q = self.split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] + k = self.split_feature_1d(k, num_splits=num_splits) + v = self.split_feature_1d(v, num_splits=num_splits) - scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) - ) / scale_factor # [B*H*K, W/K, W/K] + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*H*K, W/K, W/K] - if with_shift: - # attn_mask: [K, W/K, W/K] - scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] + if with_shift and attn_mask is not None: + # attn_mask: [K, W/K, W/K] + scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] - attn = torch.softmax(scores, dim=-1) + attn = torch.softmax(scores, dim=-1) - out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] - out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] + out = self.merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] - # shift back - if with_shift: - out = torch.roll(out, shifts=shift_size_w, dims=2) + # shift back + if with_shift: + out = torch.roll(out, shifts=shift_size_w, dims=2) - out = out.view(b, -1, c) + out = out.view(b, -1, c) - return out + return out class SelfAttnPropagation(nn.Module): @@ -169,9 +183,7 @@ class SelfAttnPropagation(nn.Module): query: feature0, key: feature0, value: flow """ - def __init__(self, in_channels, - **kwargs, - ): + def __init__(self, in_channels: int): super(SelfAttnPropagation, self).__init__() self.q_proj = nn.Linear(in_channels, in_channels) @@ -181,11 +193,10 @@ def __init__(self, in_channels, if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, feature0, flow, - local_window_attn=False, - local_window_radius=1, - **kwargs, - ): + def forward(self, feature0: torch.Tensor, flow: torch.Tensor, + local_window_attn: bool = False, + local_window_radius: int = 1 + ) -> torch.Tensor: # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] if local_window_attn: return self.forward_local_window_attn(feature0, flow, @@ -214,9 +225,9 @@ def forward(self, feature0, flow, return out - def forward_local_window_attn(self, feature0, flow, - local_window_radius=1, - ): + def forward_local_window_attn(self, feature0: torch.Tensor, flow: torch.Tensor, + local_window_radius: int = 1 + ) -> torch.Tensor: assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth assert local_window_radius > 0 @@ -227,21 +238,21 @@ def forward_local_window_attn(self, feature0, flow, feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] - kernel_size = 2 * local_window_radius + 1 + kernel_size = int(2 * local_window_radius + 1) feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) - feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, - padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + feature0_window = F.unfold(feature0_proj, kernel_size=int(kernel_size), + padding=int(local_window_radius)) # [B, C*(2R+1)^2), H*W] - feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( - 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + feature0_window = feature0_window.view(b, c, int(kernel_size ** 2), h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, int(kernel_size ** 2)) # [B*H*W, C, (2R+1)^2] - flow_window = F.unfold(flow, kernel_size=kernel_size, - padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + flow_window = F.unfold(flow, kernel_size=int(kernel_size), + padding=int(local_window_radius)) # [B, 2*(2R+1)^2), H*W] - flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( - 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2] + flow_window = flow_window.view(b, value_channel, int(kernel_size ** 2), h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, int(kernel_size ** 2), value_channel) # [B*H*W, (2R+1)^2, 2] scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] diff --git a/unimatch/backbone.py b/unimatch/backbone.py index d5c92b7..5d967ab 100755 --- a/unimatch/backbone.py +++ b/unimatch/backbone.py @@ -14,10 +14,10 @@ def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, di dilation=dilation, padding=dilation, bias=False) self.relu = nn.ReLU(inplace=True) - self.norm1 = norm_layer(planes) - self.norm2 = norm_layer(planes) + self.norm1 = norm_layer(planes, track_running_stats=True) + self.norm2 = norm_layer(planes, track_running_stats=True) if not stride == 1 or in_planes != planes: - self.norm3 = norm_layer(planes) + self.norm3 = norm_layer(planes, track_running_stats=True) if stride == 1 and in_planes == planes: self.downsample = None @@ -48,7 +48,7 @@ def __init__(self, output_dim=128, feature_dims = [64, 96, 128] self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 - self.norm1 = norm_layer(feature_dims[0]) + self.norm1 = norm_layer(feature_dims[0], track_running_stats=True) self.relu1 = nn.ReLU(inplace=True) self.in_planes = feature_dims[0] diff --git a/unimatch/geometry.py b/unimatch/geometry.py index 775a957..ddee53a 100755 --- a/unimatch/geometry.py +++ b/unimatch/geometry.py @@ -1,9 +1,9 @@ import torch import torch.nn.functional as F +from typing import Optional, Tuple - -def coords_grid(b, h, w, homogeneous=False, device=None): - y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] +def coords_grid(b: int, h: int, w: int, homogeneous: bool = False, device: Optional[torch.device] = None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing = 'ij') # [H, W] stacks = [x, y] @@ -21,24 +21,28 @@ def coords_grid(b, h, w, homogeneous=False, device=None): return grid -def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): +def generate_window_grid(h_min: int, h_max: int, w_min: int, w_max: int, len_h: int, len_w: int, device: Optional[torch.device] = None): assert device is not None x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), torch.linspace(h_min, h_max, len_h, device=device)], - ) + indexing = 'ij') grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] return grid -def normalize_coords(coords, h, w): +def normalize_coords(coords: torch.Tensor, h: int, w: int): # coords: [B, H, W, 2] - c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + c = torch.tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) return (coords - c) / c # [-1, 1] -def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): +def bilinear_sample(img: torch.Tensor, sample_coords: torch.Tensor, + mode: str = 'bilinear', + padding_mode: str = 'zeros', + return_mask: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # img: [B, C, H, W] # sample_coords: [B, 2, H, W] in image scale if sample_coords.size(1) != 2: # [B, H, W, 2] @@ -59,10 +63,13 @@ def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', r return img, mask - return img + return img, None -def flow_warp(feature, flow, mask=False, padding_mode='zeros'): +def flow_warp(feature:torch.Tensor, flow: torch.Tensor, + mask: bool = False, + padding_mode: str = 'zeros' + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: b, c, h, w = feature.size() assert flow.size(1) == 2 @@ -72,9 +79,9 @@ def flow_warp(feature, flow, mask=False, padding_mode='zeros'): return_mask=mask) -def forward_backward_consistency_check(fwd_flow, bwd_flow, - alpha=0.01, - beta=0.5 +def forward_backward_consistency_check(fwd_flow: torch.Tensor, bwd_flow: torch.Tensor, + alpha: float = 0.01, + beta: float = 0.5 ): # fwd_flow, bwd_flow: [B, 2, H, W] # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) @@ -82,8 +89,8 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] - warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] - warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + warped_bwd_flow, _ = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow, _ = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) @@ -96,7 +103,7 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, return fwd_occ, bwd_occ -def back_project(depth, intrinsics): +def back_project(depth: torch.Tensor, intrinsics: torch.Tensor): # Back project 2D pixel coords to 3D points # depth: [B, H, W] # intrinsics: [B, 3, 3] @@ -110,7 +117,10 @@ def back_project(depth, intrinsics): return points -def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): +def camera_transform(points_ref: torch.Tensor, + extrinsics_ref: Optional[torch.Tensor] = None, + extrinsics_tgt: Optional[torch.Tensor] = None, + extrinsics_rel: Optional[torch.Tensor] = None): # Transform 3D points from reference camera to target camera # points_ref: [B, 3, H, W] # extrinsics_ref: [B, 4, 4] @@ -119,6 +129,8 @@ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extri b, _, h, w = points_ref.shape if extrinsics_rel is None: + assert extrinsics_tgt is not None + assert extrinsics_ref is not None extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], @@ -129,7 +141,9 @@ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extri return points_tgt -def reproject(points_tgt, intrinsics, return_mask=False): +def reproject(points_tgt: torch.Tensor, intrinsics: torch.Tensor, + return_mask: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # reproject to target view # points_tgt: [B, 3, H, W] # intrinsics: [B, 3, 3] @@ -151,11 +165,15 @@ def reproject(points_tgt, intrinsics, return_mask=False): return pixel_coords, mask - return pixel_coords + return pixel_coords, None -def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, - return_mask=False): +def reproject_coords(depth_ref: torch.Tensor, intrinsics: torch.Tensor, + extrinsics_ref: Optional[torch.Tensor] = None, + extrinsics_tgt: Optional[torch.Tensor] = None, + extrinsics_rel: Optional[torch.Tensor] = None, + return_mask: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Compute reprojection sample coords points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) @@ -166,15 +184,18 @@ def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt= return reproj_coords, mask - reproj_coords = reproject(points_tgt, intrinsics, + reproj_coords, _ = reproject(points_tgt, intrinsics, return_mask=return_mask) # [B, 2, H, W] in image scale - return reproj_coords + return reproj_coords, None -def compute_flow_with_depth_pose(depth_ref, intrinsics, - extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, - return_mask=False): +def compute_flow_with_depth_pose(depth_ref: torch.Tensor, intrinsics: torch.Tensor, + extrinsics_ref: Optional[torch.Tensor] = None, + extrinsics_tgt: Optional[torch.Tensor] = None, + extrinsics_rel: Optional[torch.Tensor] = None, + return_mask: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: b, h, w = depth_ref.shape coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] @@ -186,10 +207,10 @@ def compute_flow_with_depth_pose(depth_ref, intrinsics, return rigid_flow, mask - reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + reproj_coords, _ = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel, return_mask=return_mask) # [B, 2, H, W] rigid_flow = reproj_coords - coords_init - return rigid_flow + return rigid_flow, None diff --git a/unimatch/matching.py b/unimatch/matching.py index 6471025..52d676f 100755 --- a/unimatch/matching.py +++ b/unimatch/matching.py @@ -1,12 +1,13 @@ import torch import torch.nn.functional as F +from typing import Tuple from .geometry import coords_grid, generate_window_grid, normalize_coords -def global_correlation_softmax(feature0, feature1, - pred_bidir_flow=False, - ): +def global_correlation_softmax(feature0: torch.Tensor, feature1: torch.Tensor, + pred_bidir_flow: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: # global correlation b, c, h, w = feature0.shape feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] @@ -36,9 +37,9 @@ def global_correlation_softmax(feature0, feature1, return flow, prob -def local_correlation_softmax(feature0, feature1, local_radius, - padding_mode='zeros', - ): +def local_correlation_softmax(feature0: torch.Tensor, feature1: torch.Tensor, local_radius: int, + padding_mode: str = 'zeros', + ) -> Tuple[torch.Tensor, torch.Tensor]: b, c, h, w = feature0.size() coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] @@ -83,12 +84,12 @@ def local_correlation_softmax(feature0, feature1, local_radius, return flow, match_prob -def local_correlation_with_flow(feature0, feature1, - flow, - local_radius, - padding_mode='zeros', - dilation=1, - ): +def local_correlation_with_flow(feature0: torch.Tensor, feature1: torch.Tensor, + flow: torch.Tensor, + local_radius: int, + padding_mode: str = 'zeros', + dilation: int = 1, + ) -> torch.Tensor: b, c, h, w = feature0.size() coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] @@ -123,8 +124,7 @@ def local_correlation_with_flow(feature0, feature1, return corr -def global_correlation_softmax_stereo(feature0, feature1, - ): +def global_correlation_softmax_stereo(feature0: torch.Tensor, feature1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # global correlation on horizontal direction b, c, h, w = feature0.shape @@ -151,8 +151,7 @@ def global_correlation_softmax_stereo(feature0, feature1, return disparity.unsqueeze(1), prob # feature resolution -def local_correlation_softmax_stereo(feature0, feature1, local_radius, - ): +def local_correlation_softmax_stereo(feature0: torch.Tensor, feature1: torch.Tensor, local_radius: int) -> Tuple[torch.Tensor, torch.Tensor]: b, c, h, w = feature0.size() coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] @@ -200,13 +199,13 @@ def local_correlation_softmax_stereo(feature0, feature1, local_radius, return flow_x, match_prob -def correlation_softmax_depth(feature0, feature1, - intrinsics, - pose, - depth_candidates, - depth_from_argmax=False, - pred_bidir_depth=False, - ): +def correlation_softmax_depth(feature0: torch.Tensor, feature1: torch.Tensor, + intrinsics: torch.Tensor, + pose: torch.Tensor, + depth_candidates: torch.Tensor, + depth_from_argmax: bool = False, + pred_bidir_depth: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: b, c, h, w = feature0.size() assert depth_candidates.dim() == 4 # [B, D, H, W] scale_factor = c ** 0.5 @@ -236,9 +235,9 @@ def correlation_softmax_depth(feature0, feature1, return depth, match_prob -def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, - clamp_min_depth=1e-3, - ): +def warp_with_pose_depth_candidates(feature1: torch.Tensor, intrinsics: torch.Tensor, pose: torch.Tensor, depth: torch.Tensor, + clamp_min_depth: float = 1e-3, + ) -> torch.Tensor: """ feature1: [B, C, H, W] intrinsics: [B, 3, 3] diff --git a/unimatch/transformer.py b/unimatch/transformer.py index a93660c..1d38332 100755 --- a/unimatch/transformer.py +++ b/unimatch/transformer.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from typing import Optional, Tuple from .attention import (single_head_full_attention, single_head_split_window_attention, single_head_full_attention_1d, single_head_split_window_attention_1d) @@ -8,10 +9,10 @@ class TransformerLayer(nn.Module): def __init__(self, - d_model=128, - nhead=1, - no_ffn=False, - ffn_dim_expansion=4, + d_model: int = 128, + nhead: int = 1, + no_ffn: bool = False, + ffn_dim_expansion: int = 4, ): super(TransformerLayer, self).__init__() @@ -28,7 +29,14 @@ def __init__(self, self.norm1 = nn.LayerNorm(d_model) + self.single_head_split_window_attention = single_head_split_window_attention() + self.single_head_split_window_attention_1d = single_head_split_window_attention_1d() + self.single_head_full_attention = single_head_full_attention() + self.single_head_full_attention_1d = single_head_full_attention_1d() + # no ffn after self-attn, with ffn after cross-attn + self.mlp = nn.Sequential() + self.norm2 = nn.Sequential() if not self.no_ffn: in_channels = d_model * 2 self.mlp = nn.Sequential( @@ -39,15 +47,14 @@ def __init__(self, self.norm2 = nn.LayerNorm(d_model) - def forward(self, source, target, - height=None, - width=None, - shifted_window_attn_mask=None, - shifted_window_attn_mask_1d=None, - attn_type='swin', - with_shift=False, - attn_num_splits=None, - ): + def forward(self, source: torch.Tensor, target: torch.Tensor, + height: Optional[int] = None, + width: Optional[int] = None, + shifted_window_attn_mask: Optional[torch.Tensor] = None, + shifted_window_attn_mask_1d: Optional[torch.Tensor] = None, + attn_type: str = 'swin', + with_shift: bool = False, + attn_num_splits: int = 0) -> torch.Tensor: # source, target: [B, L, C] query, key, value = source, target, target @@ -65,7 +72,7 @@ def forward(self, source, target, # without bringing obvious performance gains and thus the implementation is removed raise NotImplementedError else: - message = single_head_split_window_attention(query, key, value, + message = self.single_head_split_window_attention(query, key, value, num_splits=attn_num_splits, with_shift=with_shift, h=height, @@ -79,7 +86,7 @@ def forward(self, source, target, else: if is_self_attn: if attn_num_splits > 1: - message = single_head_split_window_attention(query, key, value, + message = self.single_head_split_window_attention(query, key, value, num_splits=attn_num_splits, with_shift=with_shift, h=height, @@ -88,11 +95,11 @@ def forward(self, source, target, ) else: # full 2d attn - message = single_head_full_attention(query, key, value) # [N, L, C] + message = self.single_head_full_attention(query, key, value) # [N, L, C] else: # cross attn 1d - message = single_head_full_attention_1d(query, key, value, + message = self.single_head_full_attention_1d(query, key, value, h=height, w=width, ) @@ -104,7 +111,7 @@ def forward(self, source, target, if is_self_attn: if attn_num_splits > 1: # self attn shift window - message = single_head_split_window_attention(query, key, value, + message = self.single_head_split_window_attention(query, key, value, num_splits=attn_num_splits, with_shift=with_shift, h=height, @@ -113,12 +120,12 @@ def forward(self, source, target, ) else: # full 2d attn - message = single_head_full_attention(query, key, value) # [N, L, C] + message = self.single_head_full_attention(query, key, value) # [N, L, C] else: if attn_num_splits > 1: assert shifted_window_attn_mask_1d is not None # cross attn 1d shift - message = single_head_split_window_attention_1d(query, key, value, + message = self.single_head_split_window_attention_1d(query, key, value, num_splits=attn_num_splits, with_shift=with_shift, h=height, @@ -126,13 +133,13 @@ def forward(self, source, target, attn_mask=shifted_window_attn_mask_1d, ) else: - message = single_head_full_attention_1d(query, key, value, + message = self.single_head_full_attention_1d(query, key, value, h=height, w=width, ) else: - message = single_head_full_attention(query, key, value) # [B, L, C] + message = self.single_head_full_attention(query, key, value) # [B, L, C] message = self.merge(message) # [B, L, C] message = self.norm1(message) @@ -148,9 +155,9 @@ class TransformerBlock(nn.Module): """self attention + cross attention + FFN""" def __init__(self, - d_model=128, - nhead=1, - ffn_dim_expansion=4, + d_model: int = 128, + nhead: int = 1, + ffn_dim_expansion: int = 4, ): super(TransformerBlock, self).__init__() @@ -162,18 +169,19 @@ def __init__(self, self.cross_attn_ffn = TransformerLayer(d_model=d_model, nhead=nhead, + no_ffn=False, ffn_dim_expansion=ffn_dim_expansion, ) - def forward(self, source, target, - height=None, - width=None, - shifted_window_attn_mask=None, - shifted_window_attn_mask_1d=None, - attn_type='swin', - with_shift=False, - attn_num_splits=None, - ): + def forward(self, source: torch.Tensor, target: torch.Tensor, + height: Optional[int] = None, + width: Optional[int] = None, + shifted_window_attn_mask: Optional[torch.Tensor] = None, + shifted_window_attn_mask_1d: Optional[torch.Tensor] = None, + attn_type:str = 'swin', + with_shift:bool = False, + attn_num_splits: int = 0 + ) -> torch.Tensor: # source, target: [B, L, C] # self attention @@ -181,6 +189,7 @@ def forward(self, source, target, height=height, width=width, shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=None, attn_type=attn_type, with_shift=with_shift, attn_num_splits=attn_num_splits, @@ -202,15 +211,17 @@ def forward(self, source, target, class FeatureTransformer(nn.Module): def __init__(self, - num_layers=6, - d_model=128, - nhead=1, - ffn_dim_expansion=4, + num_layers: int = 6, + d_model: int = 128, + nhead: int = 1, + ffn_dim_expansion: int = 4, ): super(FeatureTransformer, self).__init__() self.d_model = d_model self.nhead = nhead + self.generate_shift_window_attn_mask = generate_shift_window_attn_mask() + self.generate_shift_window_attn_mask_1d = generate_shift_window_attn_mask_1d() self.layers = nn.ModuleList([ TransformerBlock(d_model=d_model, @@ -223,11 +234,11 @@ def __init__(self, if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, feature0, feature1, - attn_type='swin', - attn_num_splits=None, - **kwargs, - ): + + def forward(self, feature0: torch.Tensor, feature1: torch.Tensor, + attn_type: str = 'swin', + attn_num_splits: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor]: b, c, h, w = feature0.shape assert self.d_model == c @@ -242,7 +253,7 @@ def forward(self, feature0, feature1, window_size_w = w // attn_num_splits # compute attn mask once - shifted_window_attn_mask = generate_shift_window_attn_mask( + shifted_window_attn_mask = self.generate_shift_window_attn_mask( input_resolution=(h, w), window_size_h=window_size_h, window_size_w=window_size_w, @@ -258,7 +269,7 @@ def forward(self, feature0, feature1, window_size_w = w // attn_num_splits # compute attn mask once - shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( + shifted_window_attn_mask_1d = self.generate_shift_window_attn_mask_1d( input_w=w, window_size_w=window_size_w, shift_size_w=window_size_w // 2, diff --git a/unimatch/trident_conv.py b/unimatch/trident_conv.py index 445663c..baa5837 100755 --- a/unimatch/trident_conv.py +++ b/unimatch/trident_conv.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn import functional as F from torch.nn.modules.utils import _pair +from typing import List class MultiScaleTridentConv(nn.Module): @@ -61,7 +62,7 @@ def __init__( if self.bias is not None: nn.init.constant_(self.bias, 0) - def forward(self, inputs): + def forward(self, inputs: List[torch.Tensor]): num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 assert len(inputs) == num_branch diff --git a/unimatch/unimatch.py b/unimatch/unimatch.py index 96db16e..9889ace 100755 --- a/unimatch/unimatch.py +++ b/unimatch/unimatch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Optional, List, Dict from .backbone import CNNEncoder from .transformer import FeatureTransformer @@ -43,6 +44,8 @@ def __init__(self, # propagation with self-attn self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) + self.feature_add_position = feature_add_position(feature_channels) + self.upsampler = nn.Sequential() if not self.reg_refine or task == 'depth': # convex upsampling simiar to RAFT @@ -78,13 +81,14 @@ def extract_feature(self, img0, img1): return feature0, feature1 - def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, - is_depth=False): + def upsample_flow(self, flow: torch.Tensor, feature: Optional[torch.Tensor], bilinear: bool = False, upsample_factor: float = 8, + is_depth: bool = False) -> torch.Tensor: if bilinear: - multiplier = 1 if is_depth else upsample_factor + multiplier = 1.0 if is_depth else upsample_factor up_flow = F.interpolate(flow, scale_factor=upsample_factor, mode='bilinear', align_corners=True) * multiplier else: + assert feature is not None concat = torch.cat((flow, feature), dim=1) mask = self.upsampler(concat) up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, @@ -92,23 +96,21 @@ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, return up_flow - def forward(self, img0, img1, - attn_type=None, - attn_splits_list=None, - corr_radius_list=None, - prop_radius_list=None, - num_reg_refine=1, - pred_bidir_flow=False, - task='flow', - intrinsics=None, - pose=None, # relative pose transform - min_depth=1. / 0.5, # inverse depth range - max_depth=1. / 10, - num_depth_candidates=64, - depth_from_argmax=False, - pred_bidir_depth=False, - **kwargs, - ): + def forward(self, img0: torch.Tensor, img1: torch.Tensor, + attn_type: str, + attn_splits_list: List[int], + corr_radius_list: List[int], + prop_radius_list: List[int], + num_reg_refine: int = 1, + pred_bidir_flow: bool = False, + task: str = 'flow', + intrinsics: Optional[torch.Tensor] = None, + pose: torch.Tensor = None, # relative pose transform + min_depth:float = 1. / 0.5, # inverse depth range + max_depth:float = 1. / 10, + num_depth_candidates: int = 64, + depth_from_argmax: bool = False, + pred_bidir_depth: bool = False): if pred_bidir_flow: assert task == 'flow' @@ -116,8 +118,8 @@ def forward(self, img0, img1, if task == 'depth': assert self.num_scales == 1 # multi-scale depth model is not supported yet - results_dict = {} - flow_preds = [] + results_dict: Dict[str, List[torch.Tensor]] = {} + flow_preds: List[torch.Tensor] = [] if task == 'flow': # stereo and depth tasks have normalized img in dataloader @@ -126,7 +128,7 @@ def forward(self, img0, img1, # list of features, resolution low to high feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features - flow = None + flow: Optional[torch.Tensor] = None if task != 'depth': assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales @@ -146,14 +148,19 @@ def forward(self, img0, img1, if task == 'depth': # scale intrinsics + assert intrinsics is not None intrinsics_curr = intrinsics.clone() intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor + else: + intrinsics_curr = torch.zeros(1, 1) if scale_idx > 0: assert task != 'depth' # not supported for multi-scale depth model - flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + assert flow is not None + flow = F.interpolate(flow, scale_factor=2.0, mode='bilinear', align_corners=True) * 2 if flow is not None: + assert flow is not None assert task != 'depth' flow = flow.detach() @@ -163,19 +170,21 @@ def forward(self, img0, img1, zeros = torch.zeros_like(flow) # [B, 1, H, W] # NOTE: reverse disp, disparity is positive displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] - feature1 = flow_warp(feature1, displace) # [B, C, H, W] + feature1, _ = flow_warp(feature1, displace) # [B, C, H, W] elif task == 'flow': - feature1 = flow_warp(feature1, flow) # [B, C, H, W] + feature1, _ = flow_warp(feature1, flow) # [B, C, H, W] else: raise NotImplementedError attn_splits = attn_splits_list[scale_idx] if task != 'depth': corr_radius = corr_radius_list[scale_idx] + else: + corr_radius = 0 prop_radius = prop_radius_list[scale_idx] # add position to features - feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + feature0, feature1 = self.feature_add_position(feature0, feature1, attn_splits, self.feature_channels) # Transformer feature0, feature1 = self.transformer(feature0, feature1, @@ -216,7 +225,12 @@ def forward(self, img0, img1, raise NotImplementedError # flow or residual flow - flow = flow + flow_pred if flow is not None else flow_pred + if flow is not None: + assert flow is not None + flow = flow + flow_pred + else: + assert flow is None + flow = flow_pred if task == 'stereo': flow = flow.clamp(min=0) # positive disparity @@ -269,6 +283,8 @@ def forward(self, img0, img1, is_depth=task == 'depth') flow_preds.append(flow_up) + if isinstance(num_reg_refine, tuple): + num_reg_refine = num_reg_refine[0] assert num_reg_refine > 0 for refine_iter_idx in range(num_reg_refine): flow = flow.detach() @@ -292,7 +308,7 @@ def forward(self, img0, img1, dim=0), torch.cat((feature1_ori, feature0_ori), dim=0) - flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1), + flow_from_depth, _ = compute_flow_with_depth_pose(1. / flow.squeeze(1), intrinsics_curr, extrinsics_rel=pose, ) diff --git a/unimatch/utils.py b/unimatch/utils.py index 0c3dbea..2b778c4 100755 --- a/unimatch/utils.py +++ b/unimatch/utils.py @@ -1,26 +1,27 @@ import torch +import torch.nn as nn import torch.nn.functional as F from .position import PositionEmbeddingSine +from typing import Tuple - -def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): +def generate_window_grid(h_min: int, h_max: int, w_min: int, w_max: int, len_h: int, len_w: int, device: torch.device = None) -> torch.Tensor: assert device is not None x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), torch.linspace(h_min, h_max, len_h, device=device)], - ) + indexing = 'ij') grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] return grid -def normalize_coords(coords, h, w): +def normalize_coords(coords: torch.Tensor, h: int, w: int) -> torch.Tensor: # coords: [B, H, W, 2] c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) return (coords - c) / c # [-1, 1] -def normalize_img(img0, img1): +def normalize_img(img0: torch.Tensor, img1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # loaded images are in [0, 255] # normalize by ImageNet mean and std mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) @@ -30,109 +31,113 @@ def normalize_img(img0, img1): return img0, img1 - -def split_feature(feature, - num_splits=2, - channel_last=False, - ): - if channel_last: # [B, H, W, C] - b, h, w, c = feature.size() - assert h % num_splits == 0 and w % num_splits == 0 - - b_new = b * num_splits * num_splits - h_new = h // num_splits - w_new = w // num_splits - - feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c - ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] - else: # [B, C, H, W] - b, c, h, w = feature.size() - assert h % num_splits == 0 and w % num_splits == 0 - - b_new = b * num_splits * num_splits - h_new = h // num_splits - w_new = w // num_splits - - feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits - ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] - - return feature - - -def merge_splits(splits, - num_splits=2, - channel_last=False, - ): - if channel_last: # [B*K*K, H/K, W/K, C] - b, h, w, c = splits.size() - new_b = b // num_splits // num_splits - - splits = splits.view(new_b, num_splits, num_splits, h, w, c) - merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( - new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] - else: # [B*K*K, C, H/K, W/K] - b, c, h, w = splits.size() - new_b = b // num_splits // num_splits - - splits = splits.view(new_b, num_splits, num_splits, c, h, w) - merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( - new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] - - return merge - - -def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, - shift_size_h, shift_size_w, device=torch.device('cuda')): - # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py - # calculate attention mask for SW-MSA - h, w = input_resolution - img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 - h_slices = (slice(0, -window_size_h), - slice(-window_size_h, -shift_size_h), - slice(-shift_size_h, None)) - w_slices = (slice(0, -window_size_w), - slice(-window_size_w, -shift_size_w), - slice(-shift_size_w, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) - - mask_windows = mask_windows.view(-1, window_size_h * window_size_w) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - -def feature_add_position(feature0, feature1, attn_splits, feature_channels): - pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) - - if attn_splits > 1: # add position in splited window - feature0_splits = split_feature(feature0, num_splits=attn_splits) - feature1_splits = split_feature(feature1, num_splits=attn_splits) - - position = pos_enc(feature0_splits) - - feature0_splits = feature0_splits + position - feature1_splits = feature1_splits + position - - feature0 = merge_splits(feature0_splits, num_splits=attn_splits) - feature1 = merge_splits(feature1_splits, num_splits=attn_splits) - else: - position = pos_enc(feature0) - - feature0 = feature0 + position - feature1 = feature1 + position - - return feature0, feature1 - - -def upsample_flow_with_mask(flow, up_mask, upsample_factor, - is_depth=False): +class split_feature(nn.Module): + def forward(self, feature: torch.Tensor, num_splits: int = 2, channel_last: bool = False) -> torch.Tensor: + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + +class merge_splits(nn.Module): + def forward(self, splits: torch.Tensor, num_splits: int = 2, channel_last: bool = False) -> torch.Tensor: + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + +class generate_shift_window_attn_mask(nn.Module): + def __init__(self): + super().__init__() + self.split_feature = split_feature() + + def forward(self, input_resolution: Tuple[int, int], window_size_h: int, window_size_w: int, shift_size_h: int, shift_size_w: int, device: torch.device = torch.device('cuda')) -> torch.Tensor: + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + + mask1 = torch.ones((h - window_size_h, w - window_size_w )).to(device) * 0 + mask2 = torch.ones((h - window_size_h, window_size_w - shift_size_w)).to(device) * 1 + mask3 = torch.ones((h - window_size_h, shift_size_w )).to(device) * 2 + mask4 = torch.ones((window_size_h - shift_size_h, w - window_size_w )).to(device) * 3 + mask5 = torch.ones((window_size_h - shift_size_h, window_size_w - shift_size_w)).to(device) * 4 + mask6 = torch.ones((window_size_h - shift_size_h, shift_size_w )).to(device) * 5 + mask7 = torch.ones((shift_size_h, w - window_size_w )).to(device) * 6 + mask8 = torch.ones((shift_size_h, window_size_w - shift_size_w)).to(device) * 7 + mask9 = torch.ones((shift_size_h, shift_size_w )).to(device) * 8 + # Concatenate the masks to create the full mask + upper_mask = torch.cat([mask1, mask2, mask3], dim=1) + middle_mask = torch.cat([mask4, mask5, mask6], dim=1) + lower_mask = torch.cat([mask7, mask8, mask9], dim=1) + img_mask = torch.cat([upper_mask, middle_mask, lower_mask], dim=0).unsqueeze(0).unsqueeze(-1) # Add extra dimensions for batch size and channels + + mask_windows = self.split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + +class feature_add_position(nn.Module): + def __init__(self, feature_channels: int): + super().__init__() + self.split_feature = split_feature() + self.merge_splits = merge_splits() + self.pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + def forward(self, feature0: torch.Tensor, feature1: torch.Tensor, attn_splits: int, feature_channels: int) -> Tuple[torch.Tensor, torch.Tensor]: + if attn_splits > 1: # add position in splited window + feature0_splits = self.split_feature(feature0, num_splits=attn_splits) + feature1_splits = self.split_feature(feature1, num_splits=attn_splits) + + position = self.pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = self.merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = self.merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = self.pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def upsample_flow_with_mask(flow: torch.Tensor, up_mask: torch.Tensor, upsample_factor: int, + is_depth: bool = False) -> torch.Tensor: # convex upsampling following raft mask = up_mask @@ -151,38 +156,32 @@ def upsample_flow_with_mask(flow, up_mask, upsample_factor, return up_flow +class split_feature_1d(nn.Module): + def forward(self, feature: torch.Tensor, num_splits: int = 2) -> torch.Tensor: + # feature: [B, W, C] + b, w, c = feature.size() + assert w % num_splits == 0 -def split_feature_1d(feature, - num_splits=2, - ): - # feature: [B, W, C] - b, w, c = feature.size() - assert w % num_splits == 0 - - b_new = b * num_splits - w_new = w // num_splits - - feature = feature.view(b, num_splits, w // num_splits, c - ).view(b_new, w_new, c) # [B*K, W/K, C] - - return feature + b_new = b * num_splits + w_new = w // num_splits + feature = feature.view(b, num_splits, w // num_splits, c + ).view(b_new, w_new, c) # [B*K, W/K, C] -def merge_splits_1d(splits, - h, - num_splits=2, - ): - b, w, c = splits.size() - new_b = b // num_splits // h + return feature - splits = splits.view(new_b, h, num_splits, w, c) - merge = splits.view( - new_b, h, num_splits * w, c) # [B, H, W, C] +class merge_splits_1d(nn.Module): + def forward(self, splits: torch.Tensor, h: int, num_splits: int = 2) -> torch.Tensor: + b, w, c = splits.size() + new_b = b // num_splits // h - return merge + splits = splits.view(new_b, h, num_splits, w, c) + merge = splits.view( + new_b, h, num_splits * w, c) # [B, H, W, C] + return merge -def window_partition_1d(x, window_size_w): +def window_partition_1d(x: torch.Tensor, window_size_w: int) -> torch.Tensor: """ Args: x: (B, W, C) @@ -195,22 +194,19 @@ def window_partition_1d(x, window_size_w): x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) return x +class generate_shift_window_attn_mask_1d(nn.Module): + def forward(self, input_w: int, window_size_w: int, shift_size_w: int, device: torch.device = torch.device('cuda')) -> torch.Tensor: + # calculate attention mask for SW-MSA + + mask1 = torch.ones((0, input_w - window_size_w )).to(device) * 0 + mask2 = torch.ones((0, window_size_w - shift_size_w)).to(device) * 1 + mask3 = torch.ones((0, shift_size_w )).to(device) * 2 + # Concatenate the masks to create the full mask + img_mask = torch.cat([mask1, mask2, mask3], dim=1).unsqueeze(0).unsqueeze(-1) + + mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 + mask_windows = mask_windows.view(-1, window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) -def generate_shift_window_attn_mask_1d(input_w, window_size_w, - shift_size_w, device=torch.device('cuda')): - # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 - w_slices = (slice(0, -window_size_w), - slice(-window_size_w, -shift_size_w), - slice(-shift_size_w, None)) - cnt = 0 - for w in w_slices: - img_mask[:, w, :] = cnt - cnt += 1 - - mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 - mask_windows = mask_windows.view(-1, window_size_w) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask + return attn_mask diff --git a/utils/utils.py b/utils/utils.py index 73d780f..f8e87f8 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -61,9 +61,10 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zer def coords_grid(batch, ht, wd, normalize=False): if normalize: # [-1, 1] coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, - 2 * torch.arange(wd) / (wd - 1) - 1) + 2 * torch.arange(wd) / (wd - 1) - 1, + indexing = 'ij') else: - coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing = 'ij') coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] From 0141bcb6a6f785789deef101ec813a7bf9120071 Mon Sep 17 00:00:00 2001 From: AdrianEddy Date: Mon, 1 Apr 2024 04:58:05 +0200 Subject: [PATCH 2/3] Remove track_running_stats from norm_layer --- unimatch/backbone.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unimatch/backbone.py b/unimatch/backbone.py index 5d967ab..d5c92b7 100755 --- a/unimatch/backbone.py +++ b/unimatch/backbone.py @@ -14,10 +14,10 @@ def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, di dilation=dilation, padding=dilation, bias=False) self.relu = nn.ReLU(inplace=True) - self.norm1 = norm_layer(planes, track_running_stats=True) - self.norm2 = norm_layer(planes, track_running_stats=True) + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) if not stride == 1 or in_planes != planes: - self.norm3 = norm_layer(planes, track_running_stats=True) + self.norm3 = norm_layer(planes) if stride == 1 and in_planes == planes: self.downsample = None @@ -48,7 +48,7 @@ def __init__(self, output_dim=128, feature_dims = [64, 96, 128] self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 - self.norm1 = norm_layer(feature_dims[0], track_running_stats=True) + self.norm1 = norm_layer(feature_dims[0]) self.relu1 = nn.ReLU(inplace=True) self.in_planes = feature_dims[0] From da140fac169d58fba4ffde9c4ef10c906fb5040b Mon Sep 17 00:00:00 2001 From: AdrianEddy Date: Fri, 10 May 2024 14:02:53 +0200 Subject: [PATCH 3/3] Fix generate_shift_window_attn_mask_1d --- unimatch/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unimatch/utils.py b/unimatch/utils.py index 2b778c4..3ef9cf5 100755 --- a/unimatch/utils.py +++ b/unimatch/utils.py @@ -197,12 +197,11 @@ def window_partition_1d(x: torch.Tensor, window_size_w: int) -> torch.Tensor: class generate_shift_window_attn_mask_1d(nn.Module): def forward(self, input_w: int, window_size_w: int, shift_size_w: int, device: torch.device = torch.device('cuda')) -> torch.Tensor: # calculate attention mask for SW-MSA - - mask1 = torch.ones((0, input_w - window_size_w )).to(device) * 0 - mask2 = torch.ones((0, window_size_w - shift_size_w)).to(device) * 1 - mask3 = torch.ones((0, shift_size_w )).to(device) * 2 + mask1 = torch.ones((input_w - window_size_w )).to(device) * 0 + mask2 = torch.ones((window_size_w - shift_size_w)).to(device) * 1 + mask3 = torch.ones((shift_size_w )).to(device) * 2 # Concatenate the masks to create the full mask - img_mask = torch.cat([mask1, mask2, mask3], dim=1).unsqueeze(0).unsqueeze(-1) + img_mask = torch.cat([mask1, mask2, mask3], dim=0).unsqueeze(0).unsqueeze(-1) mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 mask_windows = mask_windows.view(-1, window_size_w)