diff --git a/src/prefect/client/schemas/filters.py b/src/prefect/client/schemas/filters.py index fa9c527edfbe..8c2ac8894cd1 100644 --- a/src/prefect/client/schemas/filters.py +++ b/src/prefect/client/schemas/filters.py @@ -235,6 +235,14 @@ class FlowRunFilterNextScheduledStartTime(PrefectBaseModel): ) +class FlowRunFilterParentFlowRunId(PrefectBaseModel, OperatorMixin): + """Filter for subflows of the given flow runs""" + + any_: Optional[List[UUID]] = Field( + default=None, description="A list of flow run parents to include" + ) + + class FlowRunFilterParentTaskRunId(PrefectBaseModel, OperatorMixin): """Filter by `FlowRun.parent_task_run_id`.""" diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py index 2b96a4f254b3..07e088c25377 100644 --- a/src/prefect/server/schemas/filters.py +++ b/src/prefect/server/schemas/filters.py @@ -453,6 +453,31 @@ def _get_filter_list(self, db: "PrefectDBInterface") -> List: return filters +class FlowRunFilterParentFlowRunId(PrefectOperatorFilterBaseModel): + """Filter for subflows of a given flow run""" + + any_: Optional[List[UUID]] = Field( + default=None, description="A list of parent flow run ids to include" + ) + + def _get_filter_list(self, db: "PrefectDBInterface") -> List: + filters = [] + if self.any_ is not None: + filters.append( + db.FlowRun.id.in_( + sa.select(db.FlowRun.id) + .join( + db.TaskRun, + sa.and_( + db.TaskRun.id == db.FlowRun.parent_task_run_id, + ), + ) + .where(db.TaskRun.flow_run_id.in_(self.any_)) + ) + ) + return filters + + class FlowRunFilterParentTaskRunId(PrefectOperatorFilterBaseModel): """Filter by `FlowRun.parent_task_run_id`.""" @@ -530,6 +555,9 @@ class FlowRunFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `FlowRun.next_scheduled_start_time`", ) + parent_flow_run_id: Optional[FlowRunFilterParentFlowRunId] = Field( + default=None, description="Filter criteria for subflows of the given flow runs" + ) parent_task_run_id: Optional[FlowRunFilterParentTaskRunId] = Field( default=None, description="Filter criteria for `FlowRun.parent_task_run_id`" ) @@ -550,6 +578,7 @@ def only_filters_on_id(self): and self.start_time is None and self.expected_start_time is None and self.next_scheduled_start_time is None + and self.parent_flow_run_id is None and self.parent_task_run_id is None and self.idempotency_key is None ) @@ -577,6 +606,8 @@ def _get_filter_list(self, db: "PrefectDBInterface") -> List: filters.append(self.expected_start_time.as_sql_filter(db)) if self.next_scheduled_start_time is not None: filters.append(self.next_scheduled_start_time.as_sql_filter(db)) + if self.parent_flow_run_id is not None: + filters.append(self.parent_flow_run_id.as_sql_filter(db)) if self.parent_task_run_id is not None: filters.append(self.parent_task_run_id.as_sql_filter(db)) if self.idempotency_key is not None: diff --git a/tests/server/api/test_flow_runs.py b/tests/server/api/test_flow_runs.py index 10a96a0a1895..6a0c5fca4b02 100644 --- a/tests/server/api/test_flow_runs.py +++ b/tests/server/api/test_flow_runs.py @@ -1,5 +1,5 @@ from typing import List -from uuid import uuid4 +from uuid import UUID, uuid4 import pendulum import pydantic @@ -645,6 +645,131 @@ async def test_read_flow_runs_sort_succeeds_for_all_sort_values( assert len(response.json()) == 1 assert response.json()[0]["id"] == str(flow_run.id) + @pytest.fixture + async def parent_flow_run(self, flow, session): + flow_run = await models.flow_runs.create_flow_run( + session=session, + flow_run=schemas.core.FlowRun( + flow_id=flow.id, + flow_version="1.0", + state=schemas.states.Pending(), + ), + ) + await session.commit() + return flow_run + + @pytest.fixture + async def child_runs( + self, + flow, + parent_flow_run, + session, + ): + children = [] + for i in range(5): + dummy_task = await models.task_runs.create_task_run( + session=session, + task_run=schemas.core.TaskRun( + flow_run_id=parent_flow_run.id, + name=f"dummy-{i}", + task_key=f"dummy-{i}", + dynamic_key=f"dummy-{i}", + ), + ) + children.append( + await models.flow_runs.create_flow_run( + session=session, + flow_run=schemas.core.FlowRun( + flow_id=flow.id, + flow_version="1.0", + state=schemas.states.Pending(), + parent_task_run_id=dummy_task.id, + ), + ) + ) + return children + + @pytest.fixture + async def grandchild_runs(self, flow, child_runs, session): + grandchildren = [] + for child in child_runs: + for i in range(3): + dummy_task = await models.task_runs.create_task_run( + session=session, + task_run=schemas.core.TaskRun( + flow_run_id=child.id, + name=f"dummy-{i}", + task_key=f"dummy-{i}", + dynamic_key=f"dummy-{i}", + ), + ) + grandchildren.append( + await models.flow_runs.create_flow_run( + session=session, + flow_run=schemas.core.FlowRun( + flow_id=flow.id, + flow_version="1.0", + state=schemas.states.Pending(), + parent_task_run_id=dummy_task.id, + ), + ) + ) + return grandchildren + + async def test_read_subflow_runs( + self, + client, + parent_flow_run, + child_runs, + # included to make sure we're only going 1 level deep + grandchild_runs, + # included to make sure we're not bringing in extra flow runs + flow_runs, + ): + """We should be able to find all subflow runs of a given flow run.""" + subflow_filter = { + "flow_runs": schemas.filters.FlowRunFilter( + parent_flow_run_id=schemas.filters.FlowRunFilterParentFlowRunId( + any_=[parent_flow_run.id] + ) + ).dict(json_compatible=True) + } + + response = await client.post( + "/flow_runs/filter", + json=subflow_filter, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == len(child_runs) + + returned = {UUID(run["id"]) for run in response.json()} + expected = {run.id for run in child_runs} + assert returned == expected + + async def test_read_subflow_runs_non_existant( + self, + client, + # including these to make sure we aren't bringing in extra flow runs + parent_flow_run, + child_runs, + grandchild_runs, + flow_runs, + ): + subflow_filter = { + "flow_runs": schemas.filters.FlowRunFilter( + parent_flow_run_id=schemas.filters.FlowRunFilterParentFlowRunId( + any_=[uuid4()] + ) + ).dict(json_compatible=True) + } + + response = await client.post( + "/flow_runs/filter", + json=subflow_filter, + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 0 + class TestReadFlowRunGraph: @pytest.fixture