Skip to content

Commit

Permalink
feat: Add VIS flag and update transformer related code.
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyYang0714 committed Jul 26, 2024
1 parent 91b2fca commit 96871a5
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 102 deletions.
4 changes: 3 additions & 1 deletion tests/vis/image/bounding_box_visualizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def setUp(self) -> None:
self.scores: list[NDArrayF64] = testcase_gt["scores"]
self.tracks = [np.arange(len(b)) for b in self.boxes]

cat_mapping = {v: k for k, v in COCO_COLOR_MAPPING.items()}

self.vis = BoundingBoxVisualizer(
n_colors=20, class_id_mapping=COCO_COLOR_MAPPING, vis_freq=1
n_colors=20, cat_mapping=cat_mapping, vis_freq=1
)

def tearDown(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vis4d/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def set_tf32(use_tf32: bool, precision: str) -> None: # pragma: no cover

def init_random_seed() -> int:
"""Initialize random seed for the experiment."""
return np.random.randint(2**31)
return int(np.random.randint(2**31))


def set_random_seed(seed: int, deterministic: bool = False) -> None:
Expand Down
16 changes: 15 additions & 1 deletion vis4d/engine/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vis4d.common.util import init_random_seed, set_random_seed, set_tf32
from vis4d.config import instantiate_classes
from vis4d.config.typing import ExperimentConfig
from vis4d.engine.callbacks import VisualizerCallback

from .optim import set_up_optimizers
from .parser import pprints_config
Expand Down Expand Up @@ -87,6 +88,7 @@ def run_experiment(
use_slurm: bool = False,
ckpt_path: str | None = None,
resume: bool = False,
vis: bool = False,
) -> None:
"""Entry point for running a single experiment.
Expand All @@ -99,6 +101,7 @@ def run_experiment(
required environment variables for slurm.
ckpt_path (str | None): Path to a checkpoint to load.
resume (bool): If set, resume training from the checkpoint.
vis (bool): If set, enable visualizer callback.
Raises:
ValueError: If `mode` is not `fit` or `test`.
Expand Down Expand Up @@ -141,7 +144,18 @@ def run_experiment(
)

# Callbacks
callbacks = [instantiate_classes(cb) for cb in config.callbacks]
callbacks = []
for cb in config.callbacks:
callback = instantiate_classes(cb)

if not vis and isinstance(callback, VisualizerCallback):
rank_zero_info(
"VisualizerCallback is not used. "
"Please set --vis=True to use it."
)
continue

callbacks.append(callback)

# Setup DDP & seed
seed = init_random_seed() if config.seed == -1 else config.seed
Expand Down
6 changes: 6 additions & 0 deletions vis4d/engine/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
_SLURM = flags.DEFINE_bool(
"slurm", default=False, help="If set, setup slurm running jobs."
)
_VIS = flags.DEFINE_bool(
"vis",
default=False,
help="If set, running visualization using visualizer callback.",
)


__all__ = [
Expand All @@ -31,4 +36,5 @@
"_SHOW_CONFIG",
"_SWEEP",
"_SLURM",
"_VIS",
]
12 changes: 11 additions & 1 deletion vis4d/engine/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@
from vis4d.config.typing import ExperimentConfig

from .experiment import run_experiment
from .flag import _CKPT, _CONFIG, _GPUS, _RESUME, _SHOW_CONFIG, _SLURM, _SWEEP
from .flag import (
_CKPT,
_CONFIG,
_GPUS,
_RESUME,
_SHOW_CONFIG,
_SLURM,
_SWEEP,
_VIS,
)


def main(argv: ArgsType) -> None:
Expand Down Expand Up @@ -68,6 +77,7 @@ def main(argv: ArgsType) -> None:
_SLURM.value,
_CKPT.value,
_RESUME.value,
_VIS.value,
)


Expand Down
7 changes: 5 additions & 2 deletions vis4d/op/layer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
super().__init__()
self.batch_first = batch_first
self.embed_dims = embed_dims
self.num_heads = num_heads

self.attn = nn.MultiheadAttention(
embed_dims, num_heads, dropout=attn_drop, **kwargs
Expand Down Expand Up @@ -193,8 +194,10 @@ def forward(
key_pos = query_pos
else:
rank_zero_warn(
"position encoding of key is"
+ f"missing in {self.__class__.__name__}."
f"Position encoding of key in {self.__class__.__name__}"
+ "is missing, and positional encodeing of query has "
+ "has different shape and cannot be usde for key. "
+ "It it is not desired, please provide key_pos."
)

if query_pos is not None:
Expand Down
24 changes: 23 additions & 1 deletion vis4d/op/layer/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,15 @@ def __init__(
is_power_of_2(d_model // n_heads)

self.d_model = d_model
self.embed_dims = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.im2col_step = im2col_step

# Aligned Attributes to MHA
self.embed_dims = d_model
self.num_heads = n_heads

self.sampling_offsets = nn.Linear(
d_model, n_heads * n_levels * n_points * 2
)
Expand Down Expand Up @@ -359,3 +362,22 @@ def forward(
output = self.output_proj(output)

return output

def __call__(
self,
query: Tensor,
reference_points: Tensor,
input_flatten: Tensor,
input_spatial_shapes: Tensor,
input_level_start_index: Tensor,
input_padding_mask: Tensor | None = None,
) -> Tensor:
"""Type definition for call implementation."""
return self._call_impl(
query,
reference_points,
input_flatten,
input_spatial_shapes,
input_level_start_index,
input_padding_mask,
)
120 changes: 34 additions & 86 deletions vis4d/op/layer/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
"""

from __future__ import annotations

import math

import torch
Expand Down Expand Up @@ -59,24 +61,45 @@ def __init__(
self.eps = eps
self.offset = offset

def forward(self, mask: Tensor) -> Tensor:
def forward(
self, mask: Tensor | None, inputs: Tensor | None = None
) -> Tensor:
"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
mask (Tensor | None): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
for this image. Shape [bs, h, w]. If None, it means single
image or batch image with no padding.
inputs (Tensor | None): The input tensor. It mask is None, this
input tensor is required to get the shape of the input image.
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if mask is not None:
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
b, h, w = mask.size()
device = mask.device
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
else:
# single image or batch image with no padding
assert isinstance(inputs, Tensor)
b, _, h, w = inputs.shape
device = inputs.device
x_embed = torch.arange(
1, w + 1, dtype=torch.float32, device=device
)
x_embed = x_embed.view(1, 1, -1).repeat(b, h, 1)
y_embed = torch.arange(
1, h + 1, dtype=torch.float32, device=device
)
y_embed = y_embed.view(1, -1, 1).repeat(b, 1, w)
if self.normalize:
y_embed = (
(y_embed + self.offset)
Expand All @@ -89,13 +112,13 @@ def forward(self, mask: Tensor) -> Tensor:
* self.scale
)
dim_t = torch.arange(
self.num_feats, dtype=torch.float32, device=mask.device
self.num_feats, dtype=torch.float32, device=device
)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# use `view` instead of `flatten` for dynamically exporting to ONNX
b, h, w = mask.size()

pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).view(b, h, w, -1)
Expand Down Expand Up @@ -167,78 +190,3 @@ def forward(self, mask: Tensor) -> Tensor:
.repeat(mask.shape[0], 1, 1, 1)
)
return pos


class SinePositionalEncoding3D(SinePositionalEncoding):
"""3D Position encoding with sine and cosine functions."""

def forward(self, mask: Tensor) -> Tensor:
"""Forward function for `SinePositionalEncoding3D`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, t, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
assert mask.dim() == 4, (
f"{mask.shape} should be a 4-dimensional Tensor,"
f" got {mask.dim()}-dimensional Tensor instead "
)
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
z_embed = not_mask.cumsum(1, dtype=torch.float32)
y_embed = not_mask.cumsum(2, dtype=torch.float32)
x_embed = not_mask.cumsum(3, dtype=torch.float32)
if self.normalize:
z_embed = (
(z_embed + self.offset)
/ (z_embed[:, -1:, :, :] + self.eps)
* self.scale
)
y_embed = (
(y_embed + self.offset)
/ (y_embed[:, :, -1:, :] + self.eps)
* self.scale
)
x_embed = (
(x_embed + self.offset)
/ (x_embed[:, :, :, -1:] + self.eps)
* self.scale
)
dim_t = torch.arange(
self.num_feats, dtype=torch.float32, device=mask.device
)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)

dim_t_z = torch.arange(
(self.num_feats * 2), dtype=torch.float32, device=mask.device
)
dim_t_z = self.temperature ** (
2 * (dim_t_z // 2) / (self.num_feats * 2)
)

pos_x = x_embed[:, :, :, :, None] / dim_t
pos_y = y_embed[:, :, :, :, None] / dim_t
pos_z = z_embed[:, :, :, :, None] / dim_t_z
# use `view` instead of `flatten` for dynamically exporting to ONNX
b, t, h, w = mask.size()
pos_x = torch.stack(
(pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()),
dim=5,
).view(b, t, h, w, -1)
pos_y = torch.stack(
(pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()),
dim=5,
).view(b, t, h, w, -1)
pos_z = torch.stack(
(pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()),
dim=5,
).view(b, t, h, w, -1)
pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)
return pos
2 changes: 2 additions & 0 deletions vis4d/op/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(
LayerScale. Default: 0.0
"""
super().__init__()
self.embed_dims = embed_dims

layers: list[nn.Module] = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
Expand Down
26 changes: 22 additions & 4 deletions vis4d/pl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
from vis4d.common.util import set_tf32
from vis4d.config import instantiate_classes
from vis4d.config.typing import ExperimentConfig
from vis4d.engine.callbacks import CheckpointCallback
from vis4d.engine.flag import _CKPT, _CONFIG, _GPUS, _RESUME, _SHOW_CONFIG
from vis4d.engine.callbacks import CheckpointCallback, VisualizerCallback
from vis4d.engine.flag import (
_CKPT,
_CONFIG,
_GPUS,
_RESUME,
_SHOW_CONFIG,
_VIS,
)
from vis4d.engine.parser import pprints_config
from vis4d.pl.callbacks import CallbackWrapper, LRSchedulerCallback
from vis4d.pl.data_module import DataModule
Expand Down Expand Up @@ -83,12 +90,23 @@ def main(argv: ArgsType) -> None:
test_data_connector = None

# Callbacks
vis = _VIS.value

callbacks: list[Callback] = []
for cb in config.callbacks:
callback = instantiate_classes(cb)
# Skip checkpoint callback to use PL ModelCheckpoint
if not isinstance(callback, CheckpointCallback):
callbacks.append(CallbackWrapper(callback))
if isinstance(callback, CheckpointCallback):
continue

if not vis and isinstance(callback, VisualizerCallback):
rank_zero_info(
"VisualizerCallback is not used. "
"Please set --vis=True to use it."
)
continue

callbacks.append(CallbackWrapper(callback))

if "pl_callbacks" in config:
pl_callbacks = [instantiate_classes(cb) for cb in config.pl_callbacks]
Expand Down
Loading

0 comments on commit 96871a5

Please sign in to comment.