Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intermediate representation draft #385

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions elasticai/creator/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .attribute import AttributeT, SizeT
from .edge import Edge
from .graph import Graph
from .node import Node
6 changes: 6 additions & 0 deletions elasticai/creator/ir/attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import TypeAlias

SizeT: TypeAlias = tuple[int] | tuple[int, int] | tuple[int, int, int]
AttributeT: TypeAlias = (
int | float | str | tuple["AttributeT", ...] | dict[str, "AttributeT"]
)
24 changes: 24 additions & 0 deletions elasticai/creator/ir/edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import dataclasses
from dataclasses import dataclass

from .attribute import AttributeT


@dataclass(eq=True, frozen=True)
class Edge:
src: str
sink: str
attributes: dict[str, AttributeT] = dataclasses.field(default_factory=dict)

@classmethod
def _filter_attributes(cls, d: dict[str, AttributeT]) -> dict[str, AttributeT]:
return dict((k, v) for k, v in d if k not in ("src", "sink"))

def as_dict(self) -> dict[str, AttributeT]:
return dict(src=self.src, sink=self.sink) | self.attributes

@classmethod
def from_dict(cls, data: dict[str, AttributeT]) -> "Edge":
return cls(
src=data["src"], sink=data["sink"], attributes=cls._filter_attributes(data)
)
81 changes: 81 additions & 0 deletions elasticai/creator/ir/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from collections.abc import Iterator

from .attribute import AttributeT
from .edge import Edge
from .humble_base_graph import HumbleBaseGraph
from .node import Node


class Graph:
def __init__(self):
self._hbg: HumbleBaseGraph[str] = HumbleBaseGraph()
self._nodes: dict[str, Node] = {}
self._edges: dict[(str, str), Edge] = dict()
self.add_node(Node("input", "input"))
self.add_node(Node("output", "output"))

def get_node(self, name: str) -> Node:
return self._nodes[name]

def has_node(self, name: str) -> bool:
return name in self._nodes

def add_node(self, n: Node) -> "Graph":
self._nodes[n.name] = n
self._hbg.add_node(n.name)
return self

def add_edge(self, e: Edge) -> "Graph":
self._hbg.add_edge(e.src, e.sink)
self._edges[(e.src, e.sink)] = e
return self

def iter_edges(self) -> Iterator[Edge]:
yield from self._edges.values()

def iter_src_sink_pairs(self) -> Iterator[tuple[Node, Node]]:
for src, sink in self._edges:
yield self.get_node(src), self.get_node(sink)

def iter_src_sink_name_pairs(self) -> Iterator[tuple[str, str]]:
yield from self._edges

def iter_nodes(self) -> Iterator[Node]:
for n in self._hbg.iter_nodes():
yield self._nodes[n]

def iter_node_names(self) -> Iterator[str]:
yield from self._hbg.iter_nodes()

def get_nodes_by_type(self, type: str) -> Iterator[Node]:
for n in self.iter_nodes():
if n.type == type:
yield n

def get_successors(self, name: str) -> Iterator[Node]:
for name in self._hbg.get_successors(name):
yield self._nodes[name]

def get_successor_names(self, name: str) -> Iterator[str]:
yield from self._hbg.get_successors(name)

def get_predecessors(self, name: str) -> Iterator[Node]:
for name in self._hbg.get_predecessors(name):
yield self._nodes[name]

def as_dict(self) -> dict[str, AttributeT]:
def make_dict(x):
return x.as_dict()

return dict(
nodes=list(map(make_dict, self.iter_nodes())),
edges=list(map(make_dict, self.iter_edges())),
)

@classmethod
def from_dict(cls, data: dict[str, AttributeT]) -> "Graph":
g = cls()
for n in data["nodes"]:
g.add_node(Node.from_dict(n))
for e in data["edges"]:
g.add_edge(Edge.from_dict(e))
35 changes: 35 additions & 0 deletions elasticai/creator/ir/graph_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from collections.abc import Callable, Iterable, Iterator
from typing import Hashable, Protocol, TypeAlias, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)

NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]]


def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()

def visit(nodes: tuple[HashableT, ...]):
for n in nodes:
if n not in visited:
yield n
visited.add(n)
yield from visit(tuple(successors(n)))

yield from visit((start,))


def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()
visit_next = [start]
while len(visit_next) > 0:
current = visit_next.pop(0)
for p in successors(current):
if p not in visited:
yield p
visit_next.append(p)
visited.add(current)


def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
return bfs_iter_down(predecessors, start)
53 changes: 53 additions & 0 deletions elasticai/creator/ir/graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from .edge import Edge
from .graph import Graph
from .node import Node


def test_has_input_node():
g = Graph()
assert g.get_node("input") == Node(name="input", type="input")


def test_can_test_for_node_existence():
g = Graph()
g.add_node(Node(name="a", type="b"))
f = Graph()
assert g.has_node("a") and not f.has_node("a")


def test_can_get_successors():
g = Graph()
g.add_node(Node(name="a", type="b")).add_edge(Edge(src="input", sink="a"))
s = tuple(g.get_successors("input"))[0]
assert Node(name="a", type="b") == s


def test_can_get_predecessors():
g = Graph()
g.add_node(Node(name="a", type="b")).add_edge(Edge(src="input", sink="a"))
p = tuple(g.get_predecessors("a"))[0]
assert Node(name="input", type="input") == p


def test_can_serialize_graph():
g = Graph()
g.add_edge(Edge(src="input", sink="a", attributes=dict(indices=(2, 3)))).add_node(
Node(name="input", type="input", attributes=dict(input_shape=(2, 4, 8)))
).add_node(Node("a", "b"))
g.get_node("input").attributes["output_shape"] = (1, 1, 1)

serialized = g.as_dict()
expected = {
"nodes": [
{
"name": "input",
"type": "input",
"input_shape": (2, 4, 8),
"output_shape": (1, 1, 1),
},
{"name": "output", "type": "output"},
{"name": "a", "type": "b"},
],
"edges": [{"src": "input", "sink": "a", "indices": (2, 3)}],
}
assert serialized == expected
52 changes: 52 additions & 0 deletions elasticai/creator/ir/humble_base_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Hashable, Iterable, Iterator
from typing import Generic, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)


class HumbleBaseGraph(Generic[HashableT]):
def __init__(self) -> None:
"""We keep successor and predecessor nodes just to allow for easier implementation.
Currently, this implementation is not optimized for performance.
"""
self.successors: dict[HashableT, dict[HashableT, None]] = dict()
self.predecessors: dict[HashableT, dict[HashableT, None]] = dict()

@staticmethod
def from_dict(d: dict[HashableT, Iterable[HashableT]]):
g = HumbleBaseGraph()
for node, successors in d.items():
for s in successors:
g.add_edge(node, s)
return g

def as_dict(self) -> dict[HashableT, set[HashableT]]:
return self.successors.copy()

def add_edge(self, _from: HashableT, _to: HashableT):
self.add_node(_from)
self.add_node(_to)
self.predecessors[_to][_from] = None
self.successors[_from][_to] = None
return self

def add_node(self, node: HashableT):
if node not in self.predecessors:
self.predecessors[node] = dict()
if node not in self.successors:
self.successors[node] = dict()
return self

def iter_nodes(self) -> Iterator[HashableT]:
yield from self.predecessors.keys()

def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]:
for _from, _tos in self.successors.items():
for _to in _tos:
yield _from, _to

def get_successors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.successors[node]

def get_predecessors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.predecessors[node]
69 changes: 69 additions & 0 deletions elasticai/creator/ir/humble_base_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .graph_iterators import bfs_iter_up, dfs_pre_order
from .humble_base_graph import HumbleBaseGraph


def test_iterating_breadth_first_upwards():
g = HumbleBaseGraph()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = HumbleBaseGraph.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(bfs_iter_up(g.get_predecessors, "5"))
assert set(actual[0:2]) == {"3", "6"}
assert (set(actual[2:4]) == {"1", "2"} and actual[4] == "4") or (
set(actual[3:5]) == {"1", "2"} and actual[2] == "4"
)


def test_iterating_depth_first_preorder():
g = HumbleBaseGraph()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = HumbleBaseGraph.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(dfs_pre_order(g.get_successors, "0"))
print(tuple(g.get_successors("0")))
expected = ("0", "1", "3", "5", "2", "4", "6")
assert actual == expected
24 changes: 24 additions & 0 deletions elasticai/creator/ir/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import dataclasses
from dataclasses import dataclass

from .attribute import AttributeT


@dataclass
class Node:
name: str
type: str
attributes: dict[str, AttributeT] = dataclasses.field(default_factory=dict)

def as_dict(self) -> dict[str, AttributeT]:
return dict(name=self.name, type=self.type) | self.attributes

@classmethod
def from_dict(cls, data: dict[str, AttributeT]) -> "Node":
return Node(
name=data["name"],
type=data["type"],
attributes=dict(
(k, v) for k, v in data.items() if k not in ("name", "type")
),
)
19 changes: 19 additions & 0 deletions elasticai/creator/ir/node_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from elasticai.creator.ir.node import Node as OtherNode

from .node import Node


def test_can_serialize_node_with_attributes():
n = Node(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)})
assert dict(name="my_node", type="my_type", a="b", c=(1, 2)) == n.as_dict()


def test_can_deserialize_node_with_attributes():
n = Node(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)})
assert n == Node.from_dict(dict(name="my_node", type="my_type", a="b", c=(1, 2)))


def test_import_path_does_not_matter_for_equality():
n = Node(name="a", type="a_type", attributes=dict(a="b"))
other = OtherNode(name="a", type="a_type", attributes=dict(a="b"))
assert n == other
Loading