diff --git a/elasticai/creator/ir/__init__.py b/elasticai/creator/ir/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/elasticai/creator/ir/graph_delegate.py b/elasticai/creator/ir/graph_delegate.py new file mode 100644 index 00000000..88a3b849 --- /dev/null +++ b/elasticai/creator/ir/graph_delegate.py @@ -0,0 +1,52 @@ +from collections.abc import Hashable, Iterable, Iterator +from typing import Generic, TypeVar + +HashableT = TypeVar("HashableT", bound=Hashable) + + +class GraphDelegate(Generic[HashableT]): + def __init__(self) -> None: + """We keep successor and predecessor nodes just to allow for easier implementation. + Currently, this implementation is not optimized for performance. + """ + self.successors: dict[HashableT, dict[HashableT, None]] = dict() + self.predecessors: dict[HashableT, dict[HashableT, None]] = dict() + + @staticmethod + def from_dict(d: dict[HashableT, Iterable[HashableT]]): + g = GraphDelegate() + for node, successors in d.items(): + for s in successors: + g.add_edge(node, s) + return g + + def as_dict(self) -> dict[HashableT, set[HashableT]]: + return self.successors.copy() + + def add_edge(self, _from: HashableT, _to: HashableT): + self.add_node(_from) + self.add_node(_to) + self.predecessors[_to][_from] = None + self.successors[_from][_to] = None + return self + + def add_node(self, node: HashableT): + if node not in self.predecessors: + self.predecessors[node] = dict() + if node not in self.successors: + self.successors[node] = dict() + return self + + def iter_nodes(self) -> Iterator[HashableT]: + yield from self.predecessors.keys() + + def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]: + for _from, _tos in self.successors.items(): + for _to in _tos: + yield _from, _to + + def get_successors(self, node: HashableT) -> Iterator[HashableT]: + yield from self.successors[node] + + def get_predecessors(self, node: HashableT) -> Iterator[HashableT]: + yield from self.predecessors[node] diff --git a/elasticai/creator/ir/graph_delegate_test.py b/elasticai/creator/ir/graph_delegate_test.py new file mode 100644 index 00000000..f7c6e822 --- /dev/null +++ b/elasticai/creator/ir/graph_delegate_test.py @@ -0,0 +1,66 @@ +from .graph_delegate import GraphDelegate +from .graph_iterators import bfs_iter_up, dfs_pre_order + + +def test_iterating_breadth_first_upwards(): + g = GraphDelegate() + """ + 0 + | + /-----\ + | | + 1 2 + | /---+ + |/ | + 3 4 + | | + | 6 + |/----+ + 5 + """ + g = GraphDelegate.from_dict( + { + "0": ["1", "2"], + "1": ["3"], + "2": ["3", "4"], + "3": ["5"], + "4": ["6"], + "6": ["5"], + } + ) + + actual = tuple(bfs_iter_up(g.get_predecessors, "5")) + assert actual == ("3", "6", "1", "2", "4", "0") + + +def test_iterating_depth_first_preorder(): + g = GraphDelegate() + """ + 0 + | + /-----\ + | | + 1 2 + | /---+ + |/ | + 3 4 + | | + | 6 + |/----+ + 5 + """ + g = GraphDelegate.from_dict( + { + "0": ["1", "2"], + "1": ["3"], + "2": ["3", "4"], + "3": ["5"], + "4": ["6"], + "6": ["5"], + } + ) + + actual = tuple(dfs_pre_order(g.get_successors, "0")) + print(tuple(g.get_successors("0"))) + expected = ("0", "1", "3", "5", "2", "4", "6") + assert actual == expected diff --git a/elasticai/creator/ir/graph_iterators.py b/elasticai/creator/ir/graph_iterators.py new file mode 100644 index 00000000..c405462a --- /dev/null +++ b/elasticai/creator/ir/graph_iterators.py @@ -0,0 +1,36 @@ +from collections.abc import Callable, Iterable, Iterator +from typing import Hashable, TypeAlias, TypeVar + +HashableT = TypeVar("HashableT", bound=Hashable) + +NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]] + + +def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + visited: set[HashableT] = set() + + def visit(nodes: tuple[HashableT, ...]): + for n in nodes: + if n not in visited: + yield n + visited.add(n) + yield from visit(tuple(successors(n))) + + yield from visit((start,)) + + +def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + visited: set[HashableT] = set() + visit_next = [start] + while len(visit_next) > 0: + current = visit_next.pop(0) + for p in successors(current): + if p not in visited: + yield p + visited.add(p) + visit_next.append(p) + visited.add(current) + + +def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: + return bfs_iter_down(predecessors, start)