Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instrument task runs #15955

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec85769
WIP:instrument task runs
jeanluciano Nov 8, 2024
f3ae6b3
async engine fix
jeanluciano Nov 11, 2024
c03c73c
Merge branch 'main' of https://github.com/PrefectHQ/prefect into jean…
jeanluciano Nov 11, 2024
943cfe6
more unit tests
jeanluciano Nov 12, 2024
f8f8d09
in memory exporter
jeanluciano Nov 12, 2024
4f4d683
typo fix
jeanluciano Nov 12, 2024
8b62c09
using instrumentation fixture
jeanluciano Nov 13, 2024
a2e19df
all but labels test
jeanluciano Nov 14, 2024
781d0f4
labels function
jeanluciano Nov 14, 2024
f95d35a
keyvalue labels class
jeanluciano Nov 14, 2024
6c83dfc
test refractor
jeanluciano Nov 14, 2024
cc19dc5
Merge branch 'main' of https://github.com/PrefectHQ/prefect into jean…
jeanluciano Nov 14, 2024
5553d21
Merge branch 'main' into jean/cloud-565-task-run-instrumentation
jeanluciano Nov 15, 2024
3dc228e
Merge branch 'main' into jean/cloud-565-task-run-instrumentation
jeanluciano Nov 15, 2024
26c6d2a
Merge branch 'main' into jean/cloud-565-task-run-instrumentation
jeanluciano Nov 18, 2024
c63a0aa
Refactor telemetry to its own class
jeanluciano Nov 19, 2024
440456d
Merge branch 'main' into jean/cloud-565-task-run-instrumentation
jeanluciano Nov 19, 2024
c2afe1f
Unit test sync and async engines
jeanluciano Nov 19, 2024
8544e16
Merge branch 'jean/cloud-565-task-run-instrumentation' of https://git…
jeanluciano Nov 19, 2024
ce45c77
Merge branch 'main' of https://github.com/PrefectHQ/prefect into jean…
jeanluciano Nov 22, 2024
ed90df8
labels no longer mocked
jeanluciano Nov 22, 2024
644b602
Merge branch 'main' into jean/cloud-565-task-run-instrumentation
jeanluciano Nov 22, 2024
ae585e4
run telemetry moved to its own file
jeanluciano Nov 22, 2024
2619f16
Merge branch 'jean/cloud-565-task-run-instrumentation' of https://git…
jeanluciano Nov 22, 2024
30ed2bc
type checking guard
jeanluciano Nov 22, 2024
c752485
type fix
jeanluciano Nov 22, 2024
03317e2
import removed
jeanluciano Nov 22, 2024
8f87b55
scope fix
jeanluciano Nov 22, 2024
6f1cbfb
labels functin removed
jeanluciano Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import anyio
import pendulum
from opentelemetry import trace
from typing_extensions import ParamSpec

from prefect import Task
Expand Down Expand Up @@ -79,6 +80,7 @@
exception_to_failed_state,
return_value_to_state,
)
from prefect.telemetry.run_telemetry import RunTelemetry
from prefect.transactions import IsolationLevel, Transaction, transaction
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import run_coro_as_sync
Expand All @@ -99,6 +101,12 @@
BACKOFF_MAX = 10


def get_labels_from_context(context: Optional[FlowRunContext]) -> Dict[str, Any]:
if context is None:
return {}
return context.flow_run.labels
jeanluciano marked this conversation as resolved.
Show resolved Hide resolved


Comment on lines +104 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_labels_from_context(context: Optional[FlowRunContext]) -> Dict[str, Any]:
if context is None:
return {}
return context.flow_run.labels

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now unused

class TaskRunTimeoutError(TimeoutError):
"""Raised when a task run exceeds its timeout."""

Expand All @@ -120,6 +128,7 @@ class BaseTaskRunEngine(Generic[P, R]):
_is_started: bool = False
_task_name_set: bool = False
_last_event: Optional[PrefectEvent] = None
_telemetry: RunTelemetry = field(default_factory=RunTelemetry)

def __post_init__(self):
if self.parameters is None:
Expand Down Expand Up @@ -460,7 +469,7 @@ def set_state(self, state: State, force: bool = False) -> State:
validated_state=self.task_run.state,
follows=self._last_event,
)

self._telemetry.update_state(new_state)
return new_state

def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -514,6 +523,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
self.record_terminal_state_timing(terminal_state)
self.set_state(terminal_state)
self._return_value = result

self._telemetry.end_span_on_success(terminal_state.message)
return result

def handle_retry(self, exc: Exception) -> bool:
Expand Down Expand Up @@ -562,6 +573,7 @@ def handle_retry(self, exc: Exception) -> bool:

def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
self._telemetry.record_exception(exc)
if not self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
state = run_coro_as_sync(
Expand All @@ -575,6 +587,7 @@ def handle_exception(self, exc: Exception) -> None:
self.record_terminal_state_timing(state)
self.set_state(state)
self._raised = exc
self._telemetry.end_span_on_failure(state.message)

def handle_timeout(self, exc: TimeoutError) -> None:
if not self.handle_retry(exc):
Expand All @@ -598,6 +611,8 @@ def handle_crash(self, exc: BaseException) -> None:
self.record_terminal_state_timing(state)
self.set_state(state, force=True)
self._raised = exc
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message)

@contextmanager
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
Expand Down Expand Up @@ -655,14 +670,17 @@ def initialize_run(
with SyncClientContext.get_or_create() as client_ctx:
self._client = client_ctx.client
self._is_started = True
flow_run_context = FlowRunContext.get()
parent_task_run_context = TaskRunContext.get()

try:
if not self.task_run:
self.task_run = run_coro_as_sync(
self.task.create_local_run(
id=task_run_id,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
parent_task_run_context=TaskRunContext.get(),
flow_run_context=flow_run_context,
parent_task_run_context=parent_task_run_context,
wait_for=self.wait_for,
extra_task_inputs=dependencies,
)
Expand All @@ -679,6 +697,13 @@ def initialize_run(
self.logger.debug(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
)
labels = (
flow_run_context.flow_run.labels if flow_run_context else {}
)
self._telemetry.start_span(
self.task_run, self.parameters, labels
)

yield self

except TerminationSignal as exc:
Expand Down Expand Up @@ -730,11 +755,12 @@ def start(
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
) -> Generator[None, None, None]:
with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
self.begin_run()
try:
yield
finally:
self.call_hooks()
with trace.use_span(self._telemetry._span):
self.begin_run()
try:
yield
finally:
self.call_hooks()

@contextmanager
def transaction_context(self) -> Generator[Transaction, None, None]:
Expand Down Expand Up @@ -977,6 +1003,7 @@ async def set_state(self, state: State, force: bool = False) -> State:
follows=self._last_event,
)

self._telemetry.update_state(new_state)
return new_state

async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -1025,6 +1052,9 @@ async def handle_success(self, result: R, transaction: Transaction) -> R:
self.record_terminal_state_timing(terminal_state)
await self.set_state(terminal_state)
self._return_value = result

self._telemetry.end_span_on_success(terminal_state.message)

return result

async def handle_retry(self, exc: Exception) -> bool:
Expand Down Expand Up @@ -1073,6 +1103,7 @@ async def handle_retry(self, exc: Exception) -> bool:

async def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
self._telemetry.record_exception(exc)
if not await self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
state = await exception_to_failed_state(
Expand All @@ -1084,7 +1115,10 @@ async def handle_exception(self, exc: Exception) -> None:
await self.set_state(state)
self._raised = exc

self._telemetry.end_span_on_failure(state.message)

async def handle_timeout(self, exc: TimeoutError) -> None:
self._telemetry.record_exception(exc)
if not await self.handle_retry(exc):
if isinstance(exc, TaskRunTimeoutError):
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
Expand All @@ -1098,6 +1132,7 @@ async def handle_timeout(self, exc: TimeoutError) -> None:
)
await self.set_state(state)
self._raised = exc
self._telemetry.end_span_on_failure(state.message)

async def handle_crash(self, exc: BaseException) -> None:
state = await exception_to_crashed_state(exc)
Expand All @@ -1107,6 +1142,9 @@ async def handle_crash(self, exc: BaseException) -> None:
await self.set_state(state, force=True)
self._raised = exc

self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message)

@asynccontextmanager
async def setup_run_context(self, client: Optional[PrefectClient] = None):
from prefect.utilities.engine import (
Expand Down Expand Up @@ -1162,12 +1200,14 @@ async def initialize_run(
async with AsyncClientContext.get_or_create():
self._client = get_client()
self._is_started = True
flow_run_context = FlowRunContext.get()

try:
if not self.task_run:
self.task_run = await self.task.create_local_run(
id=task_run_id,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
flow_run_context=flow_run_context,
parent_task_run_context=TaskRunContext.get(),
wait_for=self.wait_for,
extra_task_inputs=dependencies,
Expand All @@ -1184,6 +1224,14 @@ async def initialize_run(
self.logger.debug(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
)

labels = (
flow_run_context.flow_run.labels if flow_run_context else {}
)
self._telemetry.start_span(
self.task_run, self.parameters, labels
)

yield self

except TerminationSignal as exc:
Expand Down Expand Up @@ -1237,11 +1285,12 @@ async def start(
async with self.initialize_run(
task_run_id=task_run_id, dependencies=dependencies
):
await self.begin_run()
try:
yield
finally:
await self.call_hooks()
with trace.use_span(self._telemetry._span):
await self.begin_run()
try:
yield
finally:
await self.call_hooks()

@asynccontextmanager
async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
Expand Down
73 changes: 73 additions & 0 deletions src/prefect/telemetry/run_telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict

from opentelemetry.trace import (
Status,
StatusCode,
get_tracer,
)

import prefect
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import State

if TYPE_CHECKING:
from opentelemetry.sdk.trace import Tracer


@dataclass
class RunTelemetry:
_tracer: "Tracer" = field(
default_factory=lambda: get_tracer("prefect", prefect.__version__)
)
_span = None

def start_span(
self,
task_run: TaskRun,
parameters: Dict[str, Any] = {},
labels: Dict[str, Any] = {},
Comment on lines +29 to +30
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these mutable defaults should be optionals that default to None

):
parameter_attributes = {
f"prefect.run.parameter.{k}": type(v).__name__
for k, v in parameters.items()
}
self._span = self._tracer.start_span(
name=task_run.name,
attributes={
"prefect.run.type": "task",
"prefect.run.id": str(task_run.id),
"prefect.tags": task_run.tags,
**parameter_attributes,
**labels,
},
)

def end_span_on_success(self, terminal_message: str):
if self._span:
self._span.set_status(Status(StatusCode.OK), terminal_message)
self._span.end(time.time_ns())
self._span = None

def end_span_on_failure(self, terminal_message: str):
if self._span:
self._span.set_status(Status(StatusCode.ERROR, terminal_message))
self._span.end(time.time_ns())
self._span = None

def record_exception(self, exc: Exception):
if self._span:
self._span.record_exception(exc)

def update_state(self, new_state: State):
if self._span:
self._span.add_event(
new_state.name,
{
"prefect.state.message": new_state.message or "",
"prefect.state.type": new_state.type,
"prefect.state.name": new_state.name or new_state.type,
"prefect.state.id": str(new_state.id),
},
)
Loading