From d3bbcd47b9e3c765ba85319715bc8a02e6c99d68 Mon Sep 17 00:00:00 2001 From: Lukas Einhaus Date: Tue, 5 Nov 2024 18:49:31 +0100 Subject: [PATCH] feat(ir): add draft for intermediate representation data structures The data structures are quite minimalistic and include: - Graph - Node - Edge The supported features are: - filtering nodes by type - DFS preorder graph traversal (starting at specified node) - BFS graph traversal (starting at specified node) - Graph/Node/Edge serialization and deserialization - accessing successors/predecessors in graph --- elasticai/creator/ir/__init__.py | 4 + elasticai/creator/ir/attribute.py | 6 ++ elasticai/creator/ir/edge.py | 24 ++++++ elasticai/creator/ir/graph.py | 81 +++++++++++++++++++ elasticai/creator/ir/graph_iterators.py | 35 ++++++++ elasticai/creator/ir/graph_test.py | 53 ++++++++++++ elasticai/creator/ir/humble_base_graph.py | 52 ++++++++++++ .../creator/ir/humble_base_graph_test.py | 69 ++++++++++++++++ elasticai/creator/ir/node.py | 24 ++++++ elasticai/creator/ir/node_test.py | 19 +++++ 10 files changed, 367 insertions(+) create mode 100644 elasticai/creator/ir/__init__.py create mode 100644 elasticai/creator/ir/attribute.py create mode 100644 elasticai/creator/ir/edge.py create mode 100644 elasticai/creator/ir/graph.py create mode 100644 elasticai/creator/ir/graph_iterators.py create mode 100644 elasticai/creator/ir/graph_test.py create mode 100644 elasticai/creator/ir/humble_base_graph.py create mode 100644 elasticai/creator/ir/humble_base_graph_test.py create mode 100644 elasticai/creator/ir/node.py create mode 100644 elasticai/creator/ir/node_test.py diff --git a/elasticai/creator/ir/__init__.py b/elasticai/creator/ir/__init__.py new file mode 100644 index 00000000..a287528b --- /dev/null +++ b/elasticai/creator/ir/__init__.py @@ -0,0 +1,4 @@ +from .attribute import AttributeT, SizeT +from .edge import Edge +from .graph import Graph +from .node import Node diff --git a/elasticai/creator/ir/attribute.py b/elasticai/creator/ir/attribute.py new file mode 100644 index 00000000..2441d7e3 --- /dev/null +++ b/elasticai/creator/ir/attribute.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +SizeT: TypeAlias = tuple[int] | tuple[int, int] | tuple[int, int, int] +AttributeT: TypeAlias = ( + int | float | str | tuple["AttributeT", ...] | dict[str, "AttributeT"] +) diff --git a/elasticai/creator/ir/edge.py b/elasticai/creator/ir/edge.py new file mode 100644 index 00000000..fd14866d --- /dev/null +++ b/elasticai/creator/ir/edge.py @@ -0,0 +1,24 @@ +import dataclasses +from dataclasses import dataclass + +from .attribute import AttributeT + + +@dataclass(eq=True, frozen=True) +class Edge: + src: str + sink: str + attributes: dict[str, AttributeT] = dataclasses.field(default_factory=dict) + + @classmethod + def _filter_attributes(cls, d: dict[str, AttributeT]) -> dict[str, AttributeT]: + return dict((k, v) for k, v in d if k not in ("src", "sink")) + + def as_dict(self) -> dict[str, AttributeT]: + return dict(src=self.src, sink=self.sink) | self.attributes + + @classmethod + def from_dict(cls, data: dict[str, AttributeT]) -> "Edge": + return cls( + src=data["src"], sink=data["sink"], attributes=cls._filter_attributes(data) + ) diff --git a/elasticai/creator/ir/graph.py b/elasticai/creator/ir/graph.py new file mode 100644 index 00000000..98b68852 --- /dev/null +++ b/elasticai/creator/ir/graph.py @@ -0,0 +1,81 @@ +from collections.abc import Iterator + +from .attribute import AttributeT +from .edge import Edge +from .humble_base_graph import HumbleBaseGraph +from .node import Node + + +class Graph: + def __init__(self): + self._hbg: HumbleBaseGraph[str] = HumbleBaseGraph() + self._nodes: dict[str, Node] = {} + self._edges: dict[(str, str), Edge] = dict() + self.add_node(Node("input", "input")) + self.add_node(Node("output", "output")) + + def get_node(self, name: str) -> Node: + return self._nodes[name] + + def has_node(self, name: str) -> bool: + return name in self._nodes + + def add_node(self, n: Node) -> "Graph": + self._nodes[n.name] = n + self._hbg.add_node(n.name) + return self + + def add_edge(self, e: Edge) -> "Graph": + self._hbg.add_edge(e.src, e.sink) + self._edges[(e.src, e.sink)] = e + return self + + def iter_edges(self) -> Iterator[Edge]: + yield from self._edges.values() + + def iter_src_sink_pairs(self) -> Iterator[tuple[Node, Node]]: + for src, sink in self._edges: + yield self.get_node(src), self.get_node(sink) + + def iter_src_sink_name_pairs(self) -> Iterator[tuple[str, str]]: + yield from self._edges + + def iter_nodes(self) -> Iterator[Node]: + for n in self._hbg.iter_nodes(): + yield self._nodes[n] + + def iter_node_names(self) -> Iterator[str]: + yield from self._hbg.iter_nodes() + + def get_nodes_by_type(self, type: str) -> Iterator[Node]: + for n in self.iter_nodes(): + if n.type == type: + yield n + + def get_successors(self, name: str) -> Iterator[Node]: + for name in self._hbg.get_successors(name): + yield self._nodes[name] + + def get_successor_names(self, name: str) -> Iterator[str]: + yield from self._hbg.get_successors(name) + + def get_predecessors(self, name: str) -> Iterator[Node]: + for name in self._hbg.get_predecessors(name): + yield self._nodes[name] + + def as_dict(self) -> dict[str, AttributeT]: + def make_dict(x): + return x.as_dict() + + return dict( + nodes=list(map(make_dict, self.iter_nodes())), + edges=list(map(make_dict, self.iter_edges())), + ) + + @classmethod + def from_dict(cls, data: dict[str, AttributeT]) -> "Graph": + g = cls() + for n in data["nodes"]: + g.add_node(Node.from_dict(n)) + for e in data["edges"]: + g.add_edge(Edge.from_dict(e)) diff --git a/elasticai/creator/ir/graph_iterators.py b/elasticai/creator/ir/graph_iterators.py new file mode 100644 index 00000000..cc25e045 --- /dev/null +++ b/elasticai/creator/ir/graph_iterators.py @@ -0,0 +1,35 @@ +from collections.abc import Callable, Iterable, Iterator +from typing import Hashable, Protocol, TypeAlias, TypeVar + +HashableT = TypeVar("HashableT", bound=Hashable) + +NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]] + + +def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + visited: set[HashableT] = set() + + def visit(nodes: tuple[HashableT, ...]): + for n in nodes: + if n not in visited: + yield n + visited.add(n) + yield from visit(tuple(successors(n))) + + yield from visit((start,)) + + +def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + visited: set[HashableT] = set() + visit_next = [start] + while len(visit_next) > 0: + current = visit_next.pop(0) + for p in successors(current): + if p not in visited: + yield p + visit_next.append(p) + visited.add(current) + + +def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + return bfs_iter_down(predecessors, start) diff --git a/elasticai/creator/ir/graph_test.py b/elasticai/creator/ir/graph_test.py new file mode 100644 index 00000000..016a892e --- /dev/null +++ b/elasticai/creator/ir/graph_test.py @@ -0,0 +1,53 @@ +from .edge import Edge +from .graph import Graph +from .node import Node + + +def test_has_input_node(): + g = Graph() + assert g.get_node("input") == Node(name="input", type="input") + + +def test_can_test_for_node_existence(): + g = Graph() + g.add_node(Node(name="a", type="b")) + f = Graph() + assert g.has_node("a") and not f.has_node("a") + + +def test_can_get_successors(): + g = Graph() + g.add_node(Node(name="a", type="b")).add_edge(Edge(src="input", sink="a")) + s = tuple(g.get_successors("input"))[0] + assert Node(name="a", type="b") == s + + +def test_can_get_predecessors(): + g = Graph() + g.add_node(Node(name="a", type="b")).add_edge(Edge(src="input", sink="a")) + p = tuple(g.get_predecessors("a"))[0] + assert Node(name="input", type="input") == p + + +def test_can_serialize_graph(): + g = Graph() + g.add_edge(Edge(src="input", sink="a", attributes=dict(indices=(2, 3)))).add_node( + Node(name="input", type="input", attributes=dict(input_shape=(2, 4, 8))) + ).add_node(Node("a", "b")) + g.get_node("input").attributes["output_shape"] = (1, 1, 1) + + serialized = g.as_dict() + expected = { + "nodes": [ + { + "name": "input", + "type": "input", + "input_shape": (2, 4, 8), + "output_shape": (1, 1, 1), + }, + {"name": "output", "type": "output"}, + {"name": "a", "type": "b"}, + ], + "edges": [{"src": "input", "sink": "a", "indices": (2, 3)}], + } + assert serialized == expected diff --git a/elasticai/creator/ir/humble_base_graph.py b/elasticai/creator/ir/humble_base_graph.py new file mode 100644 index 00000000..63751bbe --- /dev/null +++ b/elasticai/creator/ir/humble_base_graph.py @@ -0,0 +1,52 @@ +from collections.abc import Hashable, Iterable, Iterator +from typing import Generic, TypeVar + +HashableT = TypeVar("HashableT", bound=Hashable) + + +class HumbleBaseGraph(Generic[HashableT]): + def __init__(self) -> None: + """We keep successor and predecessor nodes just to allow for easier implementation. + Currently, this implementation is not optimized for performance. + """ + self.successors: dict[HashableT, dict[HashableT, None]] = dict() + self.predecessors: dict[HashableT, dict[HashableT, None]] = dict() + + @staticmethod + def from_dict(d: dict[HashableT, Iterable[HashableT]]): + g = HumbleBaseGraph() + for node, successors in d.items(): + for s in successors: + g.add_edge(node, s) + return g + + def as_dict(self) -> dict[HashableT, set[HashableT]]: + return self.successors.copy() + + def add_edge(self, _from: HashableT, _to: HashableT): + self.add_node(_from) + self.add_node(_to) + self.predecessors[_to][_from] = None + self.successors[_from][_to] = None + return self + + def add_node(self, node: HashableT): + if node not in self.predecessors: + self.predecessors[node] = dict() + if node not in self.successors: + self.successors[node] = dict() + return self + + def iter_nodes(self) -> Iterator[HashableT]: + yield from self.predecessors.keys() + + def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]: + for _from, _tos in self.successors.items(): + for _to in _tos: + yield _from, _to + + def get_successors(self, node: HashableT) -> Iterator[HashableT]: + yield from self.successors[node] + + def get_predecessors(self, node: HashableT) -> Iterator[HashableT]: + yield from self.predecessors[node] diff --git a/elasticai/creator/ir/humble_base_graph_test.py b/elasticai/creator/ir/humble_base_graph_test.py new file mode 100644 index 00000000..7e0a35de --- /dev/null +++ b/elasticai/creator/ir/humble_base_graph_test.py @@ -0,0 +1,69 @@ +from .graph_iterators import bfs_iter_up, dfs_pre_order +from .humble_base_graph import HumbleBaseGraph + + +def test_iterating_breadth_first_upwards(): + g = HumbleBaseGraph() + """ + 0 + | + /-----\ + | | + 1 2 + | /---+ + |/ | + 3 4 + | | + | 6 + |/----+ + 5 + """ + g = HumbleBaseGraph.from_dict( + { + "0": ["1", "2"], + "1": ["3"], + "2": ["3", "4"], + "3": ["5"], + "4": ["6"], + "6": ["5"], + } + ) + + actual = tuple(bfs_iter_up(g.get_predecessors, "5")) + assert set(actual[0:2]) == {"3", "6"} + assert (set(actual[2:4]) == {"1", "2"} and actual[4] == "4") or ( + set(actual[3:5]) == {"1", "2"} and actual[2] == "4" + ) + + +def test_iterating_depth_first_preorder(): + g = HumbleBaseGraph() + """ + 0 + | + /-----\ + | | + 1 2 + | /---+ + |/ | + 3 4 + | | + | 6 + |/----+ + 5 + """ + g = HumbleBaseGraph.from_dict( + { + "0": ["1", "2"], + "1": ["3"], + "2": ["3", "4"], + "3": ["5"], + "4": ["6"], + "6": ["5"], + } + ) + + actual = tuple(dfs_pre_order(g.get_successors, "0")) + print(tuple(g.get_successors("0"))) + expected = ("0", "1", "3", "5", "2", "4", "6") + assert actual == expected diff --git a/elasticai/creator/ir/node.py b/elasticai/creator/ir/node.py new file mode 100644 index 00000000..354ade7c --- /dev/null +++ b/elasticai/creator/ir/node.py @@ -0,0 +1,24 @@ +import dataclasses +from dataclasses import dataclass + +from .attribute import AttributeT + + +@dataclass +class Node: + name: str + type: str + attributes: dict[str, AttributeT] = dataclasses.field(default_factory=dict) + + def as_dict(self) -> dict[str, AttributeT]: + return dict(name=self.name, type=self.type) | self.attributes + + @classmethod + def from_dict(cls, data: dict[str, AttributeT]) -> "Node": + return Node( + name=data["name"], + type=data["type"], + attributes=dict( + (k, v) for k, v in data.items() if k not in ("name", "type") + ), + ) diff --git a/elasticai/creator/ir/node_test.py b/elasticai/creator/ir/node_test.py new file mode 100644 index 00000000..3c00405d --- /dev/null +++ b/elasticai/creator/ir/node_test.py @@ -0,0 +1,19 @@ +from elasticai.creator.ir.node import Node as OtherNode + +from .node import Node + + +def test_can_serialize_node_with_attributes(): + n = Node(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)}) + assert dict(name="my_node", type="my_type", a="b", c=(1, 2)) == n.as_dict() + + +def test_can_deserialize_node_with_attributes(): + n = Node(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)}) + assert n == Node.from_dict(dict(name="my_node", type="my_type", a="b", c=(1, 2))) + + +def test_import_path_does_not_matter_for_equality(): + n = Node(name="a", type="a_type", attributes=dict(a="b")) + other = OtherNode(name="a", type="a_type", attributes=dict(a="b")) + assert n == other