Skip to content

Commit

Permalink
Merge pull request #134 from pyiron/macro_interface
Browse files Browse the repository at this point in the history
Macro interface
  • Loading branch information
liamhuber authored Dec 14, 2023
2 parents 503f02f + d9c12c4 commit 43cf8f4
Show file tree
Hide file tree
Showing 6 changed files with 1,565 additions and 1,235 deletions.
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,15 @@ But the intent is to collect them together into a workflow and leverage existing
... import numpy as np
... return np.arange(n)
>>>
>>> @Workflow.wrap_as.macro_node()
... def PlotShiftedSquare(macro):
... macro.shift = macro.create.standard.UserInput(0)
>>> @Workflow.wrap_as.macro_node("fig")
... def PlotShiftedSquare(macro, shift: int = 0):
... macro.arange = Arange()
... macro.plot = macro.create.plotting.Scatter(
... x=macro.arange + macro.shift,
... x=macro.arange + shift,
... y=macro.arange**2
... )
... macro.inputs_map = {
... "shift__user_input": "shift",
... "arange__n": "n",
... }
... macro.outputs_map = {"plot__fig": "fig"}
... macro.inputs_map = {"arange__n": "n"} # Expose arange input
... return macro.plot
>>>
>>> wf = Workflow("plot_with_and_without_shift")
>>> wf.n = wf.create.standard.UserInput()
Expand All @@ -91,7 +87,7 @@ Which gives the workflow `diagram`

![](docs/_static/readme_diagram.png)

And the resulting `fig`
And the resulting figure (when axes are not cleared)

![](docs/_static/readme_fig.png)

Expand Down
1,635 changes: 442 additions & 1,193 deletions notebooks/deepdive.ipynb

Large diffs are not rendered by default.

788 changes: 778 additions & 10 deletions notebooks/quickstart.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyiron_workflow/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def __init__(

if depth > 0:
from pyiron_workflow.composite import Composite

# Janky in-line import to avoid circular imports but only look for children
# where they exist (since SingleValue nodes now actually do something on
# failed attribute access, i.e. use it as delayed access on their output)
Expand Down
258 changes: 237 additions & 21 deletions pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from __future__ import annotations

from functools import partialmethod
from typing import Optional, TYPE_CHECKING
import inspect
from typing import get_type_hints, Literal, Optional

from pyiron_workflow.channels import InputData, OutputData
from bidict import bidict

from pyiron_workflow.channels import InputData, OutputData, NotData
from pyiron_workflow.composite import Composite
from pyiron_workflow.has_channel import HasChannel
from pyiron_workflow.io import Outputs, Inputs

if TYPE_CHECKING:
from bidict import bidict
from pyiron_workflow.output_parser import ParseOutput


class Macro(Composite):
Expand All @@ -22,21 +24,43 @@ class Macro(Composite):
pre-populated workflow that is the same every time you instantiate it.
At instantiation, the macro uses a provided callable to build and wire the graph,
then builds a static IO interface for this graph. (See the parent class docstring
for more details, but by default and as with workflows, unconnected IO is
represented by combining node and channel names, but can be controlled in more
detail with maps.)
This IO is _value linked_ to the child IO, so that their values stay synchronized,
then builds a static IO interface for this graph.
This callable must use the macro object itself as the first argument (e.g. adding
nodes to it).
As with `Workflow` objects, macros leverage `inputs_map` and `outputs_map` to
control macro-level IO access to child IO.
As with `Workflow`, default behaviour is to expose all unconnected child IO.
The provided callable may optionally specify further args and kwargs, which are used
to pre-populate the macro with `UserInput` nodes;
This can be especially helpful when more than one child node needs access to the
same input value.
Similarly, the callable may return any number of child nodes' output channels (or
the node itself in the case of `SingleValue` nodes) and commensurate
`output_labels` to define macro-level output.
These function-like definitions of the graph creator callable can be used
independently or together.
Each that is used switches its IO map to a "whitelist" paradigm, so any I/O _not_
provided in the callable signature/return values and output labels will be disabled.
Manual modifications of the IO maps inside the callable always take priority over
this whitelisting behaviour, so you always retain full control over what IO is
exposed, and the whitelisting is only for your convenience.
Macro IO is _value linked_ to the child IO, so that their values stay synchronized,
but the child nodes of a macro form an isolated sub-graph.
As with function nodes, sub-classes may define a method for creating the graph.
As with workflows, all DAG macros can determine their execution flow automatically,
if you have cycles in your data flow, or otherwise want more control over the
execution, all you need to do is specify the `node.signals.input.run` connections
and `starting_nodes` list yourself.
As with function nodes, subclasses of `Macro` may define a method for creating the
graph.
As with `Workflow``, all DAG macros can determine their execution flow
automatically, if you have cycles in your data flow, or otherwise want more control
over the execution, all you need to do is specify the `node.signals.input.run`
connections and `starting_nodes` list yourself.
If only _one_ of these is specified, you'll get an error, but if you've provided
both then no further checks of their validity/reasonableness are performed, so be
careful.
Unlike `Workflow`, this execution flow automation is set up once at instantiation;
If the macro is modified post-facto, you may need to manually re-invoke
`configure_graph_execution`.
Promises (in addition parent class promises):
- IO is...
Expand Down Expand Up @@ -91,12 +115,13 @@ class Macro(Composite):
`partialmethod`:
>>> from functools import partialmethod
>>> class AddThreeMacro(Macro):
... def build_graph(self):
... @staticmethod
... def graph_creator(self):
... add_three_macro(self)
...
... __init__ = partialmethod(
... Macro.__init__,
... build_graph,
... None, # We directly define the graph creator method on the class
... )
>>>
>>> macro = AddThreeMacro()
Expand Down Expand Up @@ -169,6 +194,62 @@ class Macro(Composite):
>>> adds_six_macro.three = add_two
>>> adds_six_macro(one__x=1)
{'three__result': 7}
Instead of controlling the IO interface with dictionary maps, we can instead
provide a more `Function(Node)`-like definition of the `graph_creator` by
adding args and/or kwargs to the signature (under the hood, this dynamically
creates new `UserInput` nodes before running the rest of the graph creation),
and/or returning child channels (or whole children in the case of `SingleValue`
nodes) and providing commensurate `output_labels`.
This process switches us from the `Workflow` default of exposing all
unconnected child IO, to a "whitelist" paradigm of _only_ showing the IO that
we exposed by our function defintion.
(Note: any `.inputs_map` or `.outputs_map` explicitly defined in the
`graph_creator` still takes precedence over this whitelisting! So you always
retain full control over what IO gets exposed.)
E.g., these two definitions are perfectly equivalent:
>>> @Macro.wrap_as.macro_node("lout", "n_plus_2")
... def LikeAFunction(macro, lin: list, n: int = 1):
... macro.plus_two = n + 2
... macro.sliced_list = lin[n:macro.plus_two]
... macro.double_fork = 2 * n
... # ^ This is vestigial, just to show we don't need to blacklist it in a
... # whitelist-paradigm
... return macro.sliced_list, macro.plus_two.channel
>>>
>>> like_functions = LikeAFunction(lin=[1,2,3,4,5,6], n=2)
>>> like_functions()
{'n_plus_2': 4, 'lout': [3, 4]}
>>> @Macro.wrap_as.macro_node()
... def WithIOMaps(macro):
... macro.list_in = macro.create.standard.UserInput()
... macro.list_in.inputs.user_input.type_hint = list
... macro.forked = macro.create.standard.UserInput(2)
... macro.forked.inputs.user_input.type_hint = int
... macro.n_plus_2 = macro.forked + 2
... macro.sliced_list = macro.list_in[macro.forked:macro.n_plus_2]
... macro.double_fork = 2 * macro.forked
... macro.inputs_map = {
... "list_in__user_input": "lin",
... macro.forked.inputs.user_input.scoped_label: "n",
... "n_plus_2__other": None,
... "list_in__user_input_Slice_forked__user_input_n_plus_2__add_None__step": None,
... macro.double_fork.inputs.other.scoped_label: None,
... }
... macro.outputs_map = {
... macro.sliced_list.outputs.getitem.scoped_label: "lout",
... macro.n_plus_2.outputs.add.scoped_label: "n_plus_2",
... "double_fork__rmul": None
... }
>>>
>>> with_maps = WithIOMaps(lin=[1,2,3,4,5,6], n=2)
>>> with_maps()
{'n_plus_2': 4, 'lout': [3, 4]}
Here we've leveraged the macro-creating decorator, but this works the same way
using the `Macro` class directly.
"""

def __init__(
Expand All @@ -180,6 +261,7 @@ def __init__(
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
if not callable(graph_creator):
Expand Down Expand Up @@ -207,14 +289,138 @@ def __init__(
inputs_map=inputs_map,
outputs_map=outputs_map,
)
self.graph_creator(self)
output_labels = self._validate_output_labels(output_labels)

ui_nodes = self._prepopulate_ui_nodes_from_graph_creator_signature()
returned_has_channel_objects = self.graph_creator(self, *ui_nodes)
self._configure_graph_execution()

# Update IO map(s) if a function-like graph creator interface was used
if len(ui_nodes) > 0:
self._whitelist_inputs_map(*ui_nodes)
if returned_has_channel_objects is not None:
self._whitelist_outputs_map(
output_labels,
*(
(returned_has_channel_objects,)
if not isinstance(returned_has_channel_objects, tuple)
else returned_has_channel_objects
),
)

self._inputs: Inputs = self._build_inputs()
self._outputs: Outputs = self._build_outputs()

self.set_input_values(**kwargs)

def _validate_output_labels(self, output_labels) -> tuple[str]:
"""
Ensure that output_labels, if provided, are commensurate with graph creator
return values, if provided, and return them as a tuple.
"""
graph_creator_returns = ParseOutput(self.graph_creator).output
output_labels = (
(output_labels,) if isinstance(output_labels, str) else output_labels
)
if graph_creator_returns is not None or output_labels is not None:
error_suffix = (
f"but {self.label} macro got return values: "
f"{graph_creator_returns} and labels: {output_labels}."
)
try:
if len(output_labels) != len(graph_creator_returns):
raise ValueError(
"The number of return values in the graph creator must exactly "
"match the number of output labels provided, " + error_suffix
)
except TypeError:
raise TypeError(
f"Output labels and graph creator return values must either both "
f"or neither be present, " + error_suffix
)
return () if output_labels is None else tuple(output_labels)

def _prepopulate_ui_nodes_from_graph_creator_signature(self):
hints_dict = get_type_hints(self.graph_creator)
interface_nodes = ()
for i, (arg_name, inspected_value) in enumerate(
inspect.signature(self.graph_creator).parameters.items()
):
if i == 0:
continue # Skip the macro argument itself, it's like `self` here

default = (
NotData
if inspected_value.default is inspect.Parameter.empty
else inspected_value.default
)
node = self.create.standard.UserInput(default, label=arg_name, parent=self)
node.inputs.user_input.default = default
try:
node.inputs.user_input.type_hint = hints_dict[arg_name]
except KeyError:
pass # If there's no hint that's fine
interface_nodes += (node,)

return interface_nodes

def _whitelist_inputs_map(self, *ui_nodes) -> None:
"""
Updates the inputs map so each UI node's output channel is available directly
under the node label, and updates the map to disable all other input that
wasn't explicitly mapped already.
"""
self.inputs_map = self._hide_non_whitelisted_io(
self._whitelist_map(
self.inputs_map, tuple(n.label for n in ui_nodes), ui_nodes
),
"inputs",
)

def _whitelist_outputs_map(
self, output_labels: tuple[str], *creator_returns: HasChannel
):
"""
Updates the outputs map so objects returned by the graph creator directly
leverage the supplied output labels, and updates the map to disable all other
output that wasn't explicitly mapped already.
"""
self.outputs_map = self._hide_non_whitelisted_io(
self._whitelist_map(self.outputs_map, output_labels, creator_returns),
"outputs",
)

@staticmethod
def _whitelist_map(
io_map: bidict, new_labels: tuple[str], has_channel_objects: tuple[HasChannel]
) -> bidict:
"""
Update an IO map to give new labels to the channels of a bunch of `HasChannel`
objects.
"""
io_map = bidict({}) if io_map is None else io_map
for new_label, ui_node in zip(new_labels, has_channel_objects):
# White-list everything not already in the map
if ui_node.channel.scoped_label not in io_map.keys():
io_map[ui_node.channel.scoped_label] = new_label
return io_map

def _hide_non_whitelisted_io(
self, io_map: bidict, i_or_o: Literal["inputs", "outputs"]
) -> dict:
"""
Make a new map dictionary with `None` entries for each channel that isn't
already in the provided map bidict. I.e. blacklist things we didn't whitelist.
"""
io_map = dict(io_map)
# We do it in two steps like this to leverage the bidict security on the setter
# Since bidict can't handle getting `None` (i.e. disable) for multiple keys
for node in self.nodes.values():
for channel in getattr(node, i_or_o):
if channel.scoped_label not in io_map.keys():
io_map[channel.scoped_label] = None
return io_map

def _get_linking_channel(
self,
child_reference_channel: InputData | OutputData,
Expand Down Expand Up @@ -288,7 +494,7 @@ def to_workfow(self):
raise NotImplementedError


def macro_node(**node_class_kwargs):
def macro_node(*output_labels, **node_class_kwargs):
"""
A decorator for dynamically creating macro classes from graph-creating functions.
Expand All @@ -297,15 +503,25 @@ def macro_node(**node_class_kwargs):
graph-creating function, and whose signature is modified to exclude this function
and provided kwargs.
Optionally takes output labels as args in case the node function uses the
like-a-function interface to define its IO. (The number of output labels must match
number of channel-like objects returned by the graph creating function _exactly_.)
Optionally takes any keyword arguments of `Macro`.
"""
output_labels = None if len(output_labels) == 0 else output_labels

def as_node(graph_creator: callable[[Macro], None]):
def as_node(graph_creator: callable[[Macro, ...], Optional[tuple[HasChannel]]]):
return type(
graph_creator.__name__,
(Macro,), # Define parentage
{
"__init__": partialmethod(Macro.__init__, None, **node_class_kwargs),
"__init__": partialmethod(
Macro.__init__,
None,
output_labels=output_labels,
**node_class_kwargs,
),
"graph_creator": staticmethod(graph_creator),
},
)
Expand Down
Loading

0 comments on commit 43cf8f4

Please sign in to comment.