Skip to content

Commit

Permalink
fix: Fix mypy error.
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyYang0714 committed Apr 10, 2024
1 parent a9aba28 commit 5249e18
Show file tree
Hide file tree
Showing 29 changed files with 92 additions and 103 deletions.
2 changes: 0 additions & 2 deletions vis4d/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from vis4d.data.const import CommonKeys as K
from vis4d.data.typing import DictData

from ..const import CommonKeys as K
from ..typing import DictData
from .base import Dataset
from .util import CacheMappingMixin, im_decode

Expand Down
2 changes: 1 addition & 1 deletion vis4d/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def im_decode(
}, f"{mode} not supported for image decoding!"
if backend == "PIL":
pil_img = Image.open(BytesIO(bytearray(im_bytes)))
pil_img = ImageOps.exif_transpose(pil_img)
pil_img = ImageOps.exif_transpose(pil_img) # type: ignore
if pil_img.mode == "L": # pragma: no cover
if mode == "L":
img: NDArrayUI8 = np.array(pil_img)[..., None]
Expand Down
25 changes: 9 additions & 16 deletions vis4d/engine/loss_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,15 @@ def forward(

# Convert loss_dict to total loss and metrics dictionary
metrics: dict[str, float] = {}
if isinstance(loss_dict, Tensor):
total_loss = loss_dict
elif isinstance(loss_dict, dict):
keep_loss_dict: LossesType = {}
for k, v in loss_dict.items():
metrics[k] = v.detach().cpu().item()
if (
self.exclude_attributes is None
or k not in self.exclude_attributes
):
keep_loss_dict[k] = v
total_loss = sum(keep_loss_dict.values()) # type: ignore
else:
raise TypeError(
"Loss function must return a Tensor or a dict of Tensor"
)
keep_loss_dict: LossesType = {}
for k, v in loss_dict.items():
metrics[k] = v.detach().cpu().item()
if (
self.exclude_attributes is None
or k not in self.exclude_attributes
):
keep_loss_dict[k] = v
total_loss: Tensor = sum(keep_loss_dict.values()) # type: ignore
metrics["loss"] = total_loss.detach().cpu().item()

return total_loss, metrics
4 changes: 2 additions & 2 deletions vis4d/model/adapter/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update(self, steps: int) -> None: # pylint: disable=unused-argument
"""Update the internal EMA model."""
self._update(
self.model,
update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m, # type: ignore # pylint: disable=line-too-long
update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m,
)

def set(self, model: nn.Module) -> None:
Expand Down Expand Up @@ -114,5 +114,5 @@ def update(self, steps: int) -> None:
)
self._update(
self.model,
update_fn=lambda e, m: decay * e + (1.0 - decay) * m, # type: ignore # pylint: disable=line-too-long
update_fn=lambda e, m: decay * e + (1.0 - decay) * m,
)
62 changes: 29 additions & 33 deletions vis4d/op/base/pointnetpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
class PointNetSetAbstractionOut(NamedTuple):
"""Ouput of PointNet set abstraction."""

coordinates: torch.Tensor # [B, C, S]
features: torch.Tensor # [B, D', S]
coordinates: Tensor # [B, C, S]
features: Tensor # [B, D', S]


def square_distance(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
def square_distance(src: Tensor, dst: Tensor) -> Tensor:
"""Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
Expand All @@ -44,10 +44,10 @@ def square_distance(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src**2, -1).view(bs, n_pts_in, 1)
dist += torch.sum(dst**2, -1).view(bs, 1, n_pts_out)
return dist # type: ignore
return dist


def index_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
def index_points(points: Tensor, idx: Tensor) -> Tensor:
"""Indexes points.
Input:
Expand All @@ -73,7 +73,7 @@ def index_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return new_points


def farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor:
def farthest_point_sample(xyz: Tensor, npoint: int) -> Tensor:
"""Farthest point sampling.
Input:
Expand All @@ -100,8 +100,8 @@ def farthest_point_sample(xyz: torch.Tensor, npoint: int) -> torch.Tensor:


def query_ball_point(
radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor
) -> torch.Tensor:
radius: float, nsample: int, xyz: Tensor, new_xyz: Tensor
) -> Tensor:
"""Query around a ball with given radius.
Input:
Expand Down Expand Up @@ -137,9 +137,9 @@ def sample_and_group(
npoint: int,
radius: float,
nsample: int,
xyz: torch.Tensor,
points: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
xyz: Tensor,
points: Tensor,
) -> tuple[Tensor, Tensor]:
"""Samples and groups.
Input:
Expand Down Expand Up @@ -170,9 +170,7 @@ def sample_and_group(
return new_xyz, new_points


def sample_and_group_all(
xyz: torch.Tensor, points: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
def sample_and_group_all(xyz: Tensor, points: Tensor) -> tuple[Tensor, Tensor]:
"""Sample and groups all.
Input:
Expand Down Expand Up @@ -243,7 +241,7 @@ def __init__(
self.group_all = group_all

def __call__(
self, coordinates: torch.Tensor, features: torch.Tensor
self, coordinates: Tensor, features: Tensor
) -> PointNetSetAbstractionOut:
"""Call function.
Expand All @@ -259,7 +257,7 @@ def __call__(
return self._call_impl(coordinates, features)

def forward(
self, xyz: torch.Tensor, points: torch.Tensor
self, xyz: Tensor, points: Tensor
) -> PointNetSetAbstractionOut:
"""Pointnet++ set abstraction layer forward.
Expand Down Expand Up @@ -327,11 +325,11 @@ def __init__(

def __call__(
self,
xyz1: torch.Tensor,
xyz2: torch.Tensor,
points1: torch.Tensor | None,
points2: torch.Tensor,
) -> torch.Tensor:
xyz1: Tensor,
xyz2: Tensor,
points1: Tensor | None,
points2: Tensor,
) -> Tensor:
"""Call function.
Input:
Expand All @@ -347,11 +345,11 @@ def __call__(

def forward(
self,
xyz1: torch.Tensor,
xyz2: torch.Tensor,
points1: torch.Tensor | None,
points2: torch.Tensor,
) -> torch.Tensor:
xyz1: Tensor,
xyz2: Tensor,
points1: Tensor | None,
points2: Tensor,
) -> Tensor:
"""Forward Implementation.
Input:
Expand All @@ -377,7 +375,7 @@ def forward(
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]

dist_recip: Tensor = 1.0 / (dists + 1e-8) # type: ignore
dist_recip: Tensor = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(
Expand All @@ -387,9 +385,7 @@ def forward(

if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat(
[points1, interpolated_points], dim=-1 # type: ignore
)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points

Expand All @@ -403,7 +399,7 @@ def forward(
class PointNet2SegmentationOut(NamedTuple):
"""Prediction for the pointnet++ semantic segmentation network."""

class_logits: torch.Tensor
class_logits: Tensor


class PointNet2Segmentation(nn.Module): # TODO, probably move to module?
Expand Down Expand Up @@ -445,7 +441,7 @@ def __init__(self, num_classes: int, in_channels: int = 3):
self.conv2 = nn.Conv1d(128, num_classes, 1)
self.in_channels = in_channels

def __call__(self, xyz: torch.Tensor) -> PointNet2SegmentationOut:
def __call__(self, xyz: Tensor) -> PointNet2SegmentationOut:
"""Call implementation.
Args:
Expand All @@ -456,7 +452,7 @@ def __call__(self, xyz: torch.Tensor) -> PointNet2SegmentationOut:
"""
return self._call_impl(xyz)

def forward(self, xyz: torch.Tensor) -> PointNet2SegmentationOut:
def forward(self, xyz: Tensor) -> PointNet2SegmentationOut:
"""Predicts the semantic class logits for each point.
Args:
Expand Down
2 changes: 1 addition & 1 deletion vis4d/op/box/encoder/qd_3dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __call__(
2 * np.pi / self.num_rotation_bins,
device=alpha.device,
)
bin_centers += np.pi / self.num_rotation_bins # type: ignore
bin_centers += np.pi / self.num_rotation_bins
for i in range(alpha.shape[0]):
overlap_value = (
np.pi * 2 / self.num_rotation_bins * self.bin_overlap
Expand Down
4 changes: 2 additions & 2 deletions vis4d/op/box/matchers/sim_ota.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward( # pylint: disable=arguments-differ # type: ignore[override]

valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
# disable AMP autocast and calculate BCE with FP32 to avoid overflow
with torch.cuda.amp.autocast(enabled=False): # type: ignore[attr-defined] # pylint: disable=line-too-long
with torch.cuda.amp.autocast(enabled=False):
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32),
Expand Down Expand Up @@ -216,7 +216,7 @@ def dynamic_k_matching(
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx],
k=dynamic_ks[gt_idx].item(),
k=dynamic_ks[gt_idx].item(), # type: ignore
largest=False,
)
matching_matrix[:, gt_idx][pos_idx] = 1
Expand Down
2 changes: 1 addition & 1 deletion vis4d/op/box/poolers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def assign_boxes_to_levels(
)
# Eqn.(1) in FPN paper
level_assignments = torch.floor(
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8) # type: ignore # pylint: disable=line-too-long
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
)
# clamp level to (min, max), in case the box size is too large or too small
# for the available feature maps
Expand Down
6 changes: 3 additions & 3 deletions vis4d/op/box/samplers/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def forward(self, matching: MatchResult) -> SamplingResult:
"""Sample boxes according to strategies defined in cfg."""
pos_sample_size = int(self.batch_size * self.positive_fraction)

positive_mask: Tensor = ( # type:ignore
matching.assigned_labels != -1
) & (matching.assigned_labels != self.bg_label)
positive_mask: Tensor = (matching.assigned_labels != -1) & (
matching.assigned_labels != self.bg_label
)
negative_mask = torch.eq(matching.assigned_labels, self.bg_label)

positive = positive_mask.nonzero()[:, 0]
Expand Down
2 changes: 1 addition & 1 deletion vis4d/op/box/samplers/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def _sample_labels(
labels: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Randomly sample indices from given labels."""
positive = ((labels != -1) & (labels != 0)).nonzero()[:, 0] # type: ignore # pylint: disable=line-too-long
positive = ((labels != -1) & (labels != 0)).nonzero()[:, 0]
negative = torch.eq(labels, 0).nonzero()[:, 0]
return positive, negative
2 changes: 1 addition & 1 deletion vis4d/op/box/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _sample_labels(
self, labels: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Randomly sample indices from given labels."""
positive = ((labels != -1) & (labels != self.bg_label)).nonzero()[:, 0] # type: ignore # pylint: disable=line-too-long
positive = ((labels != -1) & (labels != self.bg_label)).nonzero()[:, 0]
negative = torch.eq(labels, self.bg_label).nonzero()[:, 0]

num_pos = int(self.batch_size * self.positive_fraction)
Expand Down
4 changes: 3 additions & 1 deletion vis4d/op/detect/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def forward(
# since feature map sizes of all images are the same, we only compute
# anchors for one time
device = cls_outs[0].device
featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs]
featmap_sizes: list[tuple[int, int]] = [
featmap.size()[-2:] for featmap in cls_outs # type: ignore
]
assert len(featmap_sizes) == self.anchor_generator.num_levels
anchor_grids = self.anchor_generator.grid_priors(
featmap_sizes, device=device
Expand Down
4 changes: 3 additions & 1 deletion vis4d/op/detect/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ def forward(
# since feature map sizes of all images are the same, we only compute
# anchors for one time
device = class_outs[0].device
featmap_sizes = [featmap.size()[-2:] for featmap in class_outs]
featmap_sizes: list[tuple[int, int]] = [
featmap.size()[-2:] for featmap in class_outs # type: ignore
]
assert len(featmap_sizes) == self.anchor_generator.num_levels
anchor_grids = self.anchor_generator.grid_priors(
featmap_sizes, device=device
Expand Down
4 changes: 2 additions & 2 deletions vis4d/op/detect/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def bboxes_nms(
max_scores, labels = torch.max(cls_scores, 1)
valid_mask = objectness * max_scores >= score_thr
valid_idxs = valid_mask.nonzero()[:, 0]
num_topk = min(nms_pre, valid_mask.sum())
num_topk = min(nms_pre, valid_mask.sum()) # type: ignore

scores, idxs = (max_scores[valid_mask] * objectness[valid_mask]).sort(
descending=True
Expand Down Expand Up @@ -288,7 +288,7 @@ def preprocess_outputs(
num_imgs = len(images_hw)
num_classes = cls_outs[0].shape[1]
featmap_sizes: list[tuple[int, int]] = [
tuple(featmap.size()[-2:]) for featmap in cls_outs
tuple(featmap.size()[-2:]) for featmap in cls_outs # type: ignore
]
assert len(featmap_sizes) == point_generator.num_levels
mlvl_points = point_generator.grid_priors(
Expand Down
6 changes: 3 additions & 3 deletions vis4d/op/detect3d/bevformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def forward(
batch_size, len_bev, num_bev_level, _ = ref_2d.shape
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack(
[prev_bev, bev_query], 1 # type: ignore
).reshape(batch_size * 2, len_bev, -1)
prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(
batch_size * 2, len_bev, -1
)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
batch_size * 2, len_bev, num_bev_level, 2
)
Expand Down
10 changes: 5 additions & 5 deletions vis4d/op/detect3d/bevformer/temporal_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ def forward(
value = value.permute(1, 0, 2)

bs, num_query, embed_dims = query.shape
_, num_value, _ = value.shape # type: ignore
_, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
assert self.num_bev_queue == 2

query = torch.cat([value[:bs], query], -1) # type: ignore
query = torch.cat([value[:bs], query], -1)
value = self.value_proj(value)
assert isinstance(value, Tensor)

if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)

value = value.reshape( # type: ignore
value = value.reshape(
bs * self.num_bev_queue, num_value, self.num_heads, -1
)

Expand Down Expand Up @@ -246,7 +246,7 @@ def forward(
f" 2 or 4, but get {reference_points.shape[-1]} instead."
)

if torch.cuda.is_available() and value.is_cuda: # type: ignore
if torch.cuda.is_available() and value.is_cuda:
output = MSDeformAttentionFunction.apply(
value,
spatial_shapes,
Expand All @@ -257,7 +257,7 @@ def forward(
)
else:
output = ms_deformable_attention_cpu(
value, # type: ignore
value,
spatial_shapes,
sampling_locations,
attention_weights,
Expand Down
Loading

0 comments on commit 5249e18

Please sign in to comment.