diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index f233983f3c93..16efecc27a72 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -2209,6 +2209,13 @@ async def read_flow_run_states( response.json() ) + async def set_flow_run_name(self, flow_run_id: UUID, name: str): + flow_run_data = FlowRunUpdate(name=name) + return await self._client.patch( + f"/flow_runs/{flow_run_id}", + json=flow_run_data.model_dump(mode="json", exclude_unset=True), + ) + async def set_task_run_name(self, task_run_id: UUID, name: str): task_run_data = TaskRunUpdate(name=name) return await self._client.patch( @@ -4019,7 +4026,7 @@ def set_flow_run_state( return OrchestrationResult.model_validate(response.json()) def set_flow_run_name(self, flow_run_id: UUID, name: str): - flow_run_data = TaskRunUpdate(name=name) + flow_run_data = FlowRunUpdate(name=name) return self._client.patch( f"/flow_runs/{flow_run_id}", json=flow_run_data.model_dump(mode="json", exclude_unset=True), diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 49516459b2c3..239d9306ffb1 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -2,7 +2,7 @@ import logging import os import time -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field from typing import ( Any, @@ -22,19 +22,25 @@ ) from uuid import UUID +from anyio import CancelScope from opentelemetry import trace from opentelemetry.trace import Tracer, get_tracer from typing_extensions import ParamSpec import prefect from prefect import Task -from prefect.client.orchestration import SyncPrefectClient, get_client +from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.filters import FlowRunFilter from prefect.client.schemas.sorting import FlowRunSort from prefect.concurrency.context import ConcurrencyContext from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1 -from prefect.context import FlowRunContext, SyncClientContext, TagsContext +from prefect.context import ( + AsyncClientContext, + FlowRunContext, + SyncClientContext, + TagsContext, +) from prefect.exceptions import ( Abort, Pause, @@ -79,6 +85,7 @@ _resolve_custom_flow_run_name, capture_sigterm, link_state_to_result, + propose_state, propose_state_sync, resolve_to_final_result, ) @@ -747,10 +754,10 @@ class AsyncFlowRunEngine(BaseFlowRunEngine[P, R]): not being fully asyncified. """ - _client: Optional[SyncPrefectClient] = None + _client: Optional[PrefectClient] = None @property - def client(self) -> SyncPrefectClient: + def client(self) -> PrefectClient: if not self._is_started or self._client is None: raise RuntimeError("Engine has not started.") return self._client @@ -794,12 +801,12 @@ def _wait_for_dependencies(self): context={}, ) - def begin_run(self) -> State: + async def begin_run(self) -> State: try: self._resolve_parameters() self._wait_for_dependencies() except UpstreamTaskError as upstream_exc: - state = self.set_state( + state = await self.set_state( Pending( name="NotReady", message=str(upstream_exc), @@ -817,7 +824,7 @@ def begin_run(self) -> State: except Exception as exc: message = "Validation of flow parameters failed with error:" self.logger.error("%s %s", message, exc) - self.handle_exception( + await self.handle_exception( exc, msg=message, result_store=get_result_store().update_for_flow( @@ -825,22 +832,22 @@ def begin_run(self) -> State: ), ) self.short_circuit = True - self.call_hooks() + await self.call_hooks() new_state = Running() - state = self.set_state(new_state) + state = await self.set_state(new_state) while state.is_pending(): - time.sleep(0.2) - state = self.set_state(new_state) + await asyncio.sleep(0.2) + state = await self.set_state(new_state) return state - def set_state(self, state: State, force: bool = False) -> State: + async def set_state(self, state: State, force: bool = False) -> State: """ """ # prevents any state-setting activity if self.short_circuit: return self.state - state = propose_state_sync( + state = await propose_state( self.client, state, flow_run_id=self.flow_run.id, force=force ) # type: ignore self.flow_run.state = state # type: ignore @@ -859,7 +866,7 @@ def set_state(self, state: State, force: bool = False) -> State: ) return state - def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": + async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": if self._return_value is not NotSet and not isinstance( self._return_value, State ): @@ -871,7 +878,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": if asyncio.iscoroutine(_result): # getting the value for a BaseResult may return an awaitable # depending on whether the parent frame is sync or not - _result = run_coro_as_sync(_result) + _result = await _result return _result if self._raised is not NotSet: @@ -888,29 +895,27 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": # state.result is a `sync_compatible` function that may or may not return an awaitable # depending on whether the parent frame is sync or not if asyncio.iscoroutine(_result): - _result = run_coro_as_sync(_result) + _result = await _result return _result - def handle_success(self, result: R) -> R: + async def handle_success(self, result: R) -> R: result_store = getattr(FlowRunContext.get(), "result_store", None) if result_store is None: raise ValueError("Result store is not set") resolved_result = resolve_futures_to_states(result) - terminal_state = run_coro_as_sync( - return_value_to_state( - resolved_result, - result_store=result_store, - write_result=should_persist_result(), - ) + terminal_state = await return_value_to_state( + resolved_result, + result_store=result_store, + write_result=should_persist_result(), ) - self.set_state(terminal_state) + await self.set_state(terminal_state) self._return_value = resolved_result self._end_span_on_success() return result - def handle_exception( + async def handle_exception( self, exc: Exception, msg: Optional[str] = None, @@ -919,16 +924,14 @@ def handle_exception( context = FlowRunContext.get() terminal_state = cast( State, - run_coro_as_sync( - exception_to_failed_state( - exc, - message=msg or "Flow run encountered an exception:", - result_store=result_store or getattr(context, "result_store", None), - write_result=True, - ) + await exception_to_failed_state( + exc, + message=msg or "Flow run encountered an exception:", + result_store=result_store or getattr(context, "result_store", None), + write_result=True, ), ) - state = self.set_state(terminal_state) + state = await self.set_state(terminal_state) if self.state.is_scheduled(): self.logger.info( ( @@ -936,14 +939,14 @@ def handle_exception( f" state {terminal_state.name!r} and will attempt to run again..." ), ) - state = self.set_state(Running()) + state = await self.set_state(Running()) self._raised = exc self._end_span_on_error(exc, state.message) return state - def handle_timeout(self, exc: TimeoutError) -> None: + async def handle_timeout(self, exc: TimeoutError) -> None: if isinstance(exc, FlowRunTimeoutError): message = ( f"Flow run exceeded timeout of {self.flow.timeout_seconds} second(s)" @@ -956,24 +959,27 @@ def handle_timeout(self, exc: TimeoutError) -> None: message=message, name="TimedOut", ) - self.set_state(state) + await self.set_state(state) self._raised = exc self._end_span_on_error(exc, message) - def handle_crash(self, exc: BaseException) -> None: - state = run_coro_as_sync(exception_to_crashed_state(exc)) - self.logger.error(f"Crash detected! {state.message}") - self.logger.debug("Crash details:", exc_info=exc) - self.set_state(state, force=True) - self._raised = exc + async def handle_crash(self, exc: BaseException) -> None: + # need to shield from asyncio cancellation to ensure we update the state + # on the server before exiting + with CancelScope(shield=True): + state = await exception_to_crashed_state(exc) + self.logger.error(f"Crash detected! {state.message}") + self.logger.debug("Crash details:", exc_info=exc) + await self.set_state(state, force=True) + self._raised = exc - self._end_span_on_error(exc, state.message) + self._end_span_on_error(exc, state.message) - def load_subflow_run( + async def load_subflow_run( self, parent_task_run: TaskRun, - client: SyncPrefectClient, + client: PrefectClient, context: FlowRunContext, ) -> Union[FlowRun, None]: """ @@ -1007,7 +1013,7 @@ def load_subflow_run( rerunning and not parent_task_run.state.is_completed() ): # return the most recent flow run, if it exists - flow_runs = client.read_flow_runs( + flow_runs = await client.read_flow_runs( flow_run_filter=FlowRunFilter( parent_task_run_id={"any_": [parent_task_run.id]} ), @@ -1019,7 +1025,7 @@ def load_subflow_run( self._return_value = loaded_flow_run.state return loaded_flow_run - def create_flow_run(self, client: SyncPrefectClient) -> FlowRun: + async def create_flow_run(self, client: PrefectClient) -> FlowRun: flow_run_ctx = FlowRunContext.get() parameters = self.parameters or {} @@ -1032,21 +1038,19 @@ def create_flow_run(self, client: SyncPrefectClient) -> FlowRun: name=self.flow.name, fn=self.flow.fn, version=self.flow.version ) - parent_task_run = run_coro_as_sync( - parent_task.create_run( - flow_run_context=flow_run_ctx, - parameters=self.parameters, - wait_for=self.wait_for, - ) + parent_task_run = await parent_task.create_run( + flow_run_context=flow_run_ctx, + parameters=self.parameters, + wait_for=self.wait_for, ) # check if there is already a flow run for this subflow - if subflow_run := self.load_subflow_run( + if subflow_run := await self.load_subflow_run( parent_task_run=parent_task_run, client=client, context=flow_run_ctx ): return subflow_run - flow_run = client.create_flow_run( + flow_run = await client.create_flow_run( flow=self.flow, parameters=self.flow.serialize_parameters(parameters), state=Pending(), @@ -1065,7 +1069,7 @@ def create_flow_run(self, client: SyncPrefectClient) -> FlowRun: return flow_run - def call_hooks(self, state: Optional[State] = None): + async def call_hooks(self, state: Optional[State] = None): if state is None: state = self.state flow = self.flow @@ -1112,7 +1116,7 @@ def call_hooks(self, state: Optional[State] = None): ) result = hook(flow, flow_run, state) if asyncio.iscoroutine(result): - run_coro_as_sync(result) + await result except Exception: self.logger.error( f"An error was encountered while running hook {hook_name!r}", @@ -1121,8 +1125,8 @@ def call_hooks(self, state: Optional[State] = None): else: self.logger.info(f"Hook {hook_name!r} finished running successfully") - @contextmanager - def setup_run_context(self, client: Optional[SyncPrefectClient] = None): + @asynccontextmanager + async def setup_run_context(self, client: Optional[PrefectClient] = None): from prefect.utilities.engine import ( should_log_prints, ) @@ -1132,7 +1136,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): if not self.flow_run: raise ValueError("Flow run not set") - self.flow_run = client.read_flow_run(self.flow_run.id) + self.flow_run = await client.read_flow_run(self.flow_run.id) log_prints = should_log_prints(self.flow) with ExitStack() as stack: @@ -1169,7 +1173,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): flow_run_name = _resolve_custom_flow_run_name( flow=self.flow, parameters=self.parameters ) - self.client.set_flow_run_name( + await self.client.set_flow_run_name( flow_run_id=self.flow_run.id, name=flow_run_name ) self.logger.extra["flow_run_name"] = flow_run_name @@ -1180,17 +1184,17 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): self._flow_run_name_set = True yield - @contextmanager - def initialize_run(self): + @asynccontextmanager + async def initialize_run(self): """ Enters a client context and creates a flow run if needed. """ - with SyncClientContext.get_or_create() as client_ctx: + async with AsyncClientContext.get_or_create() as client_ctx: self._client = client_ctx.client self._is_started = True if not self.flow_run: - self.flow_run = self.create_flow_run(self.client) + self.flow_run = await self.create_flow_run(self.client) flow_run_url = url_for(self.flow_run) if flow_run_url: @@ -1207,7 +1211,7 @@ def initialize_run(self): if self.flow_run.empirical_policy.retries is None: self.flow_run.empirical_policy.retries = self.flow.retries - self.client.update_flow_run( + await self.client.update_flow_run( flow_run_id=self.flow_run.id, flow_version=self.flow.version, empirical_policy=self.flow_run.empirical_policy, @@ -1229,7 +1233,7 @@ def initialize_run(self): except TerminationSignal as exc: self.cancel_all_tasks() - self.handle_crash(exc) + await self.handle_crash(exc) raise except Exception: # regular exceptions are caught and re-raised to the user @@ -1241,7 +1245,7 @@ def initialize_run(self): raise except BaseException as exc: # BaseExceptions are caught and handled as crashes - self.handle_crash(exc) + await self.handle_crash(exc) raise finally: # If debugging, use the more complete `repr` than the usual `str` description @@ -1262,20 +1266,21 @@ def initialize_run(self): # # -------------------------- - @contextmanager - def start(self) -> Generator[None, None, None]: - with self.initialize_run(), trace.use_span(self._span): - self.begin_run() + @asynccontextmanager + async def start(self) -> AsyncGenerator[None, None]: + async with self.initialize_run(): + with trace.use_span(self._span): + await self.begin_run() - if self.state.is_running(): - self.call_hooks() - yield + if self.state.is_running(): + await self.call_hooks() + yield - @contextmanager - def run_context(self): + @asynccontextmanager + async def run_context(self): timeout_context = timeout_async if self.flow.isasync else timeout # reenter the run context to ensure it is up to date for every run - with self.setup_run_context(): + async with self.setup_run_context(): try: with timeout_context( seconds=self.flow.timeout_seconds, @@ -1286,29 +1291,24 @@ def run_context(self): ) yield self except TimeoutError as exc: - self.handle_timeout(exc) + await self.handle_timeout(exc) except Exception as exc: self.logger.exception("Encountered exception during execution: %r", exc) - self.handle_exception(exc) + await self.handle_exception(exc) finally: if self.state.is_final() or self.state.is_cancelling(): - self.call_hooks() + await self.call_hooks() - def call_flow_fn(self) -> Union[R, Coroutine[Any, Any, R]]: + async def call_flow_fn(self) -> Coroutine[Any, Any, R]: """ Convenience method to call the flow function. Returns a coroutine if the flow is async. """ - if self.flow.isasync: - - async def _call_flow_fn(): - result = await call_with_parameters(self.flow.fn, self.parameters) - self.handle_success(result) + assert self.flow.isasync, "Flow must be async to be run with AsyncFlowRunEngine" - return _call_flow_fn() - else: - result = call_with_parameters(self.flow.fn, self.parameters) - self.handle_success(result) + result = await call_with_parameters(self.flow.fn, self.parameters) + await self.handle_success(result) + return result def run_flow_sync( @@ -1344,12 +1344,12 @@ async def run_flow_async( flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for ) - with engine.start(): + async with engine.start(): while engine.is_running(): - with engine.run_context(): + async with engine.run_context(): await engine.call_flow_fn() - return engine.state if return_type == "state" else engine.result() + return engine.state if return_type == "state" else await engine.result() def run_generator_flow_sync( @@ -1402,9 +1402,9 @@ async def run_generator_flow_async( flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for ) - with engine.start(): + async with engine.start(): while engine.is_running(): - with engine.run_context(): + async with engine.run_context(): call_args, call_kwargs = parameters_to_args_kwargs( flow.fn, engine.parameters or {} ) @@ -1417,13 +1417,13 @@ async def run_generator_flow_async( link_state_to_result(engine.state, gen_result) yield gen_result except (StopAsyncIteration, GeneratorExit) as exc: - engine.handle_success(None) + await engine.handle_success(None) if isinstance(exc, GeneratorExit): gen.throw(exc) # async generators can't return, but we can raise failures here if engine.state.is_failed(): - engine.result() + await engine.result() def run_flow( diff --git a/tests/public/flows/test_flow_crashes.py b/tests/public/flows/test_flow_crashes.py index a4ca50aadb07..a461b6810fd8 100644 --- a/tests/public/flows/test_flow_crashes.py +++ b/tests/public/flows/test_flow_crashes.py @@ -28,7 +28,7 @@ async def assert_flow_run_crashed(flow_run: FlowRun, expected_message: str): """ Utility for asserting that flow runs are crashed. """ - assert flow_run.state.is_crashed() + assert flow_run.state.is_crashed(), flow_run.state assert expected_message in flow_run.state.message with pytest.raises(prefect.exceptions.CrashedRun, match=expected_message): await flow_run.state.result() diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index c291804285f8..d1bc1a9bbdfd 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -32,6 +32,7 @@ Pause, ) from prefect.flow_engine import ( + AsyncFlowRunEngine, FlowRunEngine, load_flow_and_flow_run, run_flow, @@ -57,24 +58,24 @@ async def foo(): class TestFlowRunEngine: - async def test_basic_init(self): + def test_basic_init(self): engine = FlowRunEngine(flow=foo) assert isinstance(engine.flow, Flow) assert engine.flow.name == "foo" assert engine.parameters == {} - async def test_empty_init(self): + def test_empty_init(self): with pytest.raises( TypeError, match="missing 1 required positional argument: 'flow'" ): FlowRunEngine() - async def test_client_attr_raises_informative_error(self): + def test_client_attr_raises_informative_error(self): engine = FlowRunEngine(flow=foo) with pytest.raises(RuntimeError, match="not started"): engine.client - async def test_client_attr_returns_client_after_starting(self): + def test_client_attr_returns_client_after_starting(self): engine = FlowRunEngine(flow=foo) with engine.initialize_run(): client = engine.client @@ -83,21 +84,33 @@ async def test_client_attr_returns_client_after_starting(self): with pytest.raises(RuntimeError, match="not started"): engine.client - async def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run): - flow_code = """ - from prefect import flow - @flow - def dog(): - return "woof!" - """ - fpath = tmp_path / "f.py" - fpath.write_text(dedent(flow_code)) +class TestAsyncFlowRunEngine: + def test_basic_init(self): + engine = AsyncFlowRunEngine(flow=foo) + assert isinstance(engine.flow, Flow) + assert engine.flow.name == "foo" + assert engine.parameters == {} - monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog") - loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id) - assert loaded_flow_run.id == flow_run.id - assert flow.fn() == "woof!" + def test_empty_init(self): + with pytest.raises( + TypeError, match="missing 1 required positional argument: 'flow'" + ): + AsyncFlowRunEngine() + + def test_client_attr_raises_informative_error(self): + engine = AsyncFlowRunEngine(flow=foo) + with pytest.raises(RuntimeError, match="not started"): + engine.client + + async def test_client_attr_returns_client_after_starting(self): + engine = AsyncFlowRunEngine(flow=foo) + async with engine.initialize_run(): + client = engine.client + assert isinstance(client, PrefectClient) + + with pytest.raises(RuntimeError, match="not started"): + engine.client class TestStartFlowRunEngine: @@ -119,6 +132,25 @@ def flow_with_retries(): engine.begin_run() +class TestStartAsyncFlowRunEngine: + async def test_start_updates_empirical_policy_on_provided_flow_run( + self, prefect_client: PrefectClient + ): + @flow(retries=3, retry_delay_seconds=10) + def flow_with_retries(): + pass + + flow_run = await prefect_client.create_flow_run(flow_with_retries) + + engine = AsyncFlowRunEngine(flow=flow_with_retries, flow_run=flow_run) + async with engine.start(): + assert engine.flow_run.empirical_policy.retries == 3 + assert engine.flow_run.empirical_policy.retry_delay == 10 + + # avoid error on teardown + await engine.begin_run() + + class TestFlowRunsAsync: async def test_basic(self): @flow @@ -1744,6 +1776,22 @@ def g(required: str, model: TheModel = {"x": [1, 2, 3]}): # type: ignore class TestLoadFlowAndFlowRun: + def test_load_flow_from_entrypoint(self, monkeypatch, tmp_path, flow_run): + flow_code = """ + from prefect import flow + + @flow + def dog(): + return "woof!" + """ + fpath = tmp_path / "f.py" + fpath.write_text(dedent(flow_code)) + + monkeypatch.setenv("PREFECT__FLOW_ENTRYPOINT", f"{fpath}:dog") + loaded_flow_run, flow = load_flow_and_flow_run(flow_run.id) + assert loaded_flow_run.id == flow_run.id + assert flow.fn() == "woof!" + async def test_load_flow_from_script_with_module_level_sync_compatible_call( self, prefect_client: PrefectClient, tmp_path ):