Skip to content

Commit

Permalink
feat(ir): add abstract ir data class and nodes
Browse files Browse the repository at this point in the history
The purpose of this class is to provide a way to
easily write new wrappers around dictionaries,
that let us customize access, while still
supporting static type annotations.
  • Loading branch information
glencoe committed Nov 26, 2024
1 parent 9102843 commit bb81f0d
Show file tree
Hide file tree
Showing 11 changed files with 531 additions and 0 deletions.
6 changes: 6 additions & 0 deletions elasticai/creator/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .abstract_ir_data import AbstractIRData, MandatoryField

__all__ = [
"MandatoryField",
"AbstractIRData",
]
7 changes: 7 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .abstract_ir_data import AbstractIRData
from .mandatory_field import MandatoryField

__all__ = [
"AbstractIRData",
"MandatoryField",
]
16 changes: 16 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_attributes_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections.abc import MutableMapping

from elasticai.creator.ir.attribute import AttributeT

from ._has_data import HasData
from ._hiding_dict import _HidingDict


class _AttributesDescriptor:
def __init__(self, hidden_names: set[str]):
self._hidden_names = hidden_names

def __get__(
self, instance: HasData, owner: type[HasData]
) -> MutableMapping[str, AttributeT]:
return _HidingDict(self._hidden_names, instance.data)
9 changes: 9 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_has_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Protocol, runtime_checkable

from elasticai.creator.ir.attribute import AttributeT


@runtime_checkable
class HasData(Protocol):
@property
def data(self) -> dict[str, AttributeT]: ...
83 changes: 83 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/_hiding_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import collections
from collections.abc import Iterable
from itertools import filterfalse
from typing import MutableMapping, TypeVar

T = TypeVar("T")


class _HidingDict(MutableMapping[str, T]):
"""Allows to hide keys with `hidden_names` for all read operations.
We use this to implement an attributes field for Nodes that looks like a dictionary, but hides
all mandatory fields.
You can still write to `HidingDict`, e.g.,
>>> d = dict(a="a", b="b")
>>> h = _HidingDict({"a"}, d)
>>> "b", == tuple(h.keys())
True
>>> "a" in h
False
>>> h["a"] = "c"
>>> h.data["a"]
'c'
>>> d["a"]
'c'
>>> "a" in d and "a" in h.data
True
"""

def __init__(self, hidden_names: Iterable[str], data: dict) -> None:
self.data = data
self._hidden_names = set(hidden_names)

def __setitem__(self, key: str, value: T):
self.data[key] = value

def _is_hidden(self, name: str) -> bool:
return name in self._hidden_names

def __delitem__(self, key: str):
del self.data[key]

def __iter__(self):
return filterfalse(self._is_hidden, iter(self.data))

def __contains__(self, item):
# overriding this should also make class behave correctly for getting items
return item not in self._hidden_names and item in self.data

def __getitem__(self, item: str) -> T:
return self.data[item]

def __len__(self):
return len(self.data)

def get(self, key: str, default=None) -> T:
if key in self:
return self[key]
return default

def __copy__(self) -> MutableMapping[str, T]:
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"].copy()
return inst

def copy(self) -> MutableMapping[str, T]:
if self.__class__ is collections.UserDict:
return _HidingDict(self._hidden_names.copy(), self.data.copy())
import copy

data = self.data
try:
self.data = {}
c = copy.copy(self)
finally:
self.data = data
c.update(self)
return c

def __repr__(self) -> str:
return f"HidingDict({', '.join(self._hidden_names)}, data={self.data})"
171 changes: 171 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/abstract_ir_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import inspect
import sys
from abc import abstractmethod
from collections.abc import Iterable
from typing import Any, Callable, TypeVar

from elasticai.creator.ir.attribute import AttributeT

from ._attributes_descriptor import _AttributesDescriptor
from .mandatory_field import MandatoryField, TransformableMandatoryField

if sys.version_info.minor > 10:
from typing import Self
else:
Self = TypeVar("Self", bound="AbstractIrData")


class AbstractIRData:
"""
This class should provide a way to easily create new wrappers around dictionaries.
It is supposed to be used together with the `MandatoryField` class.
Every child of `AbstractIRData` is expected to have a constructor that takes a dictionary.
That dictionary is not copied, but instead can be shared with other Node classes.
Most of the private functions in this class deal with handling arguments of the classmethod
`new`, that is used to create new nodes from scratch.
The `attributes` attribute of the class provides a dict-like object, that hides all keys that are associated
with mandatory fields.
The purpose of this class is to provide a way to easily write new wrappers around dictionaries, that let us customize
access, while still allowing static type annotations.
"""

__slots__ = ("data",)
attributes = _AttributesDescriptor(set())

def __init__(self: Self, data: dict[str, AttributeT]):
"""IMPORTANT: Do not override this. If you want to create a function that creates new nodes of your subtype,
override the `new` method instead.
"""
for k in self._mandatory_fields():
if k not in data:
raise ValueError(f"Missing mandatory field {k}")
self.data = data

@classmethod
def _do_new(cls, *args, **kwargs) -> Self:
"""This is here for your convenience to be called in `new`."""
cls.__validate_arguments(args, kwargs)
data = cls.__turn_arguments_into_data_dict(args, kwargs)
return cls(data)

@classmethod
@abstractmethod
def new(cls, *args, **kwargs) -> Self:
"""Create a new node by creating a new dictionary from args and kwargs.
Use this for creation of new nodes from inline code. This is typically also where you want to provide
type hints for users via `@overload`. You can delegate to the `_do_new()` method of `BaseNode`
"""
...

def as_dict(self: Self) -> dict[str, AttributeT]:
return self.data

@classmethod
def from_dict(cls: type[Self], data: dict[str, AttributeT]) -> Self:
return cls(data)

def __eq__(self: Self, other: object) -> bool:
if hasattr(other, "data") and isinstance(other.data, dict):
return self.data == other.data
else:
return False

@classmethod
def __turn_arguments_into_data_dict(
cls, args: tuple[Any], kwargs: dict[str, Any]
) -> dict[str, Any]:
data = cls.__extract_attributes_from_args(args, kwargs)
data.update(cls.__get_kwargs_without_attributes(kwargs))
data.update(cls.__get_args_as_kwargs(args))
cls.__transform_args_with_mandatory_fields(data)
return data

@classmethod
def __get_mandatory_field_descriptors(
cls,
) -> Iterable[tuple[str, TransformableMandatoryField]]:
for c in reversed(inspect.getmro(cls)):
for a in c.__dict__:
if (
not a.startswith("__")
and not a.endswith("__")
and isinstance(c.__dict__[a], TransformableMandatoryField)
):
yield a, c.__dict__[a]

@classmethod
def _mandatory_fields(cls) -> tuple[str, ...]:
return tuple(name for name, _ in cls.__get_mandatory_field_descriptors())

def __attribute_keys(self: Self):
return tuple(k for k in self.data.keys() if k not in self._mandatory_fields())

def __repr__(self: Self):
mandatory_fields_repr = ", ".join(
f"{k}={self.data[k]}" for k in self._mandatory_fields()
)
attributes = ", ".join(
f"'{k}': '{self.data[k]}'" for k in self.__attribute_keys()
)
return (
f"{self.__class__.__name__}({mandatory_fields_repr},"
f" attributes={attributes})"
)

@classmethod
def __validate_arguments(cls, args: tuple[Any], kwargs: dict[str, Any]):
num_total_args = len(args) + len(kwargs)
if num_total_args not in (
len(cls._mandatory_fields()),
len(cls._mandatory_fields()) + 1,
):
raise ValueError(
f"allowed arguments are {cls._mandatory_fields()} and attributes, but"
f" passed args: {args} and kwargs: {kwargs}"
)
argument_names_in_args = set(k for k, _ in zip(cls._mandatory_fields(), args))
arguments_specified_twice = argument_names_in_args.intersection(kwargs.keys())
if len(arguments_specified_twice) > 0:
raise ValueError(f"arguments specified twice {arguments_specified_twice}")

@classmethod
def __extract_attributes_from_args(
cls, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> dict[str, Any]:
data = dict()
if "attributes" in kwargs:
if callable(kwargs["attributes"]):
data.update(kwargs["attributes"]())
else:
data.update(kwargs["attributes"])
elif len(args) + len(kwargs) == len(cls._mandatory_fields()) + 1:
attributes = args[-1]
data.update(attributes)
return data

@classmethod
def __get_kwargs_without_attributes(
cls, kwargs: dict[str, Any]
) -> dict[str, AttributeT]:
kwargs = {k: v for k, v in kwargs.items() if k != "attributes"}
return kwargs

@classmethod
def __transform_args_with_mandatory_fields(cls, args: dict[str, Any]) -> None:
set_transforms = cls.__get_field_transforms()
for k in args:
if k in set_transforms:
args[k] = set_transforms[k](args[k])

@classmethod
def __get_field_transforms(cls) -> dict[str, Callable[[Any], AttributeT]]:
return {k: v.set_transform for k, v in cls.__get_mandatory_field_descriptors()}

@classmethod
def __get_args_as_kwargs(cls, args: tuple[Any]) -> Iterable[tuple[str, Any]]:
return zip(cls._mandatory_fields(), args)
72 changes: 72 additions & 0 deletions elasticai/creator/ir/abstract_ir_data/mandatory_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Callable
from typing import cast

from typing_extensions import Generic, Protocol, TypeVar

from elasticai.creator.ir.attribute import AttributeT

from ._has_data import HasData

T = TypeVar("T", bound=AttributeT) # stored data type

F = TypeVar("F", default=T) # visible data type


class AbstractIR(Protocol):
data: dict[str, AttributeT]


class TransformableMandatoryField(Generic[T, F]):
"""
A __descriptor__ that designates a mandatory field of an abstract ir data class.
The descriptor accesses the `data` dictionary of the owning abstract ir data object
to read and write values. You can use the `set_transform` and `get_transform` functions
to transform values during read/write accesses. `T` designates the type stored in the
`data` dictionary, while `F` is the type that the mandatory field receives.
That allows to keep dictionary of primitive (serializable) data types in memory,
while still providing abstract ways to manipulate that data in complex ways.
This is typically required when working with Nodes and Graphs to create new
intermediate representations and transform one graph into another.
E.g.
```python
class A(AbstractIrData):
number: TransformableMandatoryField[str, int] = TransformableMandatoryField(set_transform=str, get_transform=int)
a = A({'number': "12"})
a.number = a.number + 3
print(a.data) # {'number': "15"}
```
"""

def __init__(
self,
set_transform: Callable[[F], T],
get_transform: Callable[[T], F],
):
self.set_transform = set_transform
self.get_transform = get_transform

def __set_name__(self, owner, name: str) -> None:
"""
IMPORTANT: do not remove owner even though it's not used
see https://docs.python.org/3/reference/datamodel.html#descriptors for more information
"""
self.name = name

def __get__(self, instance: HasData, owner) -> F:
"""
IMPORTANT: do not remove owner even though it's not used
see https://docs.python.org/3/reference/datamodel.html#descriptors for more information
"""
return self.get_transform(cast(T, instance.data[self.name]))

def __set__(self, instance: HasData, value: F) -> None:
instance.data[self.name] = self.set_transform(value)


class MandatoryField(TransformableMandatoryField[T, T]):
def __init__(self):
super().__init__(lambda x: x, lambda x: x)
6 changes: 6 additions & 0 deletions elasticai/creator/ir/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import TypeAlias

SizeT: TypeAlias = tuple[int] | tuple[int, int] | tuple[int, int, int]
AttributeT: TypeAlias = (
int | float | str | tuple["AttributeT", ...] | dict[str, "AttributeT"]
)
Loading

0 comments on commit bb81f0d

Please sign in to comment.