Skip to content

Commit

Permalink
Add registry decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
renezurbruegg committed Aug 30, 2023
1 parent 1c31af6 commit 889bca3
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/config/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.util import get_test_data
from vis4d.config.util.registry import get_config_by_name
from vis4d.zoo import register_config


class TestRegistry(unittest.TestCase):
Expand Down Expand Up @@ -57,3 +58,15 @@ def test_zoo(self) -> None:
with pytest.raises(ValueError) as err:
config = get_config_by_name("faster_rcnn_r90_1x_bdd100k")
self.assertTrue("faster_rcnn_r50_1x_bdd100k" in str(err.value))

def test_decorator(self) -> None:
"""Test registering a config."""

@register_config("cat", "test") # type: ignore
def get_config() -> dict[str, str]:
"""Test config."""
return {"test": "test"}

config = get_config_by_name("cat/test")
self.assertTrue(config is not None)
self.assertEqual(config["test"], "test")
3 changes: 2 additions & 1 deletion vis4d/config/util/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ def _get_registered_configs(
models = flatten_dict(AVAILABLE_MODELS, os.path.sep)
# check if there is an absolute match for the config
if config_name in models:
return get_dict_nested(
module = get_dict_nested(
AVAILABLE_MODELS, config_name.split(os.path.sep)
)
return getattr(module, method_name)(*args)
# check if there is a partial match for the config
matches = {}
for model in models:
Expand Down
46 changes: 46 additions & 0 deletions vis4d/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Model Zoo."""
from __future__ import annotations

from typing import Callable

from vis4d.common.typing import ArgsType

from .bdd100k import AVAILABLE_MODELS as BDD100K_MODELS
from .bevformer import AVAILABLE_MODELS as BEVFORMER_MODELS
from .cc_3dt import AVAILABLE_MODELS as CC_3DT_MODELS
Expand All @@ -11,6 +17,8 @@
from .vit import AVAILABLE_MODELS as VIT_MODELS
from .yolox import AVAILABLE_MODELS as YOLOX_MODELS

TFunc = Callable[..., ArgsType]

AVAILABLE_MODELS = {
"bdd100k": BDD100K_MODELS,
"cc_3dt": CC_3DT_MODELS,
Expand All @@ -24,3 +32,41 @@
"vit": VIT_MODELS,
"yolox": YOLOX_MODELS,
}


def register_config(category: str, name: str) -> Callable[TFunc, None]:
"""Register a config in the model zoo for the given name and category.
The config will then be available via `get_config_by_name` utilities and
located in the AVAILABLE_MODELS dictionary located at
[category][name].
Args:
category: Category of the config.
name: Name of the config.
Returns:
The decorator.
"""

def decorator(fnc_or_clazz: TFunc) -> None:
module = fnc_or_clazz

if callable(fnc_or_clazz):
# Directly annotated get_config function. Wrap it and register it.
class Wrapper:
"""Wrapper class."""

def get_config(self, *args, **kwargs) -> ArgsType:
"""Resolves the get_config function."""
return fnc_or_clazz(*args, **kwargs)

module = Wrapper()

# Register the config
if category not in AVAILABLE_MODELS:
AVAILABLE_MODELS[category] = {}

AVAILABLE_MODELS[category][name] = module

return decorator

0 comments on commit 889bca3

Please sign in to comment.