diff --git a/pyproject.toml b/pyproject.toml index cf64f4eeb..a0d7213a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,3 +140,4 @@ include = ["vis4d*"] [project.scripts] vis4d = "vis4d.engine.run:entrypoint" vis4d-pl = "vis4d.pl.run:entrypoint" +vis4d-zoo = "vis4d.zoo.run:entrypoint" diff --git a/tests/common/dict_test.py b/tests/common/dict_test.py index e4f308778..2363e4501 100644 --- a/tests/common/dict_test.py +++ b/tests/common/dict_test.py @@ -2,12 +2,20 @@ import unittest -from vis4d.common.dict import get_dict_nested, set_dict_nested +from vis4d.common.dict import flatten_dict, get_dict_nested, set_dict_nested class TestDictUtils(unittest.TestCase): """Test cases for array conversion ops.""" + def test_flatten_dict(self) -> None: + """Tests the flatten_dict function.""" + d = {"a": {"b": {"c": 10}}} + self.assertEqual(flatten_dict(d, "."), ["a.b.c"]) + + d = {"a": {"b": {"c": 10, "d": 20}}} + self.assertEqual(flatten_dict(d, "/"), ["a/b/c", "a/b/d"]) + def test_set_dict_nested(self) -> None: """Tests the set_dict_nested function.""" d = {} # type:ignore diff --git a/tests/config/registry_test.py b/tests/config/registry_test.py new file mode 100644 index 000000000..83ce62046 --- /dev/null +++ b/tests/config/registry_test.py @@ -0,0 +1,71 @@ +"""Test config registry.""" +from __future__ import annotations + +import unittest + +import pytest + +from tests.util import get_test_data +from vis4d.config.util.registry import get_config_by_name, register_config + + +class TestRegistry(unittest.TestCase): + """Test the config registry.""" + + def test_yaml(self) -> None: + """Test reading a yaml config file.""" + file = get_test_data( + "config_test/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.yaml" + ) + + # Config can be resolved + config = get_config_by_name(file) + self.assertTrue(config is not None) + + # Config does not exist + with pytest.raises(ValueError) as err: + config = get_config_by_name(file.replace("r50", "r91")) + self.assertTrue("Could not find" in str(err.value)) + + def test_py(self) -> None: + """Test reading a py config file from the model zoo.""" + file = "/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py" + cfg = get_config_by_name(file) + self.assertTrue(cfg is not None) + + # Only by file name + file = "faster_rcnn_r50_1x_bdd100k.py" + cfg = get_config_by_name(file) + self.assertTrue(cfg is not None) + + # Check did you mean message + file = "faster_rcnn_r90_1x_bdd100k" + with pytest.raises(ValueError) as err: + cfg = get_config_by_name(file) + self.assertTrue("faster_rcnn_r50_1x_bdd100k" in str(err.value)) + + def test_zoo(self) -> None: + """Test reading a registered config from the zoo.""" + config = get_config_by_name("faster_rcnn_r50_1x_bdd100k") + self.assertTrue(config is not None) + + # Full Qualified Name + config = get_config_by_name("bdd100k/faster_rcnn_r50_1x_bdd100k") + self.assertTrue(config is not None) + + # Check did you mean message + with pytest.raises(ValueError) as err: + config = get_config_by_name("faster_rcnn_r90_1x_bdd100k") + self.assertTrue("faster_rcnn_r50_1x_bdd100k" in str(err.value)) + + def test_decorator(self) -> None: + """Test registering a config.""" + + @register_config("cat", "test") # type: ignore + def get_config() -> dict[str, str]: + """Test config.""" + return {"test": "test"} + + config = get_config_by_name("cat/test") + self.assertTrue(config is not None) + self.assertEqual(config["test"], "test") diff --git a/vis4d/common/dict.py b/vis4d/common/dict.py index 6f9dee8da..9f0098cc2 100644 --- a/vis4d/common/dict.py +++ b/vis4d/common/dict.py @@ -6,6 +6,35 @@ from vis4d.common import DictStrAny +def flatten_dict(dictionary: DictStrAny, seperator: str) -> list[str]: + """Flatten a nested dictionary. + + Args: + dictionary (DictStrAny): The dictionary to flatten. + seperator (str): The seperator to use between keys. + + Returns: + List[str]: A list of flattened keys. + + Examples: + >>> d = {'a': {'b': {'c': 10}}} + >>> flatten_dict(d, '.') + ['a.b.c'] + """ + flattened = [] + for key, value in dictionary.items(): + if isinstance(value, dict): + flattened.extend( + [ + f"{key}{seperator}{subkey}" + for subkey in flatten_dict(value, seperator) + ] + ) + else: + flattened.append(key) + return flattened + + def get_dict_nested( # type: ignore dictionary: DictStrAny, keys: list[str], allow_missing: bool = False ) -> Any: diff --git a/vis4d/common/imports.py b/vis4d/common/imports.py index be4e82b5a..0e74a70d3 100644 --- a/vis4d/common/imports.py +++ b/vis4d/common/imports.py @@ -37,6 +37,9 @@ def package_available(package_name: str) -> bool: OPEN3D_AVAILABLE = package_available("open3d") PLOTLY_AVAILABLE = package_available("plotly") +# vis4d cuda ops +VIS4D_CUDA_OPS_AVAILABLE = package_available("vis4d_cuda_ops") + # logging TENSORBOARD_AVAILABLE = package_available("tensorboardX") or package_available( "tensorboard" diff --git a/vis4d/common/util.py b/vis4d/common/util.py index 3be88ce13..2373c6291 100644 --- a/vis4d/common/util.py +++ b/vis4d/common/util.py @@ -1,5 +1,6 @@ """Utility functions for common usage.""" import random +from difflib import get_close_matches import numpy as np import torch @@ -8,6 +9,30 @@ from .logging import rank_zero_warn +def create_did_you_mean_msg(keys: list[str], query: str) -> str: + """Create a did you mean message. + + Args: + keys (list[str]): List of available keys. + query (str): Query. + + Returns: + str: Did you mean message. + + Examples: + >>> keys = ["foo", "bar", "baz"] + >>> query = "fo" + >>> print(create_did_you_mean_msg(keys, query)) + Did you mean: + foo + """ + msg = "" + if len(keys) > 0: + msg = "Did you mean:\n\t" + msg += "\n\t".join(get_close_matches(query, keys, cutoff=0.75)) + return msg + + def set_tf32(use_tf32: bool = False) -> None: # pragma: no cover """Set torch TF32. diff --git a/vis4d/config/util/registry.py b/vis4d/config/util/registry.py new file mode 100644 index 000000000..60b64348e --- /dev/null +++ b/vis4d/config/util/registry.py @@ -0,0 +1,262 @@ +"""Utility function for registering config files.""" +from __future__ import annotations + +import glob +import os +import pathlib +import warnings +from typing import Callable, Union + +import yaml +from ml_collections import ConfigDict +from ml_collections.config_flags.config_flags import _LoadConfigModule + +from vis4d.common.dict import flatten_dict, get_dict_nested +from vis4d.common.typing import ArgsType +from vis4d.common.util import create_did_you_mean_msg +from vis4d.config.config_dict import FieldConfigDict +from vis4d.zoo import AVAILABLE_MODELS + +MODEL_ZOO_FOLDER = str( + (pathlib.Path(os.path.dirname(__file__)) / ".." / ".." / "zoo").resolve() +) + +# Paths that are used to search for config files. +REGISTERED_CONFIG_PATHS = [MODEL_ZOO_FOLDER] + + +TFunc = Callable[[ArgsType], ArgsType] +TfuncConfDict = Union[Callable[[ArgsType], ConfigDict], type] + + +def register_config( + category: str, name: str +) -> Callable[[TfuncConfDict], None]: + """Register a config in the model zoo for the given name and category. + + The config will then be available via `get_config_by_name` utilities and + located in the AVAILABLE_MODELS dictionary located at + [category][name]. + + Args: + category: Category of the config. + name: Name of the config. + + Returns: + The decorator. + """ + + def decorator(fnc_or_clazz: TfuncConfDict) -> None: + """Decorator for registering a config. + + Args: + fnc_or_clazz: Function or class to register. If a function is + passed, it will be wrapped in a class and the class will be + registered. If a class is passed, it will be registered + directly. + """ + if callable(fnc_or_clazz): + # Directly annotated get_config function. Wrap it and register it. + class Wrapper: + """Wrapper class.""" + + def get_config( + self, *args: ArgsType, **kwargs: ArgsType + ) -> ConfigDict: + """Resolves the get_config function.""" + return fnc_or_clazz(*args, **kwargs) + + module = Wrapper() + else: + # Directly annotated class. Register it. + module = fnc_or_clazz + + # Register the config + if category not in AVAILABLE_MODELS: + AVAILABLE_MODELS[category] = {} + + assert isinstance(AVAILABLE_MODELS[category], dict) + + AVAILABLE_MODELS[category][name] = module + + return decorator + + +def _resolve_config_path(path: str) -> str: + """Resolve the path of a config file. + + Args: + path: Name or path of the config. + If the config is not found at this location, + the function will look for the config in the model zoo folder. + + Returns: + The resolved path of the config. + + Raises: + ValueError: If the config is not found. + """ + if os.path.exists(path): + return path + + # Check for duplicate paths. + found_paths: list[str] = [] + all_paths = [] + + for p in REGISTERED_CONFIG_PATHS: + paths = sorted( + glob.glob( + os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"), + recursive=True, + ) + ) + print( + paths, + "lookup", + os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"), + ) + for cfg_path in paths: + if cfg_path.endswith(path): + found_paths.append(cfg_path) + all_paths.extend(paths) + + if len(found_paths) > 1: + warnings.warn( + f"Found multiple paths for config {path}:" + f"{found_paths}. Will load the config from the first one!" + ) + elif len(found_paths) == 0: + hint = create_did_you_mean_msg( + [*all_paths, *[os.path.basename(p) for p in all_paths]], path + ) + raise ValueError( + f"Could not find config {path}. \n" + f"The file does not exists at the path {path} or " + f"in the dedicated locations at {REGISTERED_CONFIG_PATHS}. \n" + f"Please check the path or add the config to the model zoo. \n" + f"Current working directory: {os.getcwd()}\n {hint}" + ) + return found_paths[0] + + +def _load_yaml_config(name_or_path: str) -> FieldConfigDict: + """Loads a .yaml configuration file. + + Args: + name_or_path: Name or path of the config. + If the config is not found at this location, $ + the function will look for the config in the model zoo folder. + + Returns: + The config for the experiment. + """ + path = _resolve_config_path(name_or_path) + with open(path, "r", encoding="utf-8") as yaml_file: + return FieldConfigDict(yaml.load(yaml_file, Loader=yaml.UnsafeLoader)) + + +def _load_py_config( + name_or_path: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Loads a .py configuration file. + + Args: + name_or_path: Name or path of the config. + If the config is not found at this location, + the function will look for the config in the model zoo folder. + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Returns: + The config for the experiment. + """ + path = _resolve_config_path(name_or_path) + config_module = _LoadConfigModule(f"{os.path.basename(path)}_config", path) + print("args", args, *args) + cfg = getattr(config_module, method_name)(*args) + assert isinstance(cfg, ConfigDict) + return cfg + + +def _get_registered_configs( + config_name: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Get a model from the registered config locations. + + Args: + config_name: Name of the config. This can either be + the full path of the config relative to the registered locations + or the name of the config. + If the config matches multiple configs (e.g. if there are two + conflicting config a/cfg and b/cfg) or if it is not found, + a ValueError is raised. + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Raises: + ValueError: If the config is not found. + + Returns: + The Config. + """ + models = flatten_dict(AVAILABLE_MODELS, os.path.sep) + # check if there is an absolute match for the config + if config_name in models: + module = get_dict_nested( + AVAILABLE_MODELS, config_name.split(os.path.sep) + ) + return getattr(module, method_name)(*args) + # check if there is a partial match for the config + matches = {} + for model in models: + if model.endswith(config_name): + matches[model] = get_dict_nested( + AVAILABLE_MODELS, model.split(os.path.sep) + ) + + if len(matches) > 1: + raise ValueError( + f"Found multiple configs matching {config_name}:" + f"{matches.keys()}.\nPlease specify a unique config name." + ) + if len(matches) == 0: + msg = create_did_you_mean_msg( + [*models, *[os.path.basename(m) for m in models]], config_name + ) + raise ValueError(msg) + + module = list(matches.values())[0] + return getattr(module, method_name)(*args) + + +def get_config_by_name( + name_or_path: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Get a config by name or path. + + Args: + name_or_path: Name or path of the config. + If the path has a .yaml or .py extension, the function will + load the config from the file. + Otherwise, the function will try to resolve the config from the + registered config locations. You can specify a config by its full + registered path (e.g. "a/b/cfg") or by its name (e.g. "cfg"). + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Returns: + The config. + + Raises: + ValueError: If the config is not found. + """ + if name_or_path.endswith(".yaml"): + return _load_yaml_config(name_or_path) + if name_or_path.endswith(".py"): + return _load_py_config(name_or_path, *args, method_name=method_name) + return _get_registered_configs( + name_or_path, *args, method_name=method_name + ) diff --git a/vis4d/engine/parser.py b/vis4d/engine/parser.py index d567da1f9..586de230e 100644 --- a/vis4d/engine/parser.py +++ b/vis4d/engine/parser.py @@ -7,20 +7,19 @@ import traceback from typing import Any -import yaml from absl import flags from ml_collections import ConfigDict, FieldReference from ml_collections.config_flags.config_flags import ( _ConfigFlag, _ErrorConfig, - _LoadConfigModule, _LockConfig, ) from vis4d.config import copy_and_resolve_references +from vis4d.config.util.registry import get_config_by_name -class _ConfigFileParser(flags.ArgumentParser): # type: ignore +class ConfigFileParser(flags.ArgumentParser): # type: ignore """Parser for config files.""" def __init__( @@ -29,6 +28,15 @@ def __init__( lock_config: bool = True, method_name: str = "get_config", ) -> None: + """Initializes the parser. + + Args: + name (str): The name of the flag (e.g. config for --config flag) + lock_config (bool, optional): Whether or not to lock the config. + Defaults to True. + method_name (str, optional): Name of the method to call in the + config. Defaults to "get_config". + """ self.name = name self._lock_config = lock_config self.method_name = method_name @@ -58,11 +66,13 @@ def parse( # pylint: disable=arguments-renamed # This will be a 2 element list iff extra configuration args are # present. split_path = path.split(":", 1) + try: - config_module = _LoadConfigModule( - f"{self.name}_config", split_path[0] + config = get_config_by_name( + split_path[0], + *split_path[1:], + method_name=self.method_name, ) - config = getattr(config_module, self.method_name)(*split_path[1:]) if config is None: logging.warning( "%s:%s() returned None, did you forget a return " @@ -73,7 +83,7 @@ def parse( # pylint: disable=arguments-renamed except IOError as e: # Don't raise the error unless/until the config is # actually accessed. - config = _ErrorConfig(e) + return _ErrorConfig(e) # Third party flags library catches TypeError and ValueError # and rethrows, # removing useful information unless it is added here (b/63877430): @@ -89,48 +99,7 @@ def parse( # pylint: disable=arguments-renamed return config def flag_type(self) -> str: - return "config object" - - -class ConfigFileParser(_ConfigFileParser): - """Parser for config files. - - Note, this wraps internal functions of the ml_collections code and might - be fragile! - """ - - def parse(self, path: str) -> ConfigDict: - """Returns the config object for a given path. - - If a colon is present in `path`, everything to the right of the first - colon is passed to `get_config` as an argument. This allows the - structure of what is returned to be modified. - - Works with .py file that contain a get_config() function and .yaml. - - Args: - path (string): path pointing to the config file to execute. May also - contain a config_string argument, e.g. be of the form - "config.py:some_configuration" or "config.yaml". - Returns (ConfigDict): - ConfigDict located at 'path' - """ - if path.split(".")[-1] == "yaml": - with open(path, "r", encoding="utf-8") as yaml_file: - data_dict = ConfigDict(yaml.safe_load(yaml_file)) - - if self._lock_config: - data_dict.lock() - return data_dict - else: - return super().parse(path) - - def flag_type(self) -> str: - """The flag type of this object. - - Returns: - str: config object - """ + """Returns the type of the flag.""" return "config object" diff --git a/vis4d/model/seg/semantic_fpn.py b/vis4d/model/seg/semantic_fpn.py index 42347e3a6..870519129 100644 --- a/vis4d/model/seg/semantic_fpn.py +++ b/vis4d/model/seg/semantic_fpn.py @@ -45,7 +45,8 @@ class SemanticFPN(nn.Module): num_classes (int): Number of classes. resize (bool): Resize output to input size. weights (None | str): Pre-trained weights. - basemodel (BaseModel): Base model. + basemodel (None | BaseModel): Base model to use. If None is passed, + this will default to ResNetV1c """ def __init__( @@ -53,16 +54,19 @@ def __init__( num_classes: int, resize: bool = True, weights: None | str = None, - basemodel: BaseModel = ResNetV1c( - "resnet50_v1c", - pretrained=True, - trainable_layers=5, - norm_freezed=False, - ), + basemodel: None | BaseModel = None, ): """Init.""" super().__init__() self.resize = resize + if basemodel is None: + basemodel = ResNetV1c( + "resnet50_v1c", + pretrained=True, + trainable_layers=5, + norm_freezed=False, + ) + self.basemodel = basemodel self.fpn = FPN(self.basemodel.out_channels[2:], 256, extra_blocks=None) self.seg_head = SemanticFPNHead(num_classes, 256) diff --git a/vis4d/op/detect3d/util.py b/vis4d/op/detect3d/util.py index 8276e4f9e..dbe2c85c0 100644 --- a/vis4d/op/detect3d/util.py +++ b/vis4d/op/detect3d/util.py @@ -3,7 +3,11 @@ import torch from torch import Tensor -from vis4d_cuda_ops import nms_rotated # pylint: disable=no-name-in-module + +from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE + +if VIS4D_CUDA_OPS_AVAILABLE: + from vis4d_cuda_ops import nms_rotated # pylint: disable=no-name-in-module def bev_3d_nms( @@ -102,5 +106,10 @@ def batched_nms_rotated( boxes.clone() ) # avoid modifying the original values in boxes boxes_for_nms[:, :2] += offsets[:, None] + + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "Please install vis4d_cuda_ops to use batched_nms_rotated" + ) keep = nms_rotated(boxes_for_nms, scores, iou_threshold) return keep diff --git a/vis4d/op/layer/ms_deform_attn.py b/vis4d/op/layer/ms_deform_attn.py index a2c885dd9..906c60747 100644 --- a/vis4d/op/layer/ms_deform_attn.py +++ b/vis4d/op/layer/ms_deform_attn.py @@ -1,7 +1,7 @@ # pylint: disable=no-name-in-module, abstract-method, arguments-differ """Multi-Scale Deformable Attention Module. -Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py) # pylint: disable=line-too-long +Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py) # pylint: disable=line-too-long """ from __future__ import annotations @@ -13,10 +13,13 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.init import constant_, xavier_uniform_ -from vis4d_cuda_ops import ms_deform_attn_backward, ms_deform_attn_forward +from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE from vis4d.common.logging import rank_zero_warn +if VIS4D_CUDA_OPS_AVAILABLE: + from vis4d_cuda_ops import ms_deform_attn_backward, ms_deform_attn_forward + class MSDeformAttentionFunction(Function): # pragma: no cover """Multi-Scale Deformable Attention Function module.""" @@ -32,6 +35,10 @@ def forward( # type: ignore im2col_step: int, ) -> Tensor: """Forward pass.""" + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "MSDeformAttentionFunction requires vis4d cuda ops to run." + ) ctx.im2col_step = im2col_step output = ms_deform_attn_forward( value, @@ -56,6 +63,10 @@ def backward( ctx, grad_output: Tensor ) -> tuple[Tensor, None, None, Tensor, Tensor, None]: """Backward pass.""" + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "MSDeformAttentionFunction requires vis4d cuda ops to run." + ) ( value, value_spatial_shapes, diff --git a/vis4d/zoo/__init__.py b/vis4d/zoo/__init__.py index 9fa2ac165..553819b75 100644 --- a/vis4d/zoo/__init__.py +++ b/vis4d/zoo/__init__.py @@ -1 +1,30 @@ """Model Zoo.""" +from __future__ import annotations + +from vis4d.common.typing import ArgsType + +from .bdd100k import AVAILABLE_MODELS as BDD100K_MODELS +from .bevformer import AVAILABLE_MODELS as BEVFORMER_MODELS +from .cc_3dt import AVAILABLE_MODELS as CC_3DT_MODELS +from .faster_rcnn import AVAILABLE_MODELS as FASTER_RCNN_MODELS +from .fcn_resnet import AVAILABLE_MODELS as FCN_RESNET_MODELS +from .mask_rcnn import AVAILABLE_MODELS as MASK_RCNN_MODELS +from .qdtrack import AVAILABLE_MODELS as QDTRACK_MODELS +from .retinanet import AVAILABLE_MODELS as RETINANET_MODELS +from .shift import AVAILABLE_MODELS as SHIFT_MODELS +from .vit import AVAILABLE_MODELS as VIT_MODELS +from .yolox import AVAILABLE_MODELS as YOLOX_MODELS + +AVAILABLE_MODELS: dict[str, dict[str, ArgsType]] = { + "bdd100k": BDD100K_MODELS, + "cc_3dt": CC_3DT_MODELS, + "bevformer": BEVFORMER_MODELS, + "faster_rcnn": FASTER_RCNN_MODELS, + "fcn_resnet": FCN_RESNET_MODELS, + "mask_rcnn": MASK_RCNN_MODELS, + "qdtrack": QDTRACK_MODELS, + "retinanet": RETINANET_MODELS, + "shift": SHIFT_MODELS, + "vit": VIT_MODELS, + "yolox": YOLOX_MODELS, +} diff --git a/vis4d/zoo/bdd100k/__init__.py b/vis4d/zoo/bdd100k/__init__.py index 28f38be3e..e1925b452 100644 --- a/vis4d/zoo/bdd100k/__init__.py +++ b/vis4d/zoo/bdd100k/__init__.py @@ -1 +1,24 @@ """BDD100K Model Zoo.""" +from .faster_rcnn import faster_rcnn_r50_1x_bdd100k, faster_rcnn_r50_3x_bdd100k +from .mask_rcnn import ( + mask_rcnn_r50_1x_bdd100k, + mask_rcnn_r50_3x_bdd100k, + mask_rcnn_r50_5x_bdd100k, +) +from .semantic_fpn import ( + semantic_fpn_r50_40k_bdd100k, + semantic_fpn_r50_80k_bdd100k, + semantic_fpn_r101_80k_bdd100k, +) + +# Lists of available models in BDD100K Model Zoo. +AVAILABLE_MODELS = { + "faster_rcnn_r50_1x_bdd100k": faster_rcnn_r50_1x_bdd100k, + "faster_rcnn_r50_3x_bdd100k": faster_rcnn_r50_3x_bdd100k, + "mask_rcnn_r50_1x_bdd100k": mask_rcnn_r50_1x_bdd100k, + "mask_rcnn_r50_3x_bdd100k": mask_rcnn_r50_3x_bdd100k, + "mask_rcnn_r50_5x_bdd100k": mask_rcnn_r50_5x_bdd100k, + "semantic_fpn_r50_40k_bdd100k": semantic_fpn_r50_40k_bdd100k, + "semantic_fpn_r50_80k_bdd100k": semantic_fpn_r50_80k_bdd100k, + "semantic_fpn_r101_80k_bdd100k": semantic_fpn_r101_80k_bdd100k, +} diff --git a/vis4d/zoo/bevformer/__init__.py b/vis4d/zoo/bevformer/__init__.py index 9cf665bf1..1886bbde9 100644 --- a/vis4d/zoo/bevformer/__init__.py +++ b/vis4d/zoo/bevformer/__init__.py @@ -1 +1,8 @@ """BEVFormer model zoo.""" +from . import bevformer_base, bevformer_tiny, bevformer_vis + +AVAILABLE_MODELS = { + "bevformer_base": bevformer_base, + "bevformer_tiny": bevformer_tiny, + "bevformer_vis": bevformer_vis, +} diff --git a/vis4d/zoo/cc_3dt/__init__.py b/vis4d/zoo/cc_3dt/__init__.py index 5e4c98dd3..bb0afc9f9 100644 --- a/vis4d/zoo/cc_3dt/__init__.py +++ b/vis4d/zoo/cc_3dt/__init__.py @@ -1 +1,14 @@ """CC-3DT Model Zoo.""" +from . import ( + cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc, + cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc, + cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc, + cc_3dt_nusc_vis, +) + +AVAILABLE_MODELS = { + "cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc": cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc, + "cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc": cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc, + "cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc": cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc, # pylint: disable=line-too-long + "cc_3dt_nusc_vis": cc_3dt_nusc_vis, +} diff --git a/vis4d/zoo/faster_rcnn/__init__.py b/vis4d/zoo/faster_rcnn/__init__.py index 7a5ab289a..70c525d56 100644 --- a/vis4d/zoo/faster_rcnn/__init__.py +++ b/vis4d/zoo/faster_rcnn/__init__.py @@ -1 +1,6 @@ """Faster-RCNN Model Zoo.""" +from . import faster_rcnn_coco + +AVAILABLE_MODELS = { + "faster_rcnn_coco": faster_rcnn_coco, +} diff --git a/vis4d/zoo/fcn_resnet/__init__.py b/vis4d/zoo/fcn_resnet/__init__.py index 74f16fdcc..fc533f78d 100644 --- a/vis4d/zoo/fcn_resnet/__init__.py +++ b/vis4d/zoo/fcn_resnet/__init__.py @@ -1 +1,6 @@ """FCN Model Zoo.""" +from . import fcn_resnet_coco + +AVAILABLE_MODELS = { + "fcn_resnet_coco": fcn_resnet_coco, +} diff --git a/vis4d/zoo/mask_rcnn/__init__.py b/vis4d/zoo/mask_rcnn/__init__.py index a56b0141f..6b09f158f 100644 --- a/vis4d/zoo/mask_rcnn/__init__.py +++ b/vis4d/zoo/mask_rcnn/__init__.py @@ -1 +1,7 @@ """Mask-RCNN Model Zoo.""" + +from . import mask_rcnn_coco + +AVAILABLE_MODELS = { + "mask_rcnn_coco": mask_rcnn_coco, +} diff --git a/vis4d/zoo/qdtrack/__init__.py b/vis4d/zoo/qdtrack/__init__.py index 7812c4d6e..dcc2cf6cc 100644 --- a/vis4d/zoo/qdtrack/__init__.py +++ b/vis4d/zoo/qdtrack/__init__.py @@ -1 +1,8 @@ """QDTrack Model Zoo.""" + +from . import qdtrack_bdd100k, qdtrack_yolox_bdd100k + +AVAILABLE_MODELS = { + "qdtrack_bdd100k": qdtrack_bdd100k, + "qdtrack_yolox_bdd100k": qdtrack_yolox_bdd100k, +} diff --git a/vis4d/zoo/retinanet/__init__.py b/vis4d/zoo/retinanet/__init__.py index d25388afa..f8ea58238 100644 --- a/vis4d/zoo/retinanet/__init__.py +++ b/vis4d/zoo/retinanet/__init__.py @@ -1 +1,7 @@ """RetinaNet Model Zoo.""" + +from . import retinanet_coco + +AVAILABLE_MODELS = { + "retinanet_coco": retinanet_coco, +} diff --git a/vis4d/zoo/run.py b/vis4d/zoo/run.py new file mode 100644 index 000000000..3aa1549bc --- /dev/null +++ b/vis4d/zoo/run.py @@ -0,0 +1,26 @@ +"""CLI interface.""" +from __future__ import annotations + +from absl import app + +from vis4d.common import ArgsType +from vis4d.zoo import AVAILABLE_MODELS + + +def main(argv: ArgsType) -> None: + """Main entry point for the model zoo.""" + assert len(argv) > 1, "Command must be specified: `list`" + if argv[1] == "list": + for ds, models in AVAILABLE_MODELS.items(): + print(ds) + model_names = list(models.keys()) + for model in model_names[:-1]: + print(" ├─", model) + print(" └─", model_names[-1]) + else: + raise ValueError(f"Invalid command. {argv[1]}") + + +def entrypoint() -> None: + """Entry point for the CLI.""" + app.run(main) diff --git a/vis4d/zoo/shift/__init__.py b/vis4d/zoo/shift/__init__.py index 28f38be3e..498a278c1 100644 --- a/vis4d/zoo/shift/__init__.py +++ b/vis4d/zoo/shift/__init__.py @@ -1 +1,31 @@ """BDD100K Model Zoo.""" + +from .faster_rcnn import ( + faster_rcnn_r50_6e_shift_all_domains, + faster_rcnn_r50_12e_shift, + faster_rcnn_r50_36e_shift, +) +from .mask_rcnn import ( + mask_rcnn_r50_6e_shift_all_domains, + mask_rcnn_r50_12e_shift, + mask_rcnn_r50_36e_shift, +) +from .semantic_fpn import ( + semantic_fpn_r50_40k_shift, + semantic_fpn_r50_40k_shift_all_domains, + semantic_fpn_r50_160k_shift, + semantic_fpn_r50_160k_shift_all_domains, +) + +AVAILABLE_MODELS = { + "faster_rcnn_r50_6e_shift_all_domains": faster_rcnn_r50_6e_shift_all_domains, # pylint: disable=line-too-long + "faster_rcnn_r50_12e_shift": faster_rcnn_r50_12e_shift, + "faster_rcnn_r50_36e_shift": faster_rcnn_r50_36e_shift, + "mask_rcnn_r50_6e_shift_all_domains": mask_rcnn_r50_6e_shift_all_domains, + "mask_rcnn_r50_12e_shift": mask_rcnn_r50_12e_shift, + "mask_rcnn_r50_36e_shift": mask_rcnn_r50_36e_shift, + "semantic_fpn_r50_40k_shift_all_domains": semantic_fpn_r50_40k_shift_all_domains, # pylint: disable=line-too-long + "semantic_fpn_r50_40k_shift": semantic_fpn_r50_40k_shift, + "semantic_fpn_r50_160k_shift_all_domains": semantic_fpn_r50_160k_shift_all_domains, # pylint: disable=line-too-long + "semantic_fpn_r50_160k_shift": semantic_fpn_r50_160k_shift, +} diff --git a/vis4d/zoo/vit/__init__.py b/vis4d/zoo/vit/__init__.py index 9dd667df9..59c0b0ef7 100644 --- a/vis4d/zoo/vit/__init__.py +++ b/vis4d/zoo/vit/__init__.py @@ -1 +1,8 @@ """ViT for image classification configs.""" + +from . import vit_small_imagenet, vit_tiny_imagenet + +AVAILABLE_MODELS = { + "vit_small_imagenet": vit_small_imagenet, + "vit_tiny_imagenet": vit_tiny_imagenet, +} diff --git a/vis4d/zoo/yolox/__init__.py b/vis4d/zoo/yolox/__init__.py index be3d91ce6..17ba54d4d 100644 --- a/vis4d/zoo/yolox/__init__.py +++ b/vis4d/zoo/yolox/__init__.py @@ -1 +1,8 @@ """YOLOX Model Zoo.""" + +from . import yolox_s_300e_coco, yolox_tiny_300e_coco + +AVAILABLE_MODELS = { + "yolox_s_300e_coco": yolox_s_300e_coco, + "yolox_tiny_300e_coco": yolox_tiny_300e_coco, +}