Skip to content

Commit

Permalink
Parse electron function strings at Electron construction (#1926)
Browse files Browse the repository at this point in the history
* Parse function string in Electron constructor

Waiting until `build_graph()` won't work if the graph is being built
in a remote executor without the original source code.

* Changelog

---------

Co-authored-by: Santosh kumar <[email protected]>
  • Loading branch information
cjao and santoshkumarradha authored Feb 28, 2024
1 parent 170f4d6 commit e6ad647
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 48 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### Fixed

- Sublattice electron function strings are now parsed correctly

### Operations

- Fixed nightly workflow's calling of other workflows.
Expand Down
31 changes: 3 additions & 28 deletions covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
self.metadata = metadata
self.task_group_id = task_group_id
self._packing_tasks = packing_tasks
self._function_string = get_serialized_function_str(function)

@property
def packing_tasks(self) -> bool:
Expand Down Expand Up @@ -442,7 +443,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]:
)

name = sublattice_prefix + self.function.__name__
function_string = get_serialized_function_str(self.function)
function_string = self._function_string
bound_electron = sub_electron(
self.function, json.dumps(parent_metadata), *args, **kwargs
)
Expand All @@ -463,7 +464,7 @@ def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]:
name=self.function.__name__,
function=self.function,
metadata=self.metadata.copy(),
function_string=get_serialized_function_str(self.function),
function_string=self._function_string,
task_group_id=self.task_group_id if self.packing_tasks else None,
)
self.task_group_id = self.task_group_id if self.packing_tasks else self.node_id
Expand Down Expand Up @@ -608,32 +609,6 @@ def _auto_dict_node(*args, **kwargs):
arg_index=arg_index,
)

def add_collection_node_to_graph(self, graph: "_TransportGraph", prefix: str) -> int:
"""
Adds the node to lattice's transport graph in the case
where a collection of electrons is passed as an argument
to another electron.
Args:
graph: Transport graph of the lattice
prefix: Prefix of the node
Returns:
node_id: Node id of the added node
"""

new_metadata = encode_metadata(DEFAULT_METADATA_VALUES.copy())
if "executor" in self.metadata:
new_metadata["executor"] = self.metadata["executor"]
new_metadata["executor_data"] = self.metadata["executor_data"]

node_id = graph.add_node(
name=prefix,
function=to_decoded_electron_collection,
metadata=new_metadata,
function_string=get_serialized_function_str(to_decoded_electron_collection),
)

return node_id

def wait_for(self, electrons: Union["Electron", Iterable["Electron"]]):
Expand Down
21 changes: 1 addition & 20 deletions tests/covalent_tests/workflow/electron_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
to_decoded_electron_collection,
)
from covalent._workflow.lattice import Lattice
from covalent._workflow.transport import TransportableObject, _TransportGraph, encode_metadata
from covalent._workflow.transport import TransportableObject, encode_metadata
from covalent.executor.executor_plugins.local import LocalExecutor


Expand Down Expand Up @@ -252,25 +252,6 @@ def test_collection_node_helper_electron():
assert to_decoded_electron_collection(x=dict_collection) == {"a": 1, "b": 2}


def test_electron_add_collection_node():
"""Test `to_decoded_electron_collection` in `Electron.add_collection_node`"""

def f(x):
return x

e = Electron(f)
tg = _TransportGraph()
node_id = e.add_collection_node_to_graph(tg, prefix=":")
collection_fn = tg.get_node_value(node_id, "function").get_deserialized()

collection = [
TransportableObject.make_transportable(1),
TransportableObject.make_transportable(2),
]

assert collection_fn(x=collection) == [1, 2]


def test_injected_inputs_are_not_in_tg():
"""Test that arguments to electrons injected by calldeps aren't
added to the transport graph"""
Expand Down

0 comments on commit e6ad647

Please sign in to comment.