Skip to content

Commit

Permalink
Improved __init__ methods of Dataset and Model
Browse files Browse the repository at this point in the history
  • Loading branch information
sveinugu committed Dec 10, 2023
1 parent a348af9 commit 41158fe
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
33 changes: 27 additions & 6 deletions src/omnipy/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# from orjson import orjson
from pydantic import Field, PrivateAttr, root_validator, ValidationError
from pydantic.fields import Undefined
from pydantic.fields import Undefined, UndefinedType
from pydantic.generics import GenericModel
from pydantic.typing import display_as_type
from pydantic.utils import lenient_issubclass
Expand Down Expand Up @@ -120,9 +120,13 @@ def __class_getitem__(cls, model: ModelT) -> ModelT:

return created_dataset

def __init__(self,
value: Union[Dict[str, Any], Iterator[Tuple[str, Any]]] = Undefined,
**input_data: Any) -> None:
def __init__(
self,
value: dict[str, object] | Iterator[tuple[str, object]] | UndefinedType = Undefined,
*,
data: dict[str, object] | UndefinedType = Undefined,
**input_data: object,
) -> None:
# TODO: Error message when forgetting parenthesis when creating Dataset should be improved.
# Unclear where this can be done, if anywhere? E.g.:
# a = Dataset[Model[int]]
Expand All @@ -136,13 +140,30 @@ def __init__(self,
# Dataset[Model[str]](Model[int](5)) == Dataset[Model[str]](data=Model[int](5))
# == Dataset[Model[str]](data={'__root__': Model[str]('5')})

super_kwargs = {}

assert DATA_KEY not in input_data, \
('Not allowed with"data" as input_data key. Not sure how you managed this? Are you '
'trying to break Dataset init on purpose?')

if value != Undefined:
input_data[DATA_KEY] = value
assert data == Undefined, \
'Not allowed to combine positional and "data" keyword argument'
assert len(input_data) == 0, 'Not allowed to combine positional and keyword arguments'
super_kwargs[DATA_KEY] = value

if data != Undefined:
assert len(input_data) == 0, \
"Not allowed to combine 'data' with other keyword arguments"
super_kwargs[DATA_KEY] = data

if DATA_KEY not in super_kwargs and len(input_data) > 0:
super_kwargs[DATA_KEY] = input_data

if self.get_model_class() == ModelT:
self._raise_no_model_exception()

GenericModel.__init__(self, **input_data)
GenericModel.__init__(self, **super_kwargs)
UserDict.__init__(self, self.data) # noqa
if not self.__doc__:
self._set_standard_field_description()
Expand Down
4 changes: 2 additions & 2 deletions src/omnipy/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __new__(cls, value: Union[Any, UndefinedType] = Undefined, **kwargs):
def __init__(
self,
value: Union[Any, UndefinedType] = Undefined,
/,
*,
__root__: Union[Any, UndefinedType] = Undefined,
**data: Any,
) -> None:
Expand All @@ -263,7 +263,7 @@ def __init__(
super_data[ROOT_KEY] = cast(RootT, data)
num_root_vals += 1

assert num_root_vals <= 1
assert num_root_vals <= 1, 'Not allowed to provide root data in more than one argument'

super().__init__(**super_data)

Expand Down
23 changes: 23 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,29 @@ def test_init_with_basic_parsing():
assert dataset_3['obj_type_2'] == '123'
assert dataset_3['obj_type_3'] == 'True'

dataset_4 = Dataset[Model[dict[int, int]]](file_1={1: 1234, 2: 2345}, file_2={2: 2345, 3: 3456})

assert len(dataset_4) == 2
assert dataset_4['file_1'] == {1: 1234, 2: 2345}
assert dataset_4['file_2'] == {2: 2345, 3: 3456}


def test_init_errors():
with pytest.raises(TypeError):
Dataset[Model[int]]({'file_1': 123}, {'file_2': 234})

with pytest.raises(AssertionError):
Dataset[Model[int]]({'file_1': 123}, data={'file_2': 234})

with pytest.raises(AssertionError):
Dataset[Model[int]]({'file_1': 123}, file_2=234)

with pytest.raises(AssertionError):
Dataset[Model[int]](data={'file_1': 123}, file_2=234)

with pytest.raises(ValidationError):
Dataset[Model[int]](data=123)


def test_parsing_none_allowed():
class NoneModel(Model[NoneType]):
Expand Down
15 changes: 15 additions & 0 deletions tests/data/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ class StrModel(Model[str]):
assert Model[Dict]((('a', 2), ('b', True))).to_data() == {'a': 2, 'b': True}


def test_error_init():
with pytest.raises(TypeError):
assert Model[tuple[int, ...]](12, 2, 4).to_data() == 12
assert Model[tuple[int, ...]]((12, 2, 4)).to_data() == (12, 2, 4)

with pytest.raises(AssertionError):
Model[int](123, __root__=234)

with pytest.raises(AssertionError):
Model[int](123, other=234)

with pytest.raises(AssertionError):
Model[int](__root__=123, other=234)


def test_load():
model = Model[int]()
model.from_data(12)
Expand Down

0 comments on commit 41158fe

Please sign in to comment.