Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an 'official' caller for batch UDFs #632

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ markers = [
"vcf: VCF tests that run on TileDB Cloud",
]
norecursedirs = ["tiledb/cloud"]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)1s] -- %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%d-%b-%y %H:%M:%S"

[tool.setuptools]
zip-safe = false
Expand Down
89 changes: 89 additions & 0 deletions src/tiledb/cloud/dag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import datetime
import itertools
import json
import logging
import numbers
import re
import threading
import time
import uuid
import warnings
import webbrowser
from typing import (
Any,
Callable,
Expand All @@ -17,6 +19,7 @@
FrozenSet,
Hashable,
List,
Mapping,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -46,6 +49,8 @@
from . import visualization as viz
from .mode import Mode

logger = logging.getLogger(__name__)
Copy link
Contributor

@JohnMoutafis JohnMoutafis Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For uniformity reasons, I prefer that we use tiledb.cloud.utilities.get_logger_wrapper for logging everywhere.

The method also allows for a verbosity level (verbose=True/False) that should be set from the UDF's arguments (as is the case with the as_batch method) so the logger can be declared inside the exec_batch_udf and the verbose should be "grabbed" from the kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohnMoutafis I agree, but we have a circular import problem because run_dag is in the utilities module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spencerseale @JohnMoutafis are we still stuck here? Do I understand correctly that we don't have a circular import yet, but will when run_dag() calls this new function?

If we refactored and moved get_logger_wrapper() to, for example, tiledb.cloud.logging, that would eliminate the potential circular import, yes? I'm willing to do that work.


Status = st.Status # Re-export for compabitility.
_T = TypeVar("_T")
# Special string included in server errors when there is a problem loading
Expand Down Expand Up @@ -2082,3 +2087,87 @@ def array_task_status_to_status(status: models.ArrayTaskStatus) -> Status:

def task_graph_log_status_to_status(status: models.TaskGraphLogStatus) -> Status:
return _TASK_GRAPH_LOG_STATUS_TO_STATUS_MAP.get(status, Status.NOT_STARTED)


def exec_batch_udf(
func: Union[callable, str],
*args: Any,
name: Optional[str] = None,
namespace: Optional[str] = None,
acn: Optional[str] = None,
resources: Optional[Mapping[str, str]] = None,
image_name: Optional[str] = None,
compute: bool = True,
retry_limit: int = 0,
open_browser: bool = False,
**kwargs: Any,
) -> DAG:
"""Run a function as a batch UDF on TileDB Cloud.

A batch UDF is a single task, task graph run in `dag.Mode.BATCH` mode. This
allows specifying custom resources to a UDF and passing in an access credential
name for accessing underingly storage backends with `tiledb.VFS`

:param func: Name of registered UDF (e.g. <namespace>/<registered_name>)
or in-memory callable.
:param **args: Positional args to pass to batch UDF.
:param name: Task name.
:param namespace: TileDB Cloud namespace to execute in.
:param acn: TileDB Cloud access credential name.
:param resources: Resources to allocate to task (e.g.
{"cpu": "2", "memory": "10Gi"}).
:param image_name: UDF image name.
:param compute: Whether to execute batch UDF.
:param retry_limit: Maximum retry attempts.
:param open_browser: Whether to open browser to batch UDF.
:param **kwargs: Keyword args to pass to batch UDF.
:return: DAG instance, either running or not started depending on `compute` arg.
"""

try:
name = name or func.__name__
except AttributeError:
name = func

# extract acn from deprecated 'access_credentials_name' only if not found in acn
acn_legacy = kwargs.pop("access_credentials_name", None)
acn = acn or acn_legacy

graph = DAG(
name=f"batch->{name}",
namespace=namespace,
mode=Mode.BATCH,
retry_strategy=models.RetryStrategy(
limit=retry_limit,
retry_policy="Always",
),
)

graph.submit(
func,
*args,
name=name,
access_credentials_name=acn,
resources=resources,
image_name=image_name,
**kwargs,
)

if compute:
graph.compute()

task_uri = "https://cloud.tiledb.com/activity/taskgraphs/{}/{}".format(
graph.namespace,
graph.server_graph_uuid,
)

logger.info(f"TileDB Cloud task submitted - {task_uri}")

if open_browser:
try:
webbrowser.open_new_tab(task_uri)
except webbrowser.Error:
pass
logger.debug("Unable to access webrowser.")

return graph
116 changes: 116 additions & 0 deletions tests/test_dag_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Pytest-based tests for tiledb.cloud.dag.dag"""

from unittest.mock import MagicMock
from unittest.mock import patch
from webbrowser import Error

from attrs import define

from tiledb.cloud.client import default_user
from tiledb.cloud.dag.dag import exec_batch_udf
from tiledb.cloud.dag.mode import Mode

_TASK_NAME = "unittest-test-dag-exec-batch-udf"
_NAMESPACE = default_user().username


@define
class ExecutableLoader:
arg: str = "arg1"
registered_udf: str = "TileDB-Inc/ls_uri"

def in_memory(self, arg: str):
return arg

@property
def registered(self):
return self.registered_udf

def all_exec(self):
return (
self.in_memory,
self.registered,
)


@patch("tiledb.cloud.dag.dag.webbrowser.open_new_tab")
@patch("tiledb.cloud.dag.dag.DAG")
def test_exec_batch_udf_mock(mock_dag: MagicMock, mock_open_new_tab: MagicMock) -> None:
"""Test procedure of exec_batch_udf.

This test is concerned only if proper logic is engaged based on args and exceptions.

Additionally by passing both an in-memory callable and a str referencing
a registered UDF, checks that the name of the submitted node is set properly
and no AttributeError thrown.
"""

mock_dag_inst = mock_dag.return_value

loader = ExecutableLoader()

expected_submit_call_count = 0
for callable_to_test in loader.all_exec():
expected_submit_call_count += 1

graph = exec_batch_udf(
callable_to_test,
loader.arg,
compute=False,
)

assert mock_dag_inst.submit.call_count == expected_submit_call_count
assert mock_dag_inst.compute.call_count == 0
assert isinstance(graph, MagicMock)

# checking logic associated with open_browser == True
# ensure 'except' block hits when trying to open webbrowser
mock_open_new_tab.side_effect = Error()

graph = exec_batch_udf(
loader.in_memory,
loader.arg,
compute=True,
open_browser=True,
)

assert mock_dag_inst.submit.call_count == expected_submit_call_count + 1
assert mock_dag_inst.compute.call_count == 1
assert mock_open_new_tab.called # test webbrowser attempted to open


def test_exec_batch_udf() -> None:
"""Test actual loading of DAG.

Concerned primarily that DAG is instantiated appropriately as
specified by exec_batch_udf.

Previous unit test for exec_batch_udf already tested compute method is called
when the 'compute' arg is True. So not actually executing batch UDF, as
DAG.compute is tested elsewhere, outside of scope of these tests.

Does not test whether the registered UDF actually exists, just that
a str passed is acceptable. It is on the user to ensure registered
UDF exists.
"""

loader = ExecutableLoader()

# test multiple retry limit settings
for retry_count, callable_to_test in enumerate(loader.all_exec()):
graph = exec_batch_udf(
callable_to_test,
loader.arg,
name=_TASK_NAME,
namespace=_NAMESPACE,
retry_limit=retry_count,
compute=False,
)

assert graph.name == f"batch->{_TASK_NAME}"
assert graph.namespace == _NAMESPACE
assert graph.mode == Mode.BATCH
assert graph.retry_strategy.retry_policy.lower() == "always"
assert graph.retry_strategy.limit == retry_count
assert len(graph.nodes) == 1
assert graph.status.name.lower() == "not_started"
Loading