Skip to content

Commit

Permalink
Use replacing visitor so we actually descend into batch args. (#499)
Browse files Browse the repository at this point in the history
When server-side ("batch") DAG execution was added, the implementation
used custom code for finding parent nodes, rather than the existing
visitor system. This meant that if a Node was passed in not at the top
level, it would not work.

This change replaces it with a visitor-based system that allows users to
pass in nodes just as they would in client-side execution.
  • Loading branch information
thetorpedodog authored Jan 5, 2024
1 parent 52760b5 commit a82212b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 59 deletions.
95 changes: 36 additions & 59 deletions src/tiledb/cloud/dag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .._results import results
from .._results import stored_params
from .._results import tiledb_json
from .._results import types
from ..rest_api import models
from ..sql import _execution as _sql_exec
from ..taskgraphs import _results as _tg_results
Expand Down Expand Up @@ -1525,67 +1526,26 @@ def _build_batch_taskgraph(self):
node_jsons = []
for node in topo_sorted_nodes:
kwargs = {}
if callable(node.args[0]):
kwargs["executable_code"] = codecs.PickleCodec.encode_base64(
node.args[0]
)
kwargs["source_text"] = functions.getsourcelines(node.args[0])
node_args = list(node.args)
# XXX: This is subtly different from the way functions are handled
# when coordinated locally ("realtime").
if callable(node_args[0]):
func = node_args.pop(0)
kwargs["executable_code"] = codecs.PickleCodec.encode_base64(func)
kwargs["source_text"] = functions.getsourcelines(func)
if type(node.args[0]) == str:
kwargs["registered_udf_name"] = node.args[0]

args = []
i = 0
for arg in node.args:
i += 1
# Skip if first arg is function
if i == 1 and callable(arg):
continue
# Skip if first arg is registered udf name
if i == 1 and type(arg) == str and "registered_udf_name" in kwargs:
continue
if isinstance(arg, Node):
if node._expand_node_output:
args.append(
models.TGUDFArgument(value="{{inputs.parameters.partId}}")
)
else:
args.append(
models.TGUDFArgument(
value={
"__tdbudf__": "node_output",
"client_node_id": str(arg.id),
}
)
)
else:
esc = tiledb_json.Encoder()
args.append(models.TGUDFArgument(value=esc.visit(arg)))

for name, arg in node.kwargs.items():
if name in _SKIP_BATCH_UDF_KWARGS:
continue
elif isinstance(arg, Node):
if node._expand_node_output:
args.append(
models.TGUDFArgument(
name=name, value="{{inputs.parameters.partId}}"
)
)
else:
args.append(
models.TGUDFArgument(
name=name,
value={
"__tdbudf__": "node_output",
"client_node_id": str(arg.id),
},
)
)
else:
esc = tiledb_json.Encoder()
args.append(models.TGUDFArgument(name=name, value=esc.visit(arg)))
func = node_args.pop(0)
kwargs["registered_udf_name"] = func

kwargs["arguments"] = args
filtered_node_kwargs = {
name: val
for name, val in node.kwargs.items()
if name not in _SKIP_BATCH_UDF_KWARGS
}

all_args = types.Arguments(node_args, filtered_node_kwargs)
encoder = _BatchArgEncoder(input_is_expanded=bool(node._expand_node_output))
kwargs["arguments"] = encoder.encode_arguments(all_args)

env_dict = {
"language": models.UDFLanguage.PYTHON,
Expand Down Expand Up @@ -1853,6 +1813,23 @@ def maybe_replace(self, arg) -> Optional[visitor.Replacement]:
return visitor.Replacement(arg.result())


class _BatchArgEncoder(tiledb_json.Encoder):
"""Encodes arguments with the special format used by batch graphs."""

def __init__(self, input_is_expanded: bool) -> None:
self._input_is_expanded = input_is_expanded
super().__init__()

def maybe_replace(self, arg: object) -> Optional[visitor.Replacement]:
if isinstance(arg, Node):
if self._input_is_expanded:
return visitor.Replacement("{{inputs.parameters.partId}}")
return visitor.Replacement(
{"__tdbudf__": "node_output", "client_node_id": str(arg.id)}
)
return super().maybe_replace(arg)


class _NodeResultReplacer(visitor.ReplacingVisitor):
"""Replaces :class:`Node`s with their results."""

Expand Down
11 changes: 11 additions & 0 deletions tests/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,17 @@ def print_result(x):
d.wait(300)
self.assertEqual(print_node.result(), [99.0, 299.0, 499.0, 699.0])

def test_param_replacement(self):
d = dag.DAG(mode=Mode.BATCH)
in_node = d.submit(lambda x: "out" + x[2:], "input")
wrap_node = d.submit(repr, [in_node])
dict_node = d.submit(lambda d: tuple(d.items()), {"wrapped": wrap_node})
d.compute()
d.wait(300)
self.assertEqual(in_node.result(), "output")
self.assertEqual(wrap_node.result(), "['output']")
self.assertEqual(dict_node.result(), (("wrapped", "['output']"),))

def test_batch_dag_retries(self):
def random_failure():
import random
Expand Down

0 comments on commit a82212b

Please sign in to comment.