Skip to content

Commit

Permalink
Merge branch 'tmp_add_pass_args' into 'main'
Browse files Browse the repository at this point in the history
Add pipe features. Improve FunctionalNode object to reduce memory usage

Closes #1

See merge request ricos/machine_learning/dagstream!19
  • Loading branch information
riku-sakamoto committed Apr 30, 2024
2 parents 56b807b + 817e733 commit 6795b60
Show file tree
Hide file tree
Showing 38 changed files with 1,212 additions and 258 deletions.
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ ignore =
# https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes
# E203, # whitespace before ':'.
# E231, # whitespace after ','.

per-file-ignores =
# imported but unused
__init__.py: F401
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ addopts = [

[tool.isort]
profile = "black"


[tool.coverage.report]
exclude_also = [
"def __repr__",
"@(abc\\.)?abstractmethod",
"raise NotImplementedError"
]
104 changes: 67 additions & 37 deletions src/dagstream/dagstream.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from typing import Callable, Iterable, Optional
from typing import Callable, Iterable, Optional, Union

from dagstream.graph_components import FunctionalDag, IDrawableGraph
from dagstream.graph_components.nodes import (
FunctionalNode,
from dagstream import utils
from dagstream.graph_components import (
FunctionalDag,
IDrawableGraph,
IDrawableNode,
IFunctionalNode,
)
from dagstream.graph_components.nodes import FunctionalNode
from dagstream.utils.errors import DagStreamCycleError


class DagStream(IDrawableGraph):
# This counter aims to distinguish
# between nodes which have the same function name
_SAME_NAME_COUNTER: dict[str, int] = {}

def __init__(self) -> None:
self._functions: set[IFunctionalNode] = set()
self._name2node: dict[str, IFunctionalNode] = {}

def check_exists(self, node: Union[str, IFunctionalNode]) -> bool:
if isinstance(node, IFunctionalNode):
return node.mut_name in self._name2node

if isinstance(node, str):
return node in self._name2node

def check_exists(self, node: IFunctionalNode) -> bool:
return node in self._functions
raise NotImplementedError()

def get_drawable_nodes(self) -> Iterable[IDrawableNode]:
return self._functions
return self._name2node.values()

def get_functions(self) -> set[IFunctionalNode]:
return self._functions
def get_functions(self) -> Iterable[IFunctionalNode]:
return self._name2node.values()

def emplace(self, *functions: Callable) -> tuple[IFunctionalNode, ...]:
"""create a functional node corresponding to each function
Expand All @@ -32,12 +44,27 @@ def emplace(self, *functions: Callable) -> tuple[IFunctionalNode, ...]:
"""

# To ensure orders
_functions: list[IFunctionalNode] = []
_nodes: list[IFunctionalNode] = []
for func in functions:
node = FunctionalNode(func)
_functions.append(node)
self._functions.add(node)
return tuple(_functions)
node_name = self._create_node_name(func)
node = FunctionalNode(func, mut_node_name=node_name)
_nodes.append(node)
self._name2node.update({node.mut_name: node})

return tuple(_nodes)

def _create_node_name(self, user_function: Callable) -> str:
function_name = utils.get_function_name(user_function)
if function_name not in self._name2node:
return function_name

_counter = self._SAME_NAME_COUNTER.get(function_name, 0)
_counter += 1

node_name = f"{function_name}_{_counter}"

self._SAME_NAME_COUNTER[function_name] = _counter
return node_name

def construct(
self, mandatory_nodes: Optional[set[IFunctionalNode]] = None
Expand All @@ -60,64 +87,67 @@ def construct(
self._detect_cycle()

if mandatory_nodes is None:
functions = self._functions
functions = self._name2node
else:
functions = self._extract_functions(mandatory_nodes)

return FunctionalDag(functions)

def _extract_functions(
self, mandatory_nodes: set[IFunctionalNode]
) -> set[IFunctionalNode]:
visited: set[IFunctionalNode] = set()
) -> dict[str, IFunctionalNode]:
visited: dict[str, IFunctionalNode] = {}

for node in mandatory_nodes:
self._extract_subdag(node, visited)
return visited

def _extract_subdag(
self, mandatory_node: IFunctionalNode, visited: set[IFunctionalNode]
self, mandatory_node: IFunctionalNode, visited: dict[str, IFunctionalNode]
):
if mandatory_node in visited:
if mandatory_node.mut_name in visited:
return

visited.add(mandatory_node)
predecessors: list[IFunctionalNode] = [v for v in mandatory_node.predecessors]
visited.update({mandatory_node.mut_name: mandatory_node})
predecessors: list[str] = [v for v in mandatory_node.predecessors]

while len(predecessors) != 0:
node = predecessors.pop()
if node in visited:
node_name = predecessors.pop()
node = self._name2node[node_name]
if node.mut_name in visited:
continue
visited.add(node)
visited.update({node.mut_name: node})

for next_node in node.predecessors:
predecessors.append(next_node)

return None

def _detect_cycle(self):
finished = set()
seen = set()
for func in self._functions:
if func in finished:
finished: set[str] = set()
seen: set[str] = set()
for node in self._name2node.values():
if node.mut_name in finished:
continue
self._dfs_detect_cycle(func, finished, seen)
self._dfs_detect_cycle(node, finished, seen)
return None

def _dfs_detect_cycle(
self,
start: IFunctionalNode,
finished: set[IFunctionalNode],
seen: set[IFunctionalNode],
finished: set[str],
seen: set[str],
) -> None:
for node in start.successors:
if node in finished:
for edge in start.successors:
node = self._name2node[edge.to_node]

if node.mut_name in finished:
continue

if (node in seen) and (node not in finished):
if (node.mut_name in seen) and (node.mut_name not in finished):
raise DagStreamCycleError("Detect cycle in your definition of dag.")

seen.add(node)
seen.add(node.mut_name)
self._dfs_detect_cycle(node, finished, seen)

finished.add(start)
finished.add(start.mut_name)
78 changes: 49 additions & 29 deletions src/dagstream/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import multiprocessing as multi
from typing import Any
from typing import Any, Union

from dagstream.dagstream import IFunctionalNode
from dagstream.graph_components import FunctionalDag


Expand All @@ -26,7 +25,13 @@ def __init__(self, functional_dag: FunctionalDag) -> None:
)
self._dag = functional_dag

def run(self, *args, **kwargs) -> dict[str, Any]:
def run(
self,
*args: Any,
first_args: Union[tuple[Any], None] = None,
save_state: bool = False,
**kwargs,
) -> dict[str, Any]:
"""Run functions sequencially according to static order.
Input parameters are passed to all functions.
Expand All @@ -37,20 +42,28 @@ def run(self, *args, **kwargs) -> dict[str, Any]:
Key is name of function, value is returned objects from each function.
"""
results: dict[str, Any] = {}

while self._dag.is_active:
nodes = self._dag.get_ready()
for node in nodes:
if node.n_predecessors == 0 and first_args is not None:
for arg in first_args:
node.receive_args(arg)

result = node.run(*args, **kwargs)
results.update({node.name: result})
self._dag.done(node)
self._dag.send(node.mut_name, result)
self._dag.done(node.mut_name)

if self._dag.check_last(node) or save_state:
results.update({node.display_name: result})

return results


class StreamParallelExecutor:
"""Parallel Executor for FunctionalDag Object."""

def __init__(self, functional_dag: FunctionalDag, n_processes: int = 1) -> None:
def __init__(self, functional_dag: FunctionalDag, n_process: int = 1) -> None:
"""THIS IS EXPERIMENTAL FEATURE. Parallel Executor for FunctionalDag Object.
Parameters
Expand All @@ -72,16 +85,17 @@ def __init__(self, functional_dag: FunctionalDag, n_processes: int = 1) -> None:
)

self._dag = functional_dag
self._n_processes = n_processes
self._n_processes = n_process
if self._n_processes <= 0:
raise ValueError(f"n_processes must be larger than 0. Input: {n_processes}")

def _worker(self, input_queue: multi.Queue, done_queue: multi.Queue):
for func, args, kwargs in iter(input_queue.get, "STOP"):
result = func.run(*args, **kwargs)
done_queue.put((func, result))

def run(self, *args, **kwargs) -> dict[str, Any]:
raise ValueError(f"n_processes must be larger than 0. Input: {n_process}")

def run(
self,
*args: Any,
first_args: Union[tuple[Any], None] = None,
save_state: bool = False,
**kwargs,
) -> dict[str, Any]:
"""Run functions in parallel.
Parameters are passed to all functions.
Expand All @@ -97,38 +111,44 @@ def run(self, *args, **kwargs) -> dict[str, Any]:
all_processes: list[multi.Process] = []

results: dict[str, Any] = {}
_name2nodes: dict[str, IFunctionalNode] = {
node.name: node for node in self._dag._nodes
}

while self._dag.is_active:
nodes = self._dag.get_ready()

for node_func in nodes:
task_queue.put((node_func, args, kwargs))
for node in nodes:
if node.n_predecessors == 0 and first_args is not None:
for arg in first_args:
node.receive_args(arg)

task_queue.put((node, args, kwargs))

# Start worker processes
n_left_process = self._n_processes - len(all_processes)
for _ in range(n_left_process):
process = multi.Process(
target=self._worker, args=(task_queue, done_queue)
)
process = multi.Process(target=_worker, args=(task_queue, done_queue))
process.start()
all_processes.append(process)

while not done_queue.empty():
_done_node, _result = done_queue.get()

# HACK: When using multiprocessing, id(IFunctionalNode)
# NOTE: When using multiprocessing, id(IFunctionalNode)
# after running is not the same as one before running.
# This operation is incorporated in the Dagstream object,
# after names of all nodes are guranteed to be unique.
done_node = _name2nodes[_done_node.name]
self._dag.done(done_node)
results.update({done_node.name: _result})

self._dag.send(_done_node.mut_name, _result)
self._dag.done(_done_node.mut_name)

if self._dag.check_last(_done_node) or save_state:
results.update({_done_node.mut_name: _result})

if not self._dag.is_active:
for _ in range(self._n_processes):
task_queue.put("STOP")

return results


def _worker(input_queue: multi.Queue, done_queue: multi.Queue):
for func, args, kwargs in iter(input_queue.get, "STOP"):
result = func.run(*args, **kwargs)
done_queue.put((func, result))
6 changes: 6 additions & 0 deletions src/dagstream/graph_components/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from dagstream.graph_components._interface import (
IDagEdge,
IDrawableNode,
IFunctionalNode,
INodeState,
)
from dagstream.graph_components.dags import FunctionalDag, IDrawableGraph # NOQA
Loading

0 comments on commit 6795b60

Please sign in to comment.