Skip to content

Commit

Permalink
Merge pull request #393 from es-ude/390-add-node-data-structure
Browse files Browse the repository at this point in the history
feat(ir): add abstract ir data class and nodes
  • Loading branch information
LeoBuron authored Nov 26, 2024
2 parents 07f5160 + bb81f0d commit e95aab7
Show file tree
Hide file tree
Showing 12 changed files with 546 additions and 15 deletions.
30 changes: 15 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
default_install_hook_types: [pre-commit, commit-msg]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
stages: [commit, manual ]
stages: [pre-commit, manual ]
- id: end-of-file-fixer
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: check-yaml
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: check-added-large-files
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: check-merge-conflict
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: check-toml
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: check-vcs-permalinks
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- id: no-commit-to-branch
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.13.0
hooks:
- id: mypy
stages: [ manual ]
Expand All @@ -32,18 +32,18 @@ repos:
- id: dead
stages: [ manual ]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
args: ["--profile=black"]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.10.0
hooks:
- id: black
stages: [ commit, manual ]
stages: [ pre-commit, manual ]
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v9.5.0
rev: v9.18.0
hooks:
- id: commitlint
stages: [commit-msg, manual]
6 changes: 6 additions & 0 deletions elasticai/creator/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .abstract_ir_data import AbstractIRData, MandatoryField

__all__ = [
"MandatoryField",
"AbstractIRData",
]
7 changes: 7 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .abstract_ir_data import AbstractIRData
from .mandatory_field import MandatoryField

__all__ = [
"AbstractIRData",
"MandatoryField",
]
16 changes: 16 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_attributes_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections.abc import MutableMapping

from elasticai.creator.ir.attribute import AttributeT

from ._has_data import HasData
from ._hiding_dict import _HidingDict


class _AttributesDescriptor:
def __init__(self, hidden_names: set[str]):
self._hidden_names = hidden_names

def __get__(
self, instance: HasData, owner: type[HasData]
) -> MutableMapping[str, AttributeT]:
return _HidingDict(self._hidden_names, instance.data)
9 changes: 9 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_has_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Protocol, runtime_checkable

from elasticai.creator.ir.attribute import AttributeT


@runtime_checkable
class HasData(Protocol):
@property
def data(self) -> dict[str, AttributeT]: ...
83 changes: 83 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_hiding_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import collections
from collections.abc import Iterable
from itertools import filterfalse
from typing import MutableMapping, TypeVar

T = TypeVar("T")


class _HidingDict(MutableMapping[str, T]):
"""Allows to hide keys with `hidden_names` for all read operations.
We use this to implement an attributes field for Nodes that looks like a dictionary, but hides
all mandatory fields.
You can still write to `HidingDict`, e.g.,
>>> d = dict(a="a", b="b")
>>> h = _HidingDict({"a"}, d)
>>> "b", == tuple(h.keys())
True
>>> "a" in h
False
>>> h["a"] = "c"
>>> h.data["a"]
'c'
>>> d["a"]
'c'
>>> "a" in d and "a" in h.data
True
"""

def __init__(self, hidden_names: Iterable[str], data: dict) -> None:
self.data = data
self._hidden_names = set(hidden_names)

def __setitem__(self, key: str, value: T):
self.data[key] = value

def _is_hidden(self, name: str) -> bool:
return name in self._hidden_names

def __delitem__(self, key: str):
del self.data[key]

def __iter__(self):
return filterfalse(self._is_hidden, iter(self.data))

def __contains__(self, item):
# overriding this should also make class behave correctly for getting items
return item not in self._hidden_names and item in self.data

def __getitem__(self, item: str) -> T:
return self.data[item]

def __len__(self):
return len(self.data)

def get(self, key: str, default=None) -> T:
if key in self:
return self[key]
return default

def __copy__(self) -> MutableMapping[str, T]:
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"].copy()
return inst

def copy(self) -> MutableMapping[str, T]:
if self.__class__ is collections.UserDict:
return _HidingDict(self._hidden_names.copy(), self.data.copy())
import copy

data = self.data
try:
self.data = {}
c = copy.copy(self)
finally:
self.data = data
c.update(self)
return c

def __repr__(self) -> str:
return f"HidingDict({', '.join(self._hidden_names)}, data={self.data})"
171 changes: 171 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/abstract_ir_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import inspect
import sys
from abc import abstractmethod
from collections.abc import Iterable
from typing import Any, Callable, TypeVar

from elasticai.creator.ir.attribute import AttributeT

from ._attributes_descriptor import _AttributesDescriptor
from .mandatory_field import MandatoryField, TransformableMandatoryField

if sys.version_info.minor > 10:
from typing import Self
else:
Self = TypeVar("Self", bound="AbstractIrData")


class AbstractIRData:
"""
This class should provide a way to easily create new wrappers around dictionaries.
It is supposed to be used together with the `MandatoryField` class.
Every child of `AbstractIRData` is expected to have a constructor that takes a dictionary.
That dictionary is not copied, but instead can be shared with other Node classes.
Most of the private functions in this class deal with handling arguments of the classmethod
`new`, that is used to create new nodes from scratch.
The `attributes` attribute of the class provides a dict-like object, that hides all keys that are associated
with mandatory fields.
The purpose of this class is to provide a way to easily write new wrappers around dictionaries, that let us customize
access, while still allowing static type annotations.
"""

__slots__ = ("data",)
attributes = _AttributesDescriptor(set())

def __init__(self: Self, data: dict[str, AttributeT]):
"""IMPORTANT: Do not override this. If you want to create a function that creates new nodes of your subtype,
override the `new` method instead.
"""
for k in self._mandatory_fields():
if k not in data:
raise ValueError(f"Missing mandatory field {k}")
self.data = data

@classmethod
def _do_new(cls, *args, **kwargs) -> Self:
"""This is here for your convenience to be called in `new`."""
cls.__validate_arguments(args, kwargs)
data = cls.__turn_arguments_into_data_dict(args, kwargs)
return cls(data)

@classmethod
@abstractmethod
def new(cls, *args, **kwargs) -> Self:
"""Create a new node by creating a new dictionary from args and kwargs.
Use this for creation of new nodes from inline code. This is typically also where you want to provide
type hints for users via `@overload`. You can delegate to the `_do_new()` method of `BaseNode`
"""
...

def as_dict(self: Self) -> dict[str, AttributeT]:
return self.data

@classmethod
def from_dict(cls: type[Self], data: dict[str, AttributeT]) -> Self:
return cls(data)

def __eq__(self: Self, other: object) -> bool:
if hasattr(other, "data") and isinstance(other.data, dict):
return self.data == other.data
else:
return False

@classmethod
def __turn_arguments_into_data_dict(
cls, args: tuple[Any], kwargs: dict[str, Any]
) -> dict[str, Any]:
data = cls.__extract_attributes_from_args(args, kwargs)
data.update(cls.__get_kwargs_without_attributes(kwargs))
data.update(cls.__get_args_as_kwargs(args))
cls.__transform_args_with_mandatory_fields(data)
return data

@classmethod
def __get_mandatory_field_descriptors(
cls,
) -> Iterable[tuple[str, TransformableMandatoryField]]:
for c in reversed(inspect.getmro(cls)):
for a in c.__dict__:
if (
not a.startswith("__")
and not a.endswith("__")
and isinstance(c.__dict__[a], TransformableMandatoryField)
):
yield a, c.__dict__[a]

@classmethod
def _mandatory_fields(cls) -> tuple[str, ...]:
return tuple(name for name, _ in cls.__get_mandatory_field_descriptors())

def __attribute_keys(self: Self):
return tuple(k for k in self.data.keys() if k not in self._mandatory_fields())

def __repr__(self: Self):
mandatory_fields_repr = ", ".join(
f"{k}={self.data[k]}" for k in self._mandatory_fields()
)
attributes = ", ".join(
f"'{k}': '{self.data[k]}'" for k in self.__attribute_keys()
)
return (
f"{self.__class__.__name__}({mandatory_fields_repr},"
f" attributes={attributes})"
)

@classmethod
def __validate_arguments(cls, args: tuple[Any], kwargs: dict[str, Any]):
num_total_args = len(args) + len(kwargs)
if num_total_args not in (
len(cls._mandatory_fields()),
len(cls._mandatory_fields()) + 1,
):
raise ValueError(
f"allowed arguments are {cls._mandatory_fields()} and attributes, but"
f" passed args: {args} and kwargs: {kwargs}"
)
argument_names_in_args = set(k for k, _ in zip(cls._mandatory_fields(), args))
arguments_specified_twice = argument_names_in_args.intersection(kwargs.keys())
if len(arguments_specified_twice) > 0:
raise ValueError(f"arguments specified twice {arguments_specified_twice}")

@classmethod
def __extract_attributes_from_args(
cls, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> dict[str, Any]:
data = dict()
if "attributes" in kwargs:
if callable(kwargs["attributes"]):
data.update(kwargs["attributes"]())
else:
data.update(kwargs["attributes"])
elif len(args) + len(kwargs) == len(cls._mandatory_fields()) + 1:
attributes = args[-1]
data.update(attributes)
return data

@classmethod
def __get_kwargs_without_attributes(
cls, kwargs: dict[str, Any]
) -> dict[str, AttributeT]:
kwargs = {k: v for k, v in kwargs.items() if k != "attributes"}
return kwargs

@classmethod
def __transform_args_with_mandatory_fields(cls, args: dict[str, Any]) -> None:
set_transforms = cls.__get_field_transforms()
for k in args:
if k in set_transforms:
args[k] = set_transforms[k](args[k])

@classmethod
def __get_field_transforms(cls) -> dict[str, Callable[[Any], AttributeT]]:
return {k: v.set_transform for k, v in cls.__get_mandatory_field_descriptors()}

@classmethod
def __get_args_as_kwargs(cls, args: tuple[Any]) -> Iterable[tuple[str, Any]]:
return zip(cls._mandatory_fields(), args)
Loading

0 comments on commit e95aab7

Please sign in to comment.