From e6ad647673171bbf387d37af0e822df00fb7e1aa Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Tue, 27 Feb 2024 21:51:45 -0500 Subject: [PATCH] Parse electron function strings at Electron construction (#1926) * Parse function string in Electron constructor Waiting until `build_graph()` won't work if the graph is being built in a remote executor without the original source code. * Changelog --------- Co-authored-by: Santosh kumar <29346072+santoshkumarradha@users.noreply.github.com> --- CHANGELOG.md | 4 +++ covalent/_workflow/electron.py | 31 ++----------------- .../covalent_tests/workflow/electron_test.py | 21 +------------ 3 files changed, 8 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a928eabf..17344263c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Fixed + +- Sublattice electron function strings are now parsed correctly + ### Operations - Fixed nightly workflow's calling of other workflows. diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 12f18cbf5..a8a9055be 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -96,6 +96,7 @@ def __init__( self.metadata = metadata self.task_group_id = task_group_id self._packing_tasks = packing_tasks + self._function_string = get_serialized_function_str(function) @property def packing_tasks(self) -> bool: @@ -442,7 +443,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]: ) name = sublattice_prefix + self.function.__name__ - function_string = get_serialized_function_str(self.function) + function_string = self._function_string bound_electron = sub_electron( self.function, json.dumps(parent_metadata), *args, **kwargs ) @@ -463,7 +464,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]: name=self.function.__name__, function=self.function, metadata=self.metadata.copy(), - function_string=get_serialized_function_str(self.function), + function_string=self._function_string, task_group_id=self.task_group_id if self.packing_tasks else None, ) self.task_group_id = self.task_group_id if self.packing_tasks else self.node_id @@ -608,32 +609,6 @@ def _auto_dict_node(*args, **kwargs): arg_index=arg_index, ) - def add_collection_node_to_graph(self, graph: "_TransportGraph", prefix: str) -> int: - """ - Adds the node to lattice's transport graph in the case - where a collection of electrons is passed as an argument - to another electron. - - Args: - graph: Transport graph of the lattice - prefix: Prefix of the node - - Returns: - node_id: Node id of the added node - """ - - new_metadata = encode_metadata(DEFAULT_METADATA_VALUES.copy()) - if "executor" in self.metadata: - new_metadata["executor"] = self.metadata["executor"] - new_metadata["executor_data"] = self.metadata["executor_data"] - - node_id = graph.add_node( - name=prefix, - function=to_decoded_electron_collection, - metadata=new_metadata, - function_string=get_serialized_function_str(to_decoded_electron_collection), - ) - return node_id def wait_for(self, electrons: Union["Electron", Iterable["Electron"]]): diff --git a/tests/covalent_tests/workflow/electron_test.py b/tests/covalent_tests/workflow/electron_test.py index a3db76bb7..d7a3e3192 100644 --- a/tests/covalent_tests/workflow/electron_test.py +++ b/tests/covalent_tests/workflow/electron_test.py @@ -36,7 +36,7 @@ to_decoded_electron_collection, ) from covalent._workflow.lattice import Lattice -from covalent._workflow.transport import TransportableObject, _TransportGraph, encode_metadata +from covalent._workflow.transport import TransportableObject, encode_metadata from covalent.executor.executor_plugins.local import LocalExecutor @@ -252,25 +252,6 @@ def test_collection_node_helper_electron(): assert to_decoded_electron_collection(x=dict_collection) == {"a": 1, "b": 2} -def test_electron_add_collection_node(): - """Test `to_decoded_electron_collection` in `Electron.add_collection_node`""" - - def f(x): - return x - - e = Electron(f) - tg = _TransportGraph() - node_id = e.add_collection_node_to_graph(tg, prefix=":") - collection_fn = tg.get_node_value(node_id, "function").get_deserialized() - - collection = [ - TransportableObject.make_transportable(1), - TransportableObject.make_transportable(2), - ] - - assert collection_fn(x=collection) == [1, 2] - - def test_injected_inputs_are_not_in_tg(): """Test that arguments to electrons injected by calldeps aren't added to the transport graph"""