From a004dbacb84b9399da557e426c90ef4d9d83dded Mon Sep 17 00:00:00 2001 From: Sveinung Gundersen Date: Sun, 22 Oct 2023 01:28:15 +0200 Subject: [PATCH] TMP: Fixed representation of Models with ForwardRef type arguments + dataclass -> BaseModel --- .../protocols/private/compute/job_creator.py | 4 +- src/omnipy/api/protocols/private/log.py | 3 +- src/omnipy/api/protocols/public/config.py | 6 ++- src/omnipy/api/protocols/public/hub.py | 5 ++- src/omnipy/config/engine.py | 12 +++--- src/omnipy/config/job.py | 10 +++-- src/omnipy/config/root_log.py | 10 +++-- src/omnipy/data/model.py | 41 +++++++++++++------ src/omnipy/hub/entry.py | 8 ++-- src/omnipy/hub/root_log.py | 4 +- src/omnipy/hub/runtime.py | 35 +++++++++------- src/omnipy/util/publisher.py | 13 ++++-- tests/data/test_model.py | 39 +++++++++++++++++- 13 files changed, 135 insertions(+), 55 deletions(-) diff --git a/src/omnipy/api/protocols/private/compute/job_creator.py b/src/omnipy/api/protocols/private/compute/job_creator.py index 8136063a..c88d1be8 100644 --- a/src/omnipy/api/protocols/private/compute/job_creator.py +++ b/src/omnipy/api/protocols/private/compute/job_creator.py @@ -1,13 +1,14 @@ from __future__ import annotations from datetime import datetime -from typing import Optional, Protocol +from typing import Optional, Protocol, runtime_checkable from omnipy.api.protocols.private.compute.mixins import IsNestedContext from omnipy.api.protocols.private.config import IsJobConfigBase from omnipy.api.protocols.private.engine import IsEngine +@runtime_checkable class IsJobConfigHolder(Protocol): """""" @property @@ -25,6 +26,7 @@ def set_engine(self, engine: IsEngine) -> None: ... +@runtime_checkable class IsJobCreator(IsNestedContext, IsJobConfigHolder, Protocol): """""" @property diff --git a/src/omnipy/api/protocols/private/log.py b/src/omnipy/api/protocols/private/log.py index cc29932e..aacf5552 100644 --- a/src/omnipy/api/protocols/private/log.py +++ b/src/omnipy/api/protocols/private/log.py @@ -2,7 +2,7 @@ from datetime import datetime from logging import INFO, Logger -from typing import Optional, Protocol, Tuple +from typing import Optional, Protocol, runtime_checkable, Tuple from omnipy.api.enums import RunState from omnipy.api.protocols.private.compute.mixins import IsUniquelyNamedJob @@ -18,6 +18,7 @@ def log(self, log_msg: str, level: int = INFO, datetime_obj: Optional[datetime] ... +@runtime_checkable class IsRunStateRegistry(Protocol): """""" def __init__(self) -> None: diff --git a/src/omnipy/api/protocols/public/config.py b/src/omnipy/api/protocols/public/config.py index a7b67b54..035f13b5 100644 --- a/src/omnipy/api/protocols/public/config.py +++ b/src/omnipy/api/protocols/public/config.py @@ -1,27 +1,31 @@ from __future__ import annotations -from typing import Protocol +from typing import Protocol, runtime_checkable from omnipy.api.protocols.private.config import IsJobConfigBase from omnipy.api.protocols.private.engine import IsEngineConfig from omnipy.api.types import LocaleType +@runtime_checkable class IsLocalRunnerConfig(IsEngineConfig, Protocol): """""" ... +@runtime_checkable class IsPrefectEngineConfig(IsEngineConfig, Protocol): """""" use_cached_results: int = False +@runtime_checkable class IsJobConfig(IsJobConfigBase, Protocol): """""" ... +@runtime_checkable class IsRootLogConfig(Protocol): """""" log_format_str: str diff --git a/src/omnipy/api/protocols/public/hub.py b/src/omnipy/api/protocols/public/hub.py index c4329eac..1b6e2274 100644 --- a/src/omnipy/api/protocols/public/hub.py +++ b/src/omnipy/api/protocols/public/hub.py @@ -2,7 +2,7 @@ import logging from logging.handlers import TimedRotatingFileHandler -from typing import Optional, Protocol +from typing import Optional, Protocol, runtime_checkable from omnipy.api.enums import EngineChoice from omnipy.api.protocols.private.compute.job_creator import IsJobConfigHolder @@ -14,6 +14,7 @@ IsRootLogConfig) +@runtime_checkable class IsRootLogObjects(Protocol): """""" formatter: Optional[logging.Formatter] = None @@ -25,6 +26,7 @@ def set_config(self, config: IsRootLogConfig) -> None: ... +@runtime_checkable class IsRuntimeConfig(Protocol): """""" job: IsJobConfig @@ -45,6 +47,7 @@ def __init__( ... +@runtime_checkable class IsRuntimeObjects(Protocol): """""" diff --git a/src/omnipy/config/engine.py b/src/omnipy/config/engine.py index 4a3eb32e..fafd1407 100644 --- a/src/omnipy/config/engine.py +++ b/src/omnipy/config/engine.py @@ -1,16 +1,18 @@ -from dataclasses import dataclass +# from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class EngineConfig: + +# @dataclass +class EngineConfig(BaseModel): ... -@dataclass +# @dataclass class LocalRunnerConfig(EngineConfig): ... -@dataclass +# @dataclass class PrefectEngineConfig(EngineConfig): use_cached_results: int = False diff --git a/src/omnipy/config/job.py b/src/omnipy/config/job.py index cbe0ae3d..82ad4a47 100644 --- a/src/omnipy/config/job.py +++ b/src/omnipy/config/job.py @@ -1,7 +1,9 @@ -from dataclasses import dataclass, field +# from dataclasses import field # from datetime import datetime from pathlib import Path +from pydantic import BaseModel, Field + from omnipy.api.enums import ConfigPersistOutputsOptions, ConfigRestoreOutputsOptions # from typing import Optional @@ -11,10 +13,10 @@ def _get_persist_data_dir_path() -> str: return str(Path.cwd().joinpath(Path('data'))) -@dataclass -class JobConfig: +# @dataclass +class JobConfig(BaseModel): persist_outputs: ConfigPersistOutputsOptions = \ ConfigPersistOutputsOptions.ENABLE_FLOW_AND_TASK_OUTPUTS restore_outputs: ConfigRestoreOutputsOptions = \ ConfigRestoreOutputsOptions.DISABLED - persist_data_dir_path: str = field(default_factory=_get_persist_data_dir_path) + persist_data_dir_path: str = Field(default_factory=_get_persist_data_dir_path) diff --git a/src/omnipy/config/root_log.py b/src/omnipy/config/root_log.py index c7f410db..e87e5105 100644 --- a/src/omnipy/config/root_log.py +++ b/src/omnipy/config/root_log.py @@ -1,8 +1,10 @@ -from dataclasses import dataclass, field +# from dataclasses import field import locale as pkg_locale import logging from pathlib import Path +from pydantic import BaseModel, Field + from omnipy.api.types import LocaleType @@ -10,8 +12,8 @@ def _get_log_dir_path() -> str: return str(Path.cwd().joinpath(Path('logs'))) -@dataclass -class RootLogConfig: +# @dataclass +class RootLogConfig(BaseModel): log_format_str: str = '{engine} {asctime} - {levelname}: {message} [{name}]' locale: LocaleType = pkg_locale.getlocale() log_to_stdout: bool = True @@ -20,4 +22,4 @@ class RootLogConfig: stdout_log_min_level: int = logging.INFO stderr_log_min_level: int = logging.ERROR file_log_min_level: int = logging.WARNING - file_log_dir_path: str = field(default_factory=_get_log_dir_path) + file_log_dir_path: str = Field(default_factory=_get_log_dir_path) diff --git a/src/omnipy/data/model.py b/src/omnipy/data/model.py index 08ea625d..3ed7fbfb 100644 --- a/src/omnipy/data/model.py +++ b/src/omnipy/data/model.py @@ -5,28 +5,37 @@ from isort import place_module from isort.sections import STDLIB # from orjson import orjson -from pydantic import Protocol, root_validator -from pydantic.fields import ModelField, Undefined, UndefinedType -from pydantic.generics import GenericModel +from pydantic import Protocol as pydantic_protocol +from pydantic import root_validator as pydantic_root_validator +from pydantic.fields import ModelField as PydanticModelField +from pydantic.fields import Undefined as PydanticUndefined +from pydantic.fields import UndefinedType as PydanticUndefinedType +from pydantic.generics import GenericModel as PydanticGenericModel from pydantic.typing import display_as_type RootT = TypeVar('RootT') ROOT_KEY = '__root__' +Undefined = PydanticUndefined +UndefinedType = PydanticUndefinedType + # def orjson_dumps(v, *, default): # # orjson.dumps returns bytes, to match standard json.dumps we need to decode # return orjson.dumps(v, default=default).decode() def generate_qualname(cls_name: str, model: Any) -> str: - m_module = model.__module__ if hasattr(model, '__module__') else '' - m_module_prefix = f'{m_module}.' \ - if m_module and place_module(m_module) != STDLIB else '' - fully_qual_model_name = f"{m_module_prefix}{display_as_type(model)}" + if isinstance(model, str): # ForwardRef + fully_qual_model_name = model + else: + m_module = model.__module__ if hasattr(model, '__module__') else '' + m_module_prefix = f'{m_module}.' \ + if m_module and place_module(m_module) != STDLIB else '' + fully_qual_model_name = f'{m_module_prefix}{display_as_type(model)}' return f'{cls_name}[{fully_qual_model_name}]' -class Model(GenericModel, Generic[RootT]): +class Model(PydanticGenericModel, Generic[RootT]): """ A data model containing a value parsed according to the model. @@ -110,7 +119,7 @@ def get_default_val() -> RootT: else: cls.__config__.fields[ROOT_KEY] = {'default_factory': get_default_val} - data_field = ModelField.infer( + data_field = PydanticModelField.infer( name=ROOT_KEY, value=Undefined, annotation=model, @@ -153,6 +162,9 @@ def __class_getitem__(cls, model: Union[Type[RootT], TypeVar]) -> Union[Type[Roo cls._depopulate_root_field() created_model.__qualname__ = generate_qualname(cls.__name__, model) + if isinstance(model, str): # ForwardRef + created_model.__name__ = f'{cls.__name__}[{model}]' + created_model.__module__ = cls.__module__ return created_model @@ -212,7 +224,7 @@ def _get_standard_field_description(cls) -> str: def _parse_data(cls, data: RootT) -> Any: return data - @root_validator + @pydantic_root_validator def _parse_root_object(cls, root_obj: RootT) -> Any: # noqa if ROOT_KEY not in root_obj: return root_obj @@ -241,7 +253,7 @@ def to_json(self, pretty=False) -> str: return json_content def from_json(self, json_contents: str) -> None: - new_model = self.parse_raw(json_contents, proto=Protocol.json) + new_model = self.parse_raw(json_contents, proto=pydantic_protocol.json) self._set_contents_without_validation(new_model) # @classmethod @@ -288,7 +300,7 @@ def _check_for_root_key(self) -> None: '\t"class MyNumberList(Model[List[int]]): ..."') def __setattr__(self, attr: str, value: Any) -> None: - if attr in self.__dict__ and attr not in [ROOT_KEY]: + if attr in ['__module__'] + list(self.__dict__.keys()) and attr not in [ROOT_KEY]: super().__setattr__(attr, value) else: if attr in ['contents']: @@ -297,6 +309,9 @@ def __setattr__(self, attr: str, value: Any) -> None: else: raise RuntimeError('Model does not allow setting of extra attributes') - # TODO: Update Dataset.__eq__ similarly, with tests def __eq__(self, other: object) -> bool: return self.__class__ == other.__class__ and super().__eq__(other) + + def __repr__(self) -> str: + super_repr = super().__repr__() + return super_repr.replace(f'{ROOT_KEY}=', '') diff --git a/src/omnipy/hub/entry.py b/src/omnipy/hub/entry.py index 9ec091e2..89df747d 100644 --- a/src/omnipy/hub/entry.py +++ b/src/omnipy/hub/entry.py @@ -1,13 +1,15 @@ -from dataclasses import dataclass, field +# from dataclasses import dataclass, field from typing import Optional +from pydantic import Field + from omnipy.api.protocols.public.hub import IsRuntime from omnipy.util.publisher import DataPublisher -@dataclass +# @dataclass class RuntimeEntryPublisher(DataPublisher): - _back: Optional[IsRuntime] = field(default=None, init=False, repr=False) + _back: Optional[IsRuntime] = Field(default=None, init=False, repr=False) def __setattr__(self, key, value): super().__setattr__(key, value) diff --git a/src/omnipy/hub/root_log.py b/src/omnipy/hub/root_log.py index 44bd87dc..a5d0a62e 100644 --- a/src/omnipy/hub/root_log.py +++ b/src/omnipy/hub/root_log.py @@ -13,12 +13,12 @@ from omnipy.util.helpers import get_datetime_format -@dataclass +# @dataclass class RootLogConfigEntryPublisher(RootLogConfig, RuntimeEntryPublisher): ... -@dataclass +# @dataclass class RootLogObjects: _config: IsRootLogConfig = field( init=False, repr=False, default_factory=RootLogConfigEntryPublisher) diff --git a/src/omnipy/hub/runtime.py b/src/omnipy/hub/runtime.py index e87e5d0f..e793c27b 100644 --- a/src/omnipy/hub/runtime.py +++ b/src/omnipy/hub/runtime.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass, field +# from dataclasses import dataclass, field from typing import Any +from pydantic import Field + from omnipy.api.enums import EngineChoice from omnipy.api.protocols.private.compute.job_creator import IsJobConfigHolder from omnipy.api.protocols.private.engine import IsEngine, IsEngineConfig @@ -28,28 +30,31 @@ def _job_creator_factory(): return JobBase.job_creator -@dataclass +# @dataclass class RuntimeConfig(RuntimeEntryPublisher): - job: IsJobConfig = field(default_factory=JobConfig) + job: IsJobConfig = Field(default_factory=JobConfig) engine: EngineChoice = EngineChoice.LOCAL - local: IsLocalRunnerConfig = field(default_factory=LocalRunnerConfig) - prefect: IsPrefectEngineConfig = field(default_factory=PrefectEngineConfig) - root_log: IsRootLogConfig = field(default_factory=RootLogConfigEntryPublisher) + local: IsLocalRunnerConfig = Field(default_factory=LocalRunnerConfig) + prefect: IsPrefectEngineConfig = Field(default_factory=PrefectEngineConfig) + root_log: IsRootLogConfig = Field(default_factory=RootLogConfigEntryPublisher) -@dataclass +# @dataclass class RuntimeObjects(RuntimeEntryPublisher): - job_creator: IsJobConfigHolder = field(default_factory=_job_creator_factory) - local: IsEngine = field(default_factory=LocalRunner) - prefect: IsEngine = field(default_factory=PrefectEngine) - registry: IsRunStateRegistry = field(default_factory=RunStateRegistry) - root_log: IsRootLogObjects = field(default_factory=RootLogObjects) + job_creator: IsJobConfigHolder = Field(default_factory=_job_creator_factory) + local: IsEngine = Field(default_factory=LocalRunner) + prefect: IsEngine = Field(default_factory=PrefectEngine) + registry: IsRunStateRegistry = Field(default_factory=RunStateRegistry) + root_log: IsRootLogObjects = Field(default_factory=RootLogObjects) + + +# TODO: Add automatic parsing of config values when setting to string values -@dataclass +# @dataclass class Runtime(DataPublisher): - config: IsRuntimeConfig = field(default_factory=RuntimeConfig) - objects: IsRuntimeObjects = field(default_factory=RuntimeObjects) + config: IsRuntimeConfig = Field(default_factory=RuntimeConfig) + objects: IsRuntimeObjects = Field(default_factory=RuntimeObjects) def __post_init__(self): super().__init__() diff --git a/src/omnipy/util/publisher.py b/src/omnipy/util/publisher.py index 74f973fd..68eed27f 100644 --- a/src/omnipy/util/publisher.py +++ b/src/omnipy/util/publisher.py @@ -1,16 +1,21 @@ from collections import defaultdict -from dataclasses import dataclass, field +# from dataclasses import dataclass, field from typing import Any, Callable, DefaultDict, List +from pydantic import BaseModel, Field + def _subscribers_factory(): return defaultdict(list) -@dataclass -class DataPublisher: +# @dataclass +class DataPublisher(BaseModel): _subscriptions: DefaultDict[str, List[Callable[[Any], None]]] = \ - field(default_factory=_subscribers_factory, init=False, repr=False) + Field(default_factory=_subscribers_factory, init=False, repr=False) + + class Config: + arbitrary_types_allowed = True def subscribe(self, config_item: str, callback_fun: Callable[[Any], None]): if not hasattr(self, config_item): diff --git a/tests/data/test_model.py b/tests/data/test_model.py index 5d817989..015f490b 100644 --- a/tests/data/test_model.py +++ b/tests/data/test_model.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import os from types import NoneType -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeAlias, TypeVar, Union from pydantic import BaseModel, PositiveInt, StrictInt, ValidationError import pytest @@ -120,6 +122,41 @@ class D(A): assert A(a=1) != D(a=1) +ChildT = TypeVar('ChildT', bound=object) + + +class ParentGenericModel(Model[Optional[ChildT]], Generic[ChildT]): + ... + + +ParentModel: TypeAlias = ParentGenericModel['NumberModel'] +ParentModelNested: TypeAlias = ParentGenericModel[Union[str, 'NumberModel']] + + +class NumberModel(Model[int]): + ... + + +ParentModel.update_forward_refs(NumberModel=NumberModel) +ParentModelNested.update_forward_refs(NumberModel=NumberModel) + + +def test_repr(): + assert repr(Model[int]) == "" + assert repr(Model[int](5)) == 'Model[int](5)' + + assert repr( + Model[Model[int]]) == "" + assert repr(Model[Model[int]](Model[int](5))) == 'Model[Model[int]](Model[int](5))' + + assert repr(ParentModel) == "" + assert repr(ParentModel(NumberModel(5))) == 'ParentGenericModel[NumberModel](NumberModel(5))' + + assert repr(ParentModelNested + ) == "" + assert repr(ParentModelNested('abc')) == "ParentGenericModel[Union[str, NumberModel]]('abc')" + + def _issubclass_and_isinstance(model_cls_a: Type[Model], model_cls_b: Type[Model]) -> bool: is_subclass = issubclass(model_cls_a, model_cls_b)