Skip to content

Commit

Permalink
TMP: Fixed representation of Models with ForwardRef type arguments + …
Browse files Browse the repository at this point in the history
…dataclass -> BaseModel
  • Loading branch information
sveinugu committed Oct 21, 2023
1 parent 0078f88 commit a004dba
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 55 deletions.
4 changes: 3 additions & 1 deletion src/omnipy/api/protocols/private/compute/job_creator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from datetime import datetime
from typing import Optional, Protocol
from typing import Optional, Protocol, runtime_checkable

from omnipy.api.protocols.private.compute.mixins import IsNestedContext
from omnipy.api.protocols.private.config import IsJobConfigBase
from omnipy.api.protocols.private.engine import IsEngine


@runtime_checkable
class IsJobConfigHolder(Protocol):
""""""
@property
Expand All @@ -25,6 +26,7 @@ def set_engine(self, engine: IsEngine) -> None:
...


@runtime_checkable
class IsJobCreator(IsNestedContext, IsJobConfigHolder, Protocol):
""""""
@property
Expand Down
3 changes: 2 additions & 1 deletion src/omnipy/api/protocols/private/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import datetime
from logging import INFO, Logger
from typing import Optional, Protocol, Tuple
from typing import Optional, Protocol, runtime_checkable, Tuple

from omnipy.api.enums import RunState
from omnipy.api.protocols.private.compute.mixins import IsUniquelyNamedJob
Expand All @@ -18,6 +18,7 @@ def log(self, log_msg: str, level: int = INFO, datetime_obj: Optional[datetime]
...


@runtime_checkable
class IsRunStateRegistry(Protocol):
""""""
def __init__(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/omnipy/api/protocols/public/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
from __future__ import annotations

from typing import Protocol
from typing import Protocol, runtime_checkable

from omnipy.api.protocols.private.config import IsJobConfigBase
from omnipy.api.protocols.private.engine import IsEngineConfig
from omnipy.api.types import LocaleType


@runtime_checkable
class IsLocalRunnerConfig(IsEngineConfig, Protocol):
""""""
...


@runtime_checkable
class IsPrefectEngineConfig(IsEngineConfig, Protocol):
""""""
use_cached_results: int = False


@runtime_checkable
class IsJobConfig(IsJobConfigBase, Protocol):
""""""
...


@runtime_checkable
class IsRootLogConfig(Protocol):
""""""
log_format_str: str
Expand Down
5 changes: 4 additions & 1 deletion src/omnipy/api/protocols/public/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from logging.handlers import TimedRotatingFileHandler
from typing import Optional, Protocol
from typing import Optional, Protocol, runtime_checkable

from omnipy.api.enums import EngineChoice
from omnipy.api.protocols.private.compute.job_creator import IsJobConfigHolder
Expand All @@ -14,6 +14,7 @@
IsRootLogConfig)


@runtime_checkable
class IsRootLogObjects(Protocol):
""""""
formatter: Optional[logging.Formatter] = None
Expand All @@ -25,6 +26,7 @@ def set_config(self, config: IsRootLogConfig) -> None:
...


@runtime_checkable
class IsRuntimeConfig(Protocol):
""""""
job: IsJobConfig
Expand All @@ -45,6 +47,7 @@ def __init__(
...


@runtime_checkable
class IsRuntimeObjects(Protocol):
""""""

Expand Down
12 changes: 7 additions & 5 deletions src/omnipy/config/engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from dataclasses import dataclass
# from dataclasses import dataclass

from pydantic import BaseModel

@dataclass
class EngineConfig:

# @dataclass
class EngineConfig(BaseModel):
...


@dataclass
# @dataclass
class LocalRunnerConfig(EngineConfig):
...


@dataclass
# @dataclass
class PrefectEngineConfig(EngineConfig):
use_cached_results: int = False
10 changes: 6 additions & 4 deletions src/omnipy/config/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
# from dataclasses import field
# from datetime import datetime
from pathlib import Path

from pydantic import BaseModel, Field

from omnipy.api.enums import ConfigPersistOutputsOptions, ConfigRestoreOutputsOptions

# from typing import Optional
Expand All @@ -11,10 +13,10 @@ def _get_persist_data_dir_path() -> str:
return str(Path.cwd().joinpath(Path('data')))


@dataclass
class JobConfig:
# @dataclass
class JobConfig(BaseModel):
persist_outputs: ConfigPersistOutputsOptions = \
ConfigPersistOutputsOptions.ENABLE_FLOW_AND_TASK_OUTPUTS
restore_outputs: ConfigRestoreOutputsOptions = \
ConfigRestoreOutputsOptions.DISABLED
persist_data_dir_path: str = field(default_factory=_get_persist_data_dir_path)
persist_data_dir_path: str = Field(default_factory=_get_persist_data_dir_path)
10 changes: 6 additions & 4 deletions src/omnipy/config/root_log.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from dataclasses import dataclass, field
# from dataclasses import field
import locale as pkg_locale
import logging
from pathlib import Path

from pydantic import BaseModel, Field

from omnipy.api.types import LocaleType


def _get_log_dir_path() -> str:
return str(Path.cwd().joinpath(Path('logs')))


@dataclass
class RootLogConfig:
# @dataclass
class RootLogConfig(BaseModel):
log_format_str: str = '{engine} {asctime} - {levelname}: {message} [{name}]'
locale: LocaleType = pkg_locale.getlocale()
log_to_stdout: bool = True
Expand All @@ -20,4 +22,4 @@ class RootLogConfig:
stdout_log_min_level: int = logging.INFO
stderr_log_min_level: int = logging.ERROR
file_log_min_level: int = logging.WARNING
file_log_dir_path: str = field(default_factory=_get_log_dir_path)
file_log_dir_path: str = Field(default_factory=_get_log_dir_path)
41 changes: 28 additions & 13 deletions src/omnipy/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,37 @@
from isort import place_module
from isort.sections import STDLIB
# from orjson import orjson
from pydantic import Protocol, root_validator
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.generics import GenericModel
from pydantic import Protocol as pydantic_protocol
from pydantic import root_validator as pydantic_root_validator
from pydantic.fields import ModelField as PydanticModelField
from pydantic.fields import Undefined as PydanticUndefined
from pydantic.fields import UndefinedType as PydanticUndefinedType
from pydantic.generics import GenericModel as PydanticGenericModel
from pydantic.typing import display_as_type

RootT = TypeVar('RootT')
ROOT_KEY = '__root__'

Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType

# def orjson_dumps(v, *, default):
# # orjson.dumps returns bytes, to match standard json.dumps we need to decode
# return orjson.dumps(v, default=default).decode()


def generate_qualname(cls_name: str, model: Any) -> str:
m_module = model.__module__ if hasattr(model, '__module__') else ''
m_module_prefix = f'{m_module}.' \
if m_module and place_module(m_module) != STDLIB else ''
fully_qual_model_name = f"{m_module_prefix}{display_as_type(model)}"
if isinstance(model, str): # ForwardRef
fully_qual_model_name = model
else:
m_module = model.__module__ if hasattr(model, '__module__') else ''
m_module_prefix = f'{m_module}.' \
if m_module and place_module(m_module) != STDLIB else ''
fully_qual_model_name = f'{m_module_prefix}{display_as_type(model)}'
return f'{cls_name}[{fully_qual_model_name}]'


class Model(GenericModel, Generic[RootT]):
class Model(PydanticGenericModel, Generic[RootT]):
"""
A data model containing a value parsed according to the model.
Expand Down Expand Up @@ -110,7 +119,7 @@ def get_default_val() -> RootT:
else:
cls.__config__.fields[ROOT_KEY] = {'default_factory': get_default_val}

data_field = ModelField.infer(
data_field = PydanticModelField.infer(
name=ROOT_KEY,
value=Undefined,
annotation=model,
Expand Down Expand Up @@ -153,6 +162,9 @@ def __class_getitem__(cls, model: Union[Type[RootT], TypeVar]) -> Union[Type[Roo
cls._depopulate_root_field()

created_model.__qualname__ = generate_qualname(cls.__name__, model)
if isinstance(model, str): # ForwardRef
created_model.__name__ = f'{cls.__name__}[{model}]'
created_model.__module__ = cls.__module__

return created_model

Expand Down Expand Up @@ -212,7 +224,7 @@ def _get_standard_field_description(cls) -> str:
def _parse_data(cls, data: RootT) -> Any:
return data

@root_validator
@pydantic_root_validator
def _parse_root_object(cls, root_obj: RootT) -> Any: # noqa
if ROOT_KEY not in root_obj:
return root_obj
Expand Down Expand Up @@ -241,7 +253,7 @@ def to_json(self, pretty=False) -> str:
return json_content

def from_json(self, json_contents: str) -> None:
new_model = self.parse_raw(json_contents, proto=Protocol.json)
new_model = self.parse_raw(json_contents, proto=pydantic_protocol.json)
self._set_contents_without_validation(new_model)

# @classmethod
Expand Down Expand Up @@ -288,7 +300,7 @@ def _check_for_root_key(self) -> None:
'\t"class MyNumberList(Model[List[int]]): ..."')

def __setattr__(self, attr: str, value: Any) -> None:
if attr in self.__dict__ and attr not in [ROOT_KEY]:
if attr in ['__module__'] + list(self.__dict__.keys()) and attr not in [ROOT_KEY]:
super().__setattr__(attr, value)
else:
if attr in ['contents']:
Expand All @@ -297,6 +309,9 @@ def __setattr__(self, attr: str, value: Any) -> None:
else:
raise RuntimeError('Model does not allow setting of extra attributes')

# TODO: Update Dataset.__eq__ similarly, with tests
def __eq__(self, other: object) -> bool:
return self.__class__ == other.__class__ and super().__eq__(other)

def __repr__(self) -> str:
super_repr = super().__repr__()
return super_repr.replace(f'{ROOT_KEY}=', '')
8 changes: 5 additions & 3 deletions src/omnipy/hub/entry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from dataclasses import dataclass, field
# from dataclasses import dataclass, field
from typing import Optional

from pydantic import Field

from omnipy.api.protocols.public.hub import IsRuntime
from omnipy.util.publisher import DataPublisher


@dataclass
# @dataclass
class RuntimeEntryPublisher(DataPublisher):
_back: Optional[IsRuntime] = field(default=None, init=False, repr=False)
_back: Optional[IsRuntime] = Field(default=None, init=False, repr=False)

def __setattr__(self, key, value):
super().__setattr__(key, value)
Expand Down
4 changes: 2 additions & 2 deletions src/omnipy/hub/root_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from omnipy.util.helpers import get_datetime_format


@dataclass
# @dataclass
class RootLogConfigEntryPublisher(RootLogConfig, RuntimeEntryPublisher):
...


@dataclass
# @dataclass
class RootLogObjects:
_config: IsRootLogConfig = field(
init=False, repr=False, default_factory=RootLogConfigEntryPublisher)
Expand Down
35 changes: 20 additions & 15 deletions src/omnipy/hub/runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
# from dataclasses import dataclass, field
from typing import Any

from pydantic import Field

from omnipy.api.enums import EngineChoice
from omnipy.api.protocols.private.compute.job_creator import IsJobConfigHolder
from omnipy.api.protocols.private.engine import IsEngine, IsEngineConfig
Expand Down Expand Up @@ -28,28 +30,31 @@ def _job_creator_factory():
return JobBase.job_creator


@dataclass
# @dataclass
class RuntimeConfig(RuntimeEntryPublisher):
job: IsJobConfig = field(default_factory=JobConfig)
job: IsJobConfig = Field(default_factory=JobConfig)
engine: EngineChoice = EngineChoice.LOCAL
local: IsLocalRunnerConfig = field(default_factory=LocalRunnerConfig)
prefect: IsPrefectEngineConfig = field(default_factory=PrefectEngineConfig)
root_log: IsRootLogConfig = field(default_factory=RootLogConfigEntryPublisher)
local: IsLocalRunnerConfig = Field(default_factory=LocalRunnerConfig)
prefect: IsPrefectEngineConfig = Field(default_factory=PrefectEngineConfig)
root_log: IsRootLogConfig = Field(default_factory=RootLogConfigEntryPublisher)


@dataclass
# @dataclass
class RuntimeObjects(RuntimeEntryPublisher):
job_creator: IsJobConfigHolder = field(default_factory=_job_creator_factory)
local: IsEngine = field(default_factory=LocalRunner)
prefect: IsEngine = field(default_factory=PrefectEngine)
registry: IsRunStateRegistry = field(default_factory=RunStateRegistry)
root_log: IsRootLogObjects = field(default_factory=RootLogObjects)
job_creator: IsJobConfigHolder = Field(default_factory=_job_creator_factory)
local: IsEngine = Field(default_factory=LocalRunner)
prefect: IsEngine = Field(default_factory=PrefectEngine)
registry: IsRunStateRegistry = Field(default_factory=RunStateRegistry)
root_log: IsRootLogObjects = Field(default_factory=RootLogObjects)


# TODO: Add automatic parsing of config values when setting to string values


@dataclass
# @dataclass
class Runtime(DataPublisher):
config: IsRuntimeConfig = field(default_factory=RuntimeConfig)
objects: IsRuntimeObjects = field(default_factory=RuntimeObjects)
config: IsRuntimeConfig = Field(default_factory=RuntimeConfig)
objects: IsRuntimeObjects = Field(default_factory=RuntimeObjects)

def __post_init__(self):
super().__init__()
Expand Down
Loading

0 comments on commit a004dba

Please sign in to comment.