Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ir): add graph delegate and iterators #389

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
52 changes: 52 additions & 0 deletions elasticai/creator/ir/graph_delegate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Hashable, Iterable, Iterator
from typing import Generic, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)


class GraphDelegate(Generic[HashableT]):
def __init__(self) -> None:
"""We keep successor and predecessor nodes just to allow for easier implementation.
Currently, this implementation is not optimized for performance.
"""
self.successors: dict[HashableT, dict[HashableT, None]] = dict()
self.predecessors: dict[HashableT, dict[HashableT, None]] = dict()

@staticmethod
def from_dict(d: dict[HashableT, Iterable[HashableT]]):
g = GraphDelegate()
for node, successors in d.items():
for s in successors:
g.add_edge(node, s)
return g

def as_dict(self) -> dict[HashableT, set[HashableT]]:
return self.successors.copy()

def add_edge(self, _from: HashableT, _to: HashableT):
self.add_node(_from)
self.add_node(_to)
self.predecessors[_to][_from] = None
self.successors[_from][_to] = None
return self

def add_node(self, node: HashableT):
if node not in self.predecessors:
self.predecessors[node] = dict()
if node not in self.successors:
self.successors[node] = dict()
return self

def iter_nodes(self) -> Iterator[HashableT]:
yield from self.predecessors.keys()

def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]:
for _from, _tos in self.successors.items():
for _to in _tos:
yield _from, _to

def get_successors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.successors[node]

def get_predecessors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.predecessors[node]
66 changes: 66 additions & 0 deletions elasticai/creator/ir/graph_delegate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from .graph_delegate import GraphDelegate
from .graph_iterators import bfs_iter_up, dfs_pre_order


def test_iterating_breadth_first_upwards():
g = GraphDelegate()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = GraphDelegate.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(bfs_iter_up(g.get_predecessors, "5"))
assert actual == ("3", "6", "1", "2", "4", "0")


def test_iterating_depth_first_preorder():
g = GraphDelegate()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = GraphDelegate.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(dfs_pre_order(g.get_successors, "0"))
print(tuple(g.get_successors("0")))
expected = ("0", "1", "3", "5", "2", "4", "6")
assert actual == expected
36 changes: 36 additions & 0 deletions elasticai/creator/ir/graph_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import Callable, Iterable, Iterator
from typing import Hashable, TypeAlias, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)

NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]]


def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()

def visit(nodes: tuple[HashableT, ...]):
for n in nodes:
if n not in visited:
yield n
visited.add(n)
yield from visit(tuple(successors(n)))

yield from visit((start,))
mokouMonday marked this conversation as resolved.
Show resolved Hide resolved


def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()
visit_next = [start]
while len(visit_next) > 0:
current = visit_next.pop(0)
for p in successors(current):
if p not in visited:
yield p
visited.add(p)
visit_next.append(p)
visited.add(current)


def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
return bfs_iter_down(predecessors, start)
Loading