diff --git a/notebooks/workflow_example.ipynb b/notebooks/workflow_example.ipynb index f6fd2ab9..229760f8 100644 --- a/notebooks/workflow_example.ipynb +++ b/notebooks/workflow_example.ipynb @@ -119,8 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "# pm_node.run()\n", - "pm_node.update()" + "# pm_node.run()" ] }, { @@ -128,9 +127,7 @@ "id": "48b0db5a-548e-4195-8361-76763ddf0474", "metadata": {}, "source": [ - "Using the softer `update()` call checks to make sure the input is `ready` before moving on to `run()`, avoiding this error. In this case, `update()` sees we have no input an aborts by returning `None`.\n", - "\n", - "(Note: If you _do_ swap `update()` to `run()` in this cell, not only will you get the expected error, but `pm_node` will also set its `failed` attribute to `True` -- this will prevent it from being `ready` again until you manually reset `pm_node.failed = False`.)" + "Not only will you get the expected error, but `pm_node` will also set its `failed` attribute to `True` -- this will prevent it from being `ready` again until you manually reset `pm_node.failed = False`." ] }, { @@ -176,7 +173,7 @@ "id": "c54a691e-a075-4d41-bc0f-3a990857a27a", "metadata": {}, "source": [ - "Alternatively, the `run()` command (and `update()` when it proceeds to execution) just return the function's return value:" + "Alternatively, the `run()` command just return the function's return value:" ] }, { @@ -241,7 +238,7 @@ "id": "58ed9b25-6dde-488d-9582-d49d405793c6", "metadata": {}, "source": [ - "This node also exploits type hinting! `run()` will always force the execution, but `update()` will not only check if the data is there, but also if it is the right type:" + "This node also exploits type hinting! `run()` will check that input values conform to type hints before computing anything. Failing at this stage won't actually cause the node to have a `failed` status, so you can just re-run it once the input is fixed." ] }, { @@ -249,46 +246,27 @@ "execution_count": 10, "id": "ac0fe993-6c82-48c8-a780-cbd0c97fc386", "metadata": {}, - "outputs": [], - "source": [ - "adder_node.inputs.x = \"not an integer\"\n", - "adder_node.inputs.x.type_hint, type(adder_node.inputs.x.value)\n", - "adder_node.update()\n", - "# No error because the update doesn't trigger a run since the type hint is not satisfied" - ] - }, - { - "cell_type": "markdown", - "id": "2737de39-6e75-44e1-b751-6315afe5c676", - "metadata": {}, - "source": [ - "Since the execution never happened, the output is unchanged" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "bcbd17f1-a3e4-44f0-bde1-cbddc51c5d73", - "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "1" + "(int, str)" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "adder_node.outputs.sum_.value" + "adder_node.inputs.x = \"not an integer\"\n", + "adder_node.inputs.x.type_hint, type(adder_node.inputs.x.value)\n", + "# adder_node.run()" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "15742a49-4c23-4d4a-84d9-9bf19677544c", "metadata": {}, "outputs": [ @@ -298,14 +276,14 @@ "3" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adder_node.inputs.x = 2\n", - "adder_node.update()" + "adder_node.run()" ] }, { @@ -318,7 +296,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "0c8f09a7-67c4-4c6c-a021-e3fea1a16576", "metadata": {}, "outputs": [ @@ -328,7 +306,7 @@ "30" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -348,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "69b59737-9e09-4b4b-a0e2-76a09de02c08", "metadata": {}, "outputs": [ @@ -358,7 +336,7 @@ "31" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -391,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "61b43a9b-8dad-48b7-9194-2045e465793b", "metadata": {}, "outputs": [], @@ -401,7 +379,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "647360a9-c971-4272-995c-aa01e5f5bb83", "metadata": {}, "outputs": [ @@ -438,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "b8c845b7-7088-43d7-b106-7a6ba1c571ec", "metadata": {}, "outputs": [ @@ -482,7 +460,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "id": "2e418abf-7059-4e1e-9b9f-b3dc0a4b5e35", "metadata": { "tags": [] @@ -532,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "id": "59c29856-c77e-48a1-9f17-15d4c58be588", "metadata": {}, "outputs": [ @@ -568,7 +546,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "id": "1a4e9693-0980-4435-aecc-3331d8b608dd", "metadata": {}, "outputs": [], @@ -580,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "id": "7c4d314b-33bb-4a67-bfb9-ed77fba3949c", "metadata": {}, "outputs": [ @@ -619,7 +597,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "id": "61ae572f-197b-4a60-8d3e-e19c1b9cc6e2", "metadata": {}, "outputs": [ @@ -659,24 +637,24 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "id": "6569014a-815b-46dd-8b47-4e1cd4584b3b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([0.6816222 , 0.60285251, 0.31984666, 0.38336884, 0.95586544,\n", - " 0.20915899, 0.73614411, 0.67259937, 0.84499503, 0.10539287])" + "array([0.91077351, 0.33860412, 0.59806048, 0.66528464, 0.80125293,\n", + " 0.31981677, 0.54395521, 0.4926537 , 0.52626431, 0.7848854 ])" ] }, - "execution_count": 23, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -732,7 +710,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "id": "1cd000bd-9b24-4c39-9cac-70a3291d0660", "metadata": {}, "outputs": [], @@ -759,7 +737,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "id": "7964df3c-55af-4c25-afc5-9e07accb606a", "metadata": {}, "outputs": [ @@ -800,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "id": "809178a5-2e6b-471d-89ef-0797db47c5ad", "metadata": {}, "outputs": [ @@ -854,7 +832,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "id": "52c48d19-10a2-4c48-ae81-eceea4129a60", "metadata": {}, "outputs": [ @@ -864,7 +842,7 @@ "{'ay': 3, 'a + b + 2': 7}" ] }, - "execution_count": 27, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -884,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "id": "bb35ba3e-602d-4c9c-b046-32da9401dd1c", "metadata": {}, "outputs": [ @@ -894,7 +872,7 @@ "(7, 3)" ] }, - "execution_count": 28, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -913,7 +891,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "id": "2b0d2c85-9049-417b-8739-8a8432a1efbe", "metadata": {}, "outputs": [ @@ -932,127 +910,127 @@ "clustersimple\n", "\n", "simple: Workflow\n", + "\n", + "clustersimpleInputs\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Inputs\n", + "\n", "\n", "clustersimpleOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Outputs\n", "\n", "\n", "clustersimplea\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "a: AddOne\n", "\n", "\n", "clustersimpleaInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Inputs\n", "\n", "\n", "clustersimpleaOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Outputs\n", "\n", "\n", "clustersimpleb\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "b: AddOne\n", "\n", "\n", "clustersimplebInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Inputs\n", "\n", "\n", "clustersimplebOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Outputs\n", "\n", "\n", "clustersimplesum\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "sum: AddNode\n", "\n", "\n", "clustersimplesumInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Inputs\n", "\n", "\n", "clustersimplesumOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Outputs\n", "\n", - "\n", - "clustersimpleInputs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Inputs\n", - "\n", "\n", "\n", "clustersimpleInputsrun\n", @@ -1231,10 +1209,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1255,14 +1233,14 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "id": "ae500d5e-e55b-432c-8b5f-d5892193cdf5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "11fa1336d10a42f4936ce22a299f191d", + "model_id": "a289b513c50d41989670c5b4ac9df823", "version_major": 2, "version_minor": 0 }, @@ -1289,10 +1267,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 30, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, @@ -1333,7 +1311,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "id": "2114d0c3-cdad-43c7-9ffa-50c36d56d18f", "metadata": {}, "outputs": [ @@ -1541,10 +1519,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 31, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1565,7 +1543,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 31, "id": "c71a8308-f8a1-4041-bea0-1c841e072a6d", "metadata": {}, "outputs": [], @@ -1575,7 +1553,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "id": "2b9bb21a-73cd-444e-84a9-100e202aa422", "metadata": {}, "outputs": [ @@ -1593,7 +1571,7 @@ "13" ] }, - "execution_count": 33, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1632,7 +1610,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 33, "id": "3668f9a9-adca-48a4-84ea-13add965897c", "metadata": {}, "outputs": [ @@ -1642,7 +1620,7 @@ "{'intermediate': 102, 'plus_three': 103}" ] }, - "execution_count": 34, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1680,7 +1658,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 34, "id": "9aaeeec0-5f88-4c94-a6cc-45b56d2f0111", "metadata": {}, "outputs": [], @@ -1710,7 +1688,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 35, "id": "a832e552-b3cc-411a-a258-ef21574fc439", "metadata": {}, "outputs": [], @@ -1737,7 +1715,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "id": "b764a447-236f-4cb7-952a-7cba4855087d", "metadata": {}, "outputs": [ @@ -2961,10 +2939,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 37, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -2975,7 +2953,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "id": "b51bef25-86c5-4d57-80c1-ab733e703caf", "metadata": {}, "outputs": [ @@ -2996,7 +2974,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 38, "id": "091e2386-0081-436c-a736-23d019bd9b91", "metadata": {}, "outputs": [ @@ -3037,7 +3015,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 39, "id": "4cdffdca-48d3-4486-9045-48102c7e5f31", "metadata": {}, "outputs": [ @@ -3075,7 +3053,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 40, "id": "ed4a3a22-fc3a-44c9-9d4f-c65bc1288889", "metadata": {}, "outputs": [ @@ -3097,7 +3075,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 41, "id": "5a985cbf-c308-4369-9223-b8a37edb8ab1", "metadata": {}, "outputs": [ @@ -3187,7 +3165,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 42, "id": "0b373764-b389-4c24-8086-f3d33a4f7fd7", "metadata": {}, "outputs": [ @@ -3201,7 +3179,7 @@ " 17.230249999999995]" ] }, - "execution_count": 43, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -3238,7 +3216,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 43, "id": "0dd04b4c-e3e7-4072-ad34-58f2c1e4f596", "metadata": {}, "outputs": [ @@ -3297,7 +3275,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 44, "id": "2dfb967b-41ac-4463-b606-3e315e617f2a", "metadata": {}, "outputs": [ @@ -3321,7 +3299,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 45, "id": "2e87f858-b327-4f6b-9237-c8a557f29aeb", "metadata": {}, "outputs": [ @@ -3329,12 +3307,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.406 > 0.2\n", - "0.999 > 0.2\n", - "0.827 > 0.2\n", - "0.417 > 0.2\n", - "0.120 <= 0.2\n", - "Finally 0.120\n" + "0.064 <= 0.2\n", + "Finally 0.064\n" ] } ], diff --git a/pyiron_workflow/composite.py b/pyiron_workflow/composite.py index 1bb1abc2..270992e0 100644 --- a/pyiron_workflow/composite.py +++ b/pyiron_workflow/composite.py @@ -10,12 +10,12 @@ from typing import Literal, Optional, TYPE_CHECKING from bidict import bidict -from toposort import toposort_flatten, CircularDependencyError from pyiron_workflow.interfaces import Creator, Wrappers from pyiron_workflow.io import Outputs, Inputs from pyiron_workflow.node import Node from pyiron_workflow.node_package import NodePackage +from pyiron_workflow.topology import set_run_connections_according_to_linear_dag from pyiron_workflow.util import logger, DotDict, SeabornColors if TYPE_CHECKING: @@ -189,74 +189,11 @@ def disconnect_run(self) -> list[tuple[Channel, Channel]]: def set_run_signals_to_dag_execution(self): """ Disconnects all `signals.input.run` connections among children and attempts to - reconnect these according to the DAG flow of the data. - - Raises: - ValueError: When the data connections do not form a DAG. - """ - self.disconnect_run() - self._set_run_connections_and_starting_nodes_according_to_linear_dag() - # TODO: Replace this linear setup with something more powerful - - def _set_run_connections_and_starting_nodes_according_to_linear_dag(self): - # This is the most primitive sort of topological exploitation we can do - # It is not efficient if the nodes have executors and can run in parallel - try: - # Topological sorting ensures that all input dependencies have been - # executed before the node depending on them gets run - # The flattened part is just that we don't care about topological - # generations that are mutually independent (inefficient but easier for now) - execution_order = toposort_flatten(self.get_data_digraph()) - except CircularDependencyError as e: - raise ValueError( - f"Detected a cycle in the data flow topology, unable to automate the " - f"execution of non-DAGs: cycles found among {e.data}" - ) - - for i, label in enumerate(execution_order[:-1]): - next_node = execution_order[i + 1] - self.nodes[label] > self.nodes[next_node] - self.starting_nodes = [self.nodes[execution_order[0]]] - - def get_data_digraph(self) -> dict[str, set[str]]: - """ - Builds a directed graph of node labels based on data connections between nodes - directly owned by this composite -- i.e. does not worry about data connections - which are entirely internal to an owned sub-graph. - - Returns: - dict[str, set[str]]: A dictionary of nodes and the nodes they depend on for - data. - - Raises: - ValueError: When a node appears in its own input. + reconnect these according to the DAG flow of the data. On success, sets the + starting nodes to just be the upstream-most node in this linear DAG flow. """ - digraph = {} - - for node in self.nodes.values(): - node_dependencies = [] - for channel in node.inputs: - locally_scoped_dependencies = [] - for upstream in channel.connections: - if upstream.node.parent is self: - locally_scoped_dependencies.append(upstream.node.label) - elif channel.node.get_first_shared_parent(upstream.node) is self: - locally_scoped_dependencies.append( - upstream.node.get_parent_proximate_to(self).label - ) - node_dependencies.extend(locally_scoped_dependencies) - node_dependencies = set(node_dependencies) - if node.label in node_dependencies: - # the toposort library has a - # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) - # That self-dependency isn't caught, so we catch it manually here. - raise ValueError( - f"Detected a cycle in the data flow topology, unable to automate " - f"the execution of non-DAGs: {node.label} appears in its own input." - ) - digraph[node.label] = node_dependencies - - return digraph + _, upstream_most_node = set_run_connections_according_to_linear_dag(self.nodes) + self.starting_nodes = [upstream_most_node] def _build_io( self, diff --git a/pyiron_workflow/function.py b/pyiron_workflow/function.py index 34b85c91..aaea64b7 100644 --- a/pyiron_workflow/function.py +++ b/pyiron_workflow/function.py @@ -59,9 +59,9 @@ class Function(Node): Further, functions with multiple return branches that return different types or numbers of return values may or may not work smoothly, depending on the details. - Output is updated in the `process_run_result` inside the parent class `finish_run` - call, such that output data gets pushed after the node stops running but before - then `ran` signal fires: run, process and push result, ran. + Output is updated according to `process_run_result` -- which gets invoked by the + post-run callbacks defined in `Node` -- such that run results are used to populate + the output channels. After a node is instantiated, its input can be updated as `*args` and/or `**kwargs` on call. @@ -103,7 +103,7 @@ class Function(Node): run: Parse and process the input, execute the engine, process the results and update the output. disconnect: Disconnect all data and signal IO connections. - update_input: Allows input channels' values to be updated without any running. + set_input_values: Allows input channels' values to be updated without any running. Examples: At the most basic level, to use nodes all we need to do is provide the @@ -173,9 +173,7 @@ class Function(Node): using good variable names and returning those variables instead of using `output_labels`. If we force the node to `run()` (or call it) with bad types, it will raise an - error. - But, if we use the gentler `update()`, it will check types first and simply - return `None` if the input is not all `ready`. + error: >>> from typing import Union >>> >>> def hinted_example( @@ -186,13 +184,17 @@ class Function(Node): ... return p1, m1 >>> >>> plus_minus_1 = Function(hinted_example, x="not an int") - >>> plus_minus_1.update() - >>> plus_minus_1.outputs.to_value_dict() - {'p1': , - 'm1': } + >>> plus_minus_1.run() + ValueError: hinted_example received a run command but is not ready. The node + should be neither running nor failed, and all input values should conform to + type hints: + running: False + failed: False + x ready: False + y ready: True Here, even though all the input has data, the node sees that some of it is the - wrong type and so the automatic updates don't proceed all the way to a run. + wrong type and so (by default) the run raises an error right away. Note that the type hinting doesn't actually prevent us from assigning bad values directly to the channel (although it will, by default, prevent connections _between_ type-hinted channels with incompatible hints), but it _does_ stop the @@ -333,7 +335,7 @@ def __init__( # TODO: Parse output labels from the node function in case output_labels is None self.signals = self._build_signal_channels() - self.update_input(*args, **kwargs) + self.set_input_values(*args, **kwargs) def _get_output_labels(self, output_labels: str | list[str] | tuple[str] | None): """ @@ -516,7 +518,7 @@ def _convert_input_args_and_kwargs_to_input_kwargs(self, *args, **kwargs): return kwargs - def update_input(self, *args, **kwargs) -> None: + def set_input_values(self, *args, **kwargs) -> None: """ Match positional and keyword arguments to input channels and update input values. @@ -527,7 +529,7 @@ def update_input(self, *args, **kwargs) -> None: pairs. """ kwargs = self._convert_input_args_and_kwargs_to_input_kwargs(*args, **kwargs) - return super().update_input(**kwargs) + return super().set_input_values(**kwargs) def __call__(self, *args, **kwargs) -> None: kwargs = self._convert_input_args_and_kwargs_to_input_kwargs(*args, **kwargs) diff --git a/pyiron_workflow/macro.py b/pyiron_workflow/macro.py index 1fb56756..e81d4975 100644 --- a/pyiron_workflow/macro.py +++ b/pyiron_workflow/macro.py @@ -183,7 +183,7 @@ def __init__( self._inputs: Inputs = self._build_inputs() self._outputs: Outputs = self._build_outputs() - self.update_input(**kwargs) + self.set_input_values(**kwargs) def _get_linking_channel( self, diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 1199d380..d2f7a78b 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -16,7 +16,7 @@ from pyiron_workflow.files import DirectoryObject from pyiron_workflow.has_to_dict import HasToDict from pyiron_workflow.io import Signals, InputSignal, OutputSignal -from pyiron_workflow.type_hinting import valid_value +from pyiron_workflow.topology import set_run_connections_according_to_linear_dag from pyiron_workflow.util import SeabornColors if TYPE_CHECKING: @@ -24,7 +24,7 @@ from pyiron_workflow.channels import Channel from pyiron_workflow.composite import Composite - from pyiron_workflow.io import IO, Inputs, Outputs + from pyiron_workflow.io import Inputs, Outputs def manage_status(node_method): @@ -80,8 +80,11 @@ class Node(HasToDict, ABC): These labels also help to identify nodes in the wider context of (potentially nested) computational graphs. - By default, nodes' signals input comes with `run` and `ran` IO ports which force - the `run()` method and which emit after `finish_run()` is completed, respectfully. + By default, nodes' signals input comes with `run` and `ran` IO ports, which invoke + the `run()` method and emit after running the node, respectfully. + (Whether we get all the way to emitting the `ran` signal depends on how the node + was invoked -- it is possible to computing things with the node without sending + any more signals downstream.) These signal connections can be made manually by reference to the node signals channel, or with the `>` symbol to indicate a flow of execution. This syntactic sugar can be mixed between actual signal channels (output signal > input signal), @@ -101,8 +104,8 @@ class Node(HasToDict, ABC): Nodes have a status, which is currently represented by the `running` and `failed` boolean flag attributes. - Their value is controlled automatically in the defined `run` and `finish_run` - methods. + These are updated automatically when the node's operation is invoked, e.g. with + `run`, `execute`, `pull`, or by calling the node instance. Nodes can be run on the main python process that owns them, or by setting their `executor` attribute to `True`, in which case a @@ -140,6 +143,8 @@ class Node(HasToDict, ABC): owning this, if any. ready (bool): Whether the inputs are all ready and the node is neither already running nor already failed. + run_args (dict): **Abstract** the argmuments to use for actually running the + node. Must be specified in child classes. running (bool): Whether the node has called `run` and has not yet received output from this call. (Default is False.) signals (pyiron_workflow.io.Signals): A container for input and output @@ -152,11 +157,20 @@ class Node(HasToDict, ABC): initialized. Methods: + __call__: Update input values (optional) then run the node (without firing off + .the `ran` signal, so nothing happens farther downstream). disconnect: Remove all connections, including signals. draw: Use graphviz to visualize the node, its IO and, if composite in nature, its internal structure. - on_run: **Abstract.** Do the thing. - run: A wrapper to handle all the infrastructure around executing `on_run`. + execute: Run the node, but right here, right now, and with the input it + currently has. + on_run: **Abstract.** Do the thing. What thing must be specified by child + classes. + pull: Run everything upstream, then run this node (but don't fire off the `ran` + signal, so nothing happens farther downstream). + run: Run the node function from `on_run`. Handles status, whether to run on an + executor, firing the `ran` signal, and callbacks (if an executor is used). + set_input_values: Allows input channels' values to be updated without any running. """ def __init__( @@ -182,8 +196,6 @@ def __init__( parent.add(self) self.running = False self.failed = False - # TODO: Move from a traditional "sever" to a tinybase "executor" - # TODO: Provide support for actually computing stuff with the executor self.signals = self._build_signal_channels() self._working_directory = None self.executor = False @@ -231,74 +243,68 @@ def process_run_result(self, run_output): run_output: The results of a `self.on_run(self.run_args)` call. """ - @manage_status - def execute(self): - """ - Perform the node's operation with its current data. - - Execution happens directly on this python process. - """ - return self.process_run_result(self.on_run(**self.run_args)) - - def run(self): + def run( + self, + first_fetch_input: bool = True, + then_emit_output_signals: bool = True, + force_local_execution: bool = False, + check_readiness: bool = True, + ): """ Update the input (with whatever is currently available -- does _not_ trigger - any other nodes to run) and use it to perform the node's operation. + any other nodes to run) and use it to perform the node's operation. After, + emit all output signals. If executor information is specified, execution happens on that process, a callback is registered, and futures object is returned. - Once complete, fire `ran` signal to propagate execution in the computation graph - that owns this node (if any). - """ - self.update_input() - return self._run(finished_callback=self.finish_run_and_emit_ran) - - def pull(self): - raise NotImplementedError - # Need to implement everything for on-the-fly construction of the upstream - # graph and its execution - # Then, - self.update_input() - return self._run(finished_callback=self.finish_run) - - def update_input(self, **kwargs) -> None: - """ - Fetch the latest and highest-priority input values from connections, then - overwrite values with keywords arguments matching input channel labels. - - Any channel that has neither a connection nor a kwarg update at time of call is - left unchanged. - - Throws a warning if a keyword is provided that cannot be found among the input - keys. - - If you really want to update just a single value without any other side-effects, - this can always be accomplished by following the full semantic path to the - channel's value: `my_node.input.my_channel.value = "foo"`. - Args: - **kwargs: input key - input value (including channels for connection) pairs. - """ - self.inputs.fetch() - for k, v in kwargs.items(): - if k in self.inputs.labels: - self.inputs[k] = v - else: - warnings.warn( - f"The keyword '{k}' was not found among input labels. If you are " - f"trying to update a node keyword, please use attribute assignment " - f"directly instead of calling" - ) + first_fetch_input (bool): Whether to first update inputs with the + highest-priority connections holding data. (Default is True.) + then_emit_output_signals (bool): Whether to fire off all output signals + (e.g. `ran`) afterwards. (Default is True.) + force_local_execution (bool): Whether to ignore any executor settings and + force the computation to run locally. (Default is False.) + check_readiness (bool): Whether to raise an exception if the node is not + `ready` to run after fetching new input. (Default is True.) + + Returns: + (Any | Future): The result of running the node, or a futures object (if + running on an executor). + """ + if first_fetch_input: + self.inputs.fetch() + if check_readiness and not self.ready: + input_readiness = "\n".join( + [f"{k} ready: {v.ready}" for k, v in self.inputs.items()] + ) + raise ValueError( + f"{self.label} received a run command but is not ready. The node " + f"should be neither running nor failed, and all input values should" + f" conform to type hints:\n" + f"running: {self.running}\n" + f"failed: {self.failed}\n" + input_readiness + ) + return self._run( + finished_callback=self._finish_run_and_emit_ran + if then_emit_output_signals + else self._finish_run, + force_local_execution=force_local_execution, + ) @manage_status - def _run(self, finished_callback: callable) -> Any | tuple | Future: + def _run( + self, + finished_callback: callable, + force_local_execution: bool, + ) -> Any | tuple | Future: """ Executes the functionality of the node defined in `on_run`. Handles the status of the node, and communicating with any remote computing resources. """ - if not self.executor: + if force_local_execution or not self.executor: + # Run locally run_output = self.on_run(**self.run_args) return finished_callback(run_output) else: @@ -309,7 +315,7 @@ def _run(self, finished_callback: callable) -> Any | tuple | Future: self.future.add_done_callback(finished_callback) return self.future - def finish_run(self, run_output: tuple | Future) -> Any | tuple: + def _finish_run(self, run_output: tuple | Future) -> Any | tuple: """ Switch the node status, then process and return the run result. @@ -326,29 +332,118 @@ def finish_run(self, run_output: tuple | Future) -> Any | tuple: self.failed = True raise e - def finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple: - processed_output = self.finish_run(run_output) + def _finish_run_and_emit_ran(self, run_output: tuple | Future) -> Any | tuple: + processed_output = self._finish_run(run_output) self.signals.output.ran() return processed_output - finish_run_and_emit_ran.__doc__ = ( - finish_run.__doc__ + _finish_run_and_emit_ran.__doc__ = ( + _finish_run.__doc__ + """ - + Finally, fire the `ran` signal. """ ) + def execute(self): + """ + Run the node with whatever input it currently has, run it on this python + process, and don't emit the `ran` signal afterwards. + + Intended to be useful for debugging by just forcing the node to do its thing + right here, right now, and as-is. + """ + return self.run( + first_fetch_input=False, + then_emit_output_signals=False, + force_local_execution=True, + check_readiness=False, + ) + + def pull(self): + """ + Use topological analysis to build a tree of all upstream dependencies; run them + first, then run this node to get an up-to-date result. Does not trigger any + downstream executions. + """ + label_map = {} + nodes = {} + for node in self.get_nodes_in_data_tree(): + modified_label = node.label + str(id(node)) + label_map[modified_label] = node.label + node.label = modified_label # Ensure each node has a unique label + # This is necessary when the nodes do not have a workflow and may thus have + # arbitrary labels. + # This is pretty ugly; it would be nice to not depend so heavily on labels. + # Maybe we could switch a bunch of stuff to rely on the unique ID? + nodes[modified_label] = node + disconnected_pairs, starter = set_run_connections_according_to_linear_dag(nodes) + try: + self.signals.disconnect_run() # Don't let anything upstream trigger this + starter.run() # Now push from the top + return self.run() # Finally, run here and return the result + # Emitting won't matter since we already disconnected this one + finally: + # No matter what, restore the original connections and labels afterwards + for modified_label, node in nodes.items(): + node.label = label_map[modified_label] + node.signals.disconnect_run() + for c1, c2 in disconnected_pairs: + c1.connect(c2) + + def get_nodes_in_data_tree(self) -> set[Node]: + """ + Get a set of all nodes from this one and upstream through data connections. + """ + nodes = set([self]) + for channel in self.inputs: + for connection in channel.connections: + nodes = nodes.union(connection.node.get_nodes_in_data_tree()) + return nodes + + def __call__(self, **kwargs) -> None: + """ + Update the input, then run without firing the `ran` signal. + + Note that since input fetching happens _after_ the input values are updated, + if there is a connected data value it will get used instead of what is specified + here. If you really want to set a particular state and then run this can be + accomplished with `.inputs.fetch()` then `.set_input_values(...)` then + `.execute()` (or `.run(...)` with the flags you want). + + Args: + **kwargs: Keyword arguments matching input channel labels; used to update + the input before running. + """ + self.set_input_values(**kwargs) + return self.run() + + def set_input_values(self, **kwargs) -> None: + """ + Match keywords to input channels and update their values. + + Throws a warning if a keyword is provided that cannot be found among the input + keys. + + Args: + **kwargs: input key - input value (including channels for connection) pairs. + """ + for k, v in kwargs.items(): + if k in self.inputs.labels: + self.inputs[k] = v + else: + warnings.warn( + f"The keyword '{k}' was not found among input labels. If you are " + f"trying to update a node keyword, please use attribute assignment " + f"directly instead of calling" + ) + def _build_signal_channels(self) -> Signals: signals = Signals() signals.input.run = InputSignal("run", self, self.run) signals.output.ran = OutputSignal("ran", self) return signals - def update(self) -> Any | tuple | Future | None: - if self.ready: - return self.run() - @property def working_directory(self): if self._working_directory is None: @@ -389,10 +484,6 @@ def fully_connected(self): and self.signals.fully_connected ) - def __call__(self, **kwargs) -> None: - self.update_input(**kwargs) - return self.run() - @property def color(self) -> str: """A hex code color for use in drawing.""" diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py new file mode 100644 index 00000000..06d53631 --- /dev/null +++ b/pyiron_workflow/topology.py @@ -0,0 +1,156 @@ +""" +A submodule for getting our node classes talking nicely with an external tool for +topological analysis. Such analyses are useful for automating execution flows based on +data flow dependencies. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from toposort import toposort_flatten, CircularDependencyError + +if TYPE_CHECKING: + from pyiron_workflow.channels import SignalChannel + from pyiron_workflow.node import Node + + +def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: + """ + Maps a set of nodes to a digraph of their data dependency in the format of label + keys and set of label values for upstream nodes. + + Args: + nodes (dict[str, Node]): A label-keyed dictionary of nodes to convert into a + string-based dictionary of digraph connections based on data flow. + + Returns: + dict[str, set[str]]: A dictionary of nodes and the nodes they depend on for + data. + + Raises: + ValueError: When a node appears in its own input. + ValueError: If the nodes do not all have the same parent. + ValueError: If one of the nodes has an upstream data connection whose node has + a different parent. + """ + digraph = {} + + parent = next(iter(nodes.values())).parent # Just grab any one + if not all(n.parent is parent for n in nodes.values()): + raise ValueError( + "Nodes in a data digraph must all be siblings -- i.e. have the same " + "`parent` attribute." + ) + + for node in nodes.values(): + node_dependencies = [] + for channel in node.inputs: + locally_scoped_dependencies = [] + for upstream in channel.connections: + try: + upstream_node = nodes[upstream.node.label] + except KeyError as e: + raise KeyError( + f"The {channel.label} channel of {node.label} has a connection " + f"to {upstream.label} channel of {upstream.node.label}, but " + f"{upstream.node.label} was not found among nodes. All nodes " + f"in the data flow dependency tree must be included." + ) + if upstream_node is not upstream.node: + raise ValueError( + f"The {channel.label} channel of {node.label} has a connection " + f"to {upstream.label} channel of {upstream.node.label}, but " + f"that channel's node is not the same as the nodes passed " + f"here. All nodes in the data flow dependency tree must be " + f"included." + ) + locally_scoped_dependencies.append(upstream.node.label) + node_dependencies.extend(locally_scoped_dependencies) + node_dependencies = set(node_dependencies) + if node.label in node_dependencies: + # the toposort library has a + # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) + # That self-dependency isn't caught, so we catch it manually here. + raise ValueError( + f"Detected a cycle in the data flow topology, unable to automate " + f"the execution of non-DAGs: {node.label} appears in its own input." + ) + digraph[node.label] = node_dependencies + + return digraph + + +def nodes_to_execution_order(nodes: dict[str, Node]) -> list[str]: + """ + Given a set of nodes that all have the same parent, returns a list of corresponding + node labels giving an execution order that guarantees the executing node always has + data from all its upstream nodes. + + Args: + nodes (dict[str, Node]): A label-keyed dictionary of nodes from whom to build + an execution order based on topological analysis of data flow. + + Returns: + (list[str]): The labels in safe execution order. + + Raises: + CircularDependencyError: If the data dependency is not a Directed Acyclic Graph + """ + try: + # Topological sorting ensures that all input dependencies have been + # executed before the node depending on them gets run + # The flattened part is just that we don't care about topological + # generations that are mutually independent (inefficient but easier for now) + execution_order = toposort_flatten(nodes_to_data_digraph(nodes)) + except CircularDependencyError as e: + raise ValueError( + f"Detected a cycle in the data flow topology, unable to automate the " + f"execution of non-DAGs: cycles found among {e.data}" + ) + return execution_order + + +def set_run_connections_according_to_linear_dag( + nodes: dict[str, Node] +) -> tuple[list[tuple[SignalChannel, SignalChannel]], Node]: + """ + Given a set of nodes that all have the same parent, have no upstream data + connections outside the nodes provided, and have acyclic data flow, disconnects all + their `run` and `ran` signals, then sets these signals to a linear execution that + guarantees downstream nodes are always executed after upstream nodes. Returns one + of the upstream-most nodes. + + In the event an exception is encountered, any disconnected connections are repaired + before it is raised. + + Args: + nodes (dict[str, Node]): A dictionary of node labels and the node the label is + from, whose connections will be set according to data flow. + + Returns: + (list[tuple[SignalChannel, SignalChannel]]): Any `run`/`ran` pairs that were + disconnected. + (Node): The 0th node in the execution order, i.e. on that has no + dependencies. + """ + disconnected_pairs = [] + for node in nodes.values(): + disconnected_pairs.extend(node.signals.disconnect_run()) + disconnected_pairs.extend(node.signals.output.ran.disconnect_all()) + + try: + # This is the most primitive sort of topological exploitation we can do + # It is not efficient if the nodes have executors and can run in parallel + execution_order = nodes_to_execution_order(nodes) + + for i, label in enumerate(execution_order[:-1]): + next_node = execution_order[i + 1] + nodes[label] > nodes[next_node] + + return disconnected_pairs, nodes[execution_order[0]] + except Exception as e: + # Restore whatever you broke + for c1, c2 in disconnected_pairs: + c1.connect(c2) + raise e diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 5817fa7e..36eee604 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -203,7 +203,13 @@ def inputs(self) -> Inputs: def outputs(self) -> Outputs: return self._build_outputs() - def run(self): + def run( + self, + first_fetch_input: bool = True, + then_emit_output_signals: bool = True, + force_local_execution: bool = False, + check_readiness: bool = True, + ): if self.automate_execution: self.set_run_signals_to_dag_execution() return super().run() diff --git a/tests/integration/test_pull.py b/tests/integration/test_pull.py new file mode 100644 index 00000000..75e85aca --- /dev/null +++ b/tests/integration/test_pull.py @@ -0,0 +1,84 @@ +import unittest + +from pyiron_workflow.workflow import Workflow + + +class TestPullingOutput(unittest.TestCase): + def test_without_workflow(self): + from pyiron_workflow import Workflow + + @Workflow.wrap_as.single_value_node("sum") + def x_plus_y(x: int = 0, y: int = 0) -> int: + return x + y + + node = x_plus_y( + x=x_plus_y(0, 1), + y=x_plus_y(2, 3) + ) + self.assertEqual(6, node.pull()) + + for n in [ + node, + node.inputs.x.connections[0].node, + node.inputs.y.connections[0].node, + ]: + self.assertFalse( + n.signals.connected, + msg="Connections should be unwound after the pull is done" + ) + self.assertEqual( + "x_plus_y", + n.label, + msg="Original labels should be restored after the pull is done" + ) + + def test_pulling_from_inside_a_macro(self): + @Workflow.wrap_as.single_value_node("sum") + def x_plus_y(x: int = 0, y: int = 0) -> int: + # print("EXECUTING") + return x + y + + @Workflow.wrap_as.macro_node() + def b2_leaves_a1_alone(macro): + macro.a1 = x_plus_y(0, 0) + macro.a2 = x_plus_y(0, 1) + macro.b1 = x_plus_y(macro.a1, macro.a2) + macro.b2 = x_plus_y(macro.a2, 10) + + wf = Workflow("demo") + wf.upstream = x_plus_y() + wf.macro = b2_leaves_a1_alone(a2__x=wf.upstream) + + # Pulling b1 -- executes a1, a2, b2 + self.assertEqual(1, wf.macro.b1.pull()) + # >>> EXECUTING + # >>> EXECUTING + # >>> EXECUTING + # >>> 1 + + # Pulling b2 -- executes a2, a1 + self.assertEqual(11, wf.macro.b2.pull()) + # >>> EXECUTING + # >>> EXECUTING + # >>> 11 + + # Updated inputs get reflected in the pull + wf.macro.set_input_values(a1__x=100, a2__x=-100) + self.assertEqual(-89, wf.macro.b2.pull()) + # >>> EXECUTING + # >>> EXECUTING + # >>> -89 + + # Connections are restored after a pull + # Crazy negative value of a2 gets written over by pulling in the upstream + # connection value + # Running wf -- executes upstream, macro (is silent), a1, a2, b1, b2 + out = wf() + self.assertEqual(101, out.macro__b1__sum) + self.assertEqual(11, out.macro__b2__sum) + # >>> EXECUTING + # >>> EXECUTING + # >>> EXECUTING + # >>> EXECUTING + # >>> EXECUTING + # >>> {'macro__b1__sum': 101, 'macro__b2__sum': 11} \ No newline at end of file diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 4f295899..28b8fbb8 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -286,7 +286,7 @@ def with_self(self, x: float) -> float: msg="Expected 'self' to be filtered out of node input, but found it in the " "input labels" ) - node.inputs.x = 1 + node.inputs.x = 1.0 node.run() self.assertEqual( node.outputs.output.value, @@ -489,7 +489,12 @@ def all_floats(x=1.1, y=1.1, z=1.1, omega=NotData, extra_there=None) -> float: ref = reference() floats = all_floats() ref() - floats() + floats.run( + check_readiness=False, + # We force-skip the readiness check since we are explicitly _trying_ to + # have one of the inputs be `NotData` -- a value which triggers the channel + # to be "not ready" + ) ref._copy_values(floats) self.assertEqual(