Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config Registry #130

Merged
merged 15 commits into from
Aug 31, 2023
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ include = ["vis4d*"]
[project.scripts]
vis4d = "vis4d.engine.run:entrypoint"
vis4d-pl = "vis4d.pl.run:entrypoint"
vis4d-zoo = "vis4d.zoo.run:entrypoint"
10 changes: 9 additions & 1 deletion tests/common/dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import unittest

from vis4d.common.dict import get_dict_nested, set_dict_nested
from vis4d.common.dict import flatten_dict, get_dict_nested, set_dict_nested


class TestDictUtils(unittest.TestCase):
"""Test cases for array conversion ops."""

def test_flatten_dict(self) -> None:
"""Tests the flatten_dict function."""
d = {"a": {"b": {"c": 10}}}
self.assertEqual(flatten_dict(d, "."), ["a.b.c"])

d = {"a": {"b": {"c": 10, "d": 20}}}
self.assertEqual(flatten_dict(d, "/"), ["a/b/c", "a/b/d"])

def test_set_dict_nested(self) -> None:
"""Tests the set_dict_nested function."""
d = {} # type:ignore
Expand Down
71 changes: 71 additions & 0 deletions tests/config/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Test config registry."""
from __future__ import annotations

import unittest

import pytest

from tests.util import get_test_data
from vis4d.config.util.registry import get_config_by_name, register_config


class TestRegistry(unittest.TestCase):
"""Test the config registry."""

def test_yaml(self) -> None:
"""Test reading a yaml config file."""
file = get_test_data(
"config_test/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.yaml"
)

# Config can be resolved
config = get_config_by_name(file)
self.assertTrue(config is not None)

# Config does not exist
with pytest.raises(ValueError) as err:
config = get_config_by_name(file.replace("r50", "r91"))
self.assertTrue("Could not find" in str(err.value))

def test_py(self) -> None:
"""Test reading a py config file from the model zoo."""
file = "/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py"
cfg = get_config_by_name(file)
self.assertTrue(cfg is not None)

# Only by file name
file = "faster_rcnn_r50_1x_bdd100k.py"
cfg = get_config_by_name(file)
self.assertTrue(cfg is not None)

# Check did you mean message
file = "faster_rcnn_r90_1x_bdd100k"
with pytest.raises(ValueError) as err:
cfg = get_config_by_name(file)
self.assertTrue("faster_rcnn_r50_1x_bdd100k" in str(err.value))

def test_zoo(self) -> None:
"""Test reading a registered config from the zoo."""
config = get_config_by_name("faster_rcnn_r50_1x_bdd100k")
self.assertTrue(config is not None)

# Full Qualified Name
config = get_config_by_name("bdd100k/faster_rcnn_r50_1x_bdd100k")
self.assertTrue(config is not None)

# Check did you mean message
with pytest.raises(ValueError) as err:
config = get_config_by_name("faster_rcnn_r90_1x_bdd100k")
self.assertTrue("faster_rcnn_r50_1x_bdd100k" in str(err.value))

def test_decorator(self) -> None:
"""Test registering a config."""

@register_config("cat", "test") # type: ignore
def get_config() -> dict[str, str]:
"""Test config."""
return {"test": "test"}

config = get_config_by_name("cat/test")
self.assertTrue(config is not None)
self.assertEqual(config["test"], "test")
29 changes: 29 additions & 0 deletions vis4d/common/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@
from vis4d.common import DictStrAny


def flatten_dict(dictionary: DictStrAny, seperator: str) -> list[str]:
"""Flatten a nested dictionary.

Args:
dictionary (DictStrAny): The dictionary to flatten.
seperator (str): The seperator to use between keys.

Returns:
List[str]: A list of flattened keys.

Examples:
>>> d = {'a': {'b': {'c': 10}}}
>>> flatten_dict(d, '.')
['a.b.c']
"""
flattened = []
for key, value in dictionary.items():
if isinstance(value, dict):
flattened.extend(
[
f"{key}{seperator}{subkey}"
for subkey in flatten_dict(value, seperator)
]
)
else:
flattened.append(key)
return flattened


def get_dict_nested( # type: ignore
dictionary: DictStrAny, keys: list[str], allow_missing: bool = False
) -> Any:
Expand Down
3 changes: 3 additions & 0 deletions vis4d/common/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def package_available(package_name: str) -> bool:
OPEN3D_AVAILABLE = package_available("open3d")
PLOTLY_AVAILABLE = package_available("plotly")

# vis4d cuda ops
VIS4D_CUDA_OPS_AVAILABLE = package_available("vis4d_cuda_ops")

# logging
TENSORBOARD_AVAILABLE = package_available("tensorboardX") or package_available(
"tensorboard"
Expand Down
25 changes: 25 additions & 0 deletions vis4d/common/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for common usage."""
import random
from difflib import get_close_matches

import numpy as np
import torch
Expand All @@ -8,6 +9,30 @@
from .logging import rank_zero_warn


def create_did_you_mean_msg(keys: list[str], query: str) -> str:
"""Create a did you mean message.

Args:
keys (list[str]): List of available keys.
query (str): Query.

Returns:
str: Did you mean message.

Examples:
>>> keys = ["foo", "bar", "baz"]
>>> query = "fo"
>>> print(create_did_you_mean_msg(keys, query))
Did you mean:
foo
"""
msg = ""
if len(keys) > 0:
msg = "Did you mean:\n\t"
msg += "\n\t".join(get_close_matches(query, keys, cutoff=0.75))
return msg


def set_tf32(use_tf32: bool = False) -> None: # pragma: no cover
"""Set torch TF32.

Expand Down
Loading
Loading