diff --git a/src/omnipy/data/dataset.py b/src/omnipy/data/dataset.py index 6673a380..80bf9ee4 100644 --- a/src/omnipy/data/dataset.py +++ b/src/omnipy/data/dataset.py @@ -15,7 +15,7 @@ # from orjson import orjson from pydantic import Field, PrivateAttr, root_validator, ValidationError -from pydantic.fields import Undefined +from pydantic.fields import Undefined, UndefinedType from pydantic.generics import GenericModel from pydantic.typing import display_as_type from pydantic.utils import lenient_issubclass @@ -120,9 +120,13 @@ def __class_getitem__(cls, model: ModelT) -> ModelT: return created_dataset - def __init__(self, - value: Union[Dict[str, Any], Iterator[Tuple[str, Any]]] = Undefined, - **input_data: Any) -> None: + def __init__( + self, + value: dict[str, object] | Iterator[tuple[str, object]] | UndefinedType = Undefined, + *, + data: dict[str, object] | UndefinedType = Undefined, + **input_data: object, + ) -> None: # TODO: Error message when forgetting parenthesis when creating Dataset should be improved. # Unclear where this can be done, if anywhere? E.g.: # a = Dataset[Model[int]] @@ -136,13 +140,30 @@ def __init__(self, # Dataset[Model[str]](Model[int](5)) == Dataset[Model[str]](data=Model[int](5)) # == Dataset[Model[str]](data={'__root__': Model[str]('5')}) + super_kwargs = {} + + assert DATA_KEY not in input_data, \ + ('Not allowed with"data" as input_data key. Not sure how you managed this? Are you ' + 'trying to break Dataset init on purpose?') + if value != Undefined: - input_data[DATA_KEY] = value + assert data == Undefined, \ + 'Not allowed to combine positional and "data" keyword argument' + assert len(input_data) == 0, 'Not allowed to combine positional and keyword arguments' + super_kwargs[DATA_KEY] = value + + if data != Undefined: + assert len(input_data) == 0, \ + "Not allowed to combine 'data' with other keyword arguments" + super_kwargs[DATA_KEY] = data + + if DATA_KEY not in super_kwargs and len(input_data) > 0: + super_kwargs[DATA_KEY] = input_data if self.get_model_class() == ModelT: self._raise_no_model_exception() - GenericModel.__init__(self, **input_data) + GenericModel.__init__(self, **super_kwargs) UserDict.__init__(self, self.data) # noqa if not self.__doc__: self._set_standard_field_description() diff --git a/src/omnipy/data/model.py b/src/omnipy/data/model.py index 668967c6..e10711cb 100644 --- a/src/omnipy/data/model.py +++ b/src/omnipy/data/model.py @@ -244,7 +244,7 @@ def __new__(cls, value: Union[Any, UndefinedType] = Undefined, **kwargs): def __init__( self, value: Union[Any, UndefinedType] = Undefined, - /, + *, __root__: Union[Any, UndefinedType] = Undefined, **data: Any, ) -> None: @@ -263,7 +263,7 @@ def __init__( super_data[ROOT_KEY] = cast(RootT, data) num_root_vals += 1 - assert num_root_vals <= 1 + assert num_root_vals <= 1, 'Not allowed to provide root data in more than one argument' super().__init__(**super_data) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 61028ce6..9562052a 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -61,6 +61,29 @@ def test_init_with_basic_parsing(): assert dataset_3['obj_type_2'] == '123' assert dataset_3['obj_type_3'] == 'True' + dataset_4 = Dataset[Model[dict[int, int]]](file_1={1: 1234, 2: 2345}, file_2={2: 2345, 3: 3456}) + + assert len(dataset_4) == 2 + assert dataset_4['file_1'] == {1: 1234, 2: 2345} + assert dataset_4['file_2'] == {2: 2345, 3: 3456} + + +def test_init_errors(): + with pytest.raises(TypeError): + Dataset[Model[int]]({'file_1': 123}, {'file_2': 234}) + + with pytest.raises(AssertionError): + Dataset[Model[int]]({'file_1': 123}, data={'file_2': 234}) + + with pytest.raises(AssertionError): + Dataset[Model[int]]({'file_1': 123}, file_2=234) + + with pytest.raises(AssertionError): + Dataset[Model[int]](data={'file_1': 123}, file_2=234) + + with pytest.raises(ValidationError): + Dataset[Model[int]](data=123) + def test_parsing_none_allowed(): class NoneModel(Model[NoneType]): diff --git a/tests/data/test_model.py b/tests/data/test_model.py index 6475b854..7672058d 100644 --- a/tests/data/test_model.py +++ b/tests/data/test_model.py @@ -64,6 +64,21 @@ class StrModel(Model[str]): assert Model[Dict]((('a', 2), ('b', True))).to_data() == {'a': 2, 'b': True} +def test_error_init(): + with pytest.raises(TypeError): + assert Model[tuple[int, ...]](12, 2, 4).to_data() == 12 + assert Model[tuple[int, ...]]((12, 2, 4)).to_data() == (12, 2, 4) + + with pytest.raises(AssertionError): + Model[int](123, __root__=234) + + with pytest.raises(AssertionError): + Model[int](123, other=234) + + with pytest.raises(AssertionError): + Model[int](__root__=123, other=234) + + def test_load(): model = Model[int]() model.from_data(12)