Skip to content

Commit

Permalink
Flow run filter for fetching the (first-level) subflows of a given fl…
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisguidry authored Sep 22, 2023
1 parent 326e619 commit d739f53
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/prefect/client/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ class FlowRunFilterNextScheduledStartTime(PrefectBaseModel):
)


class FlowRunFilterParentFlowRunId(PrefectBaseModel, OperatorMixin):
"""Filter for subflows of the given flow runs"""

any_: Optional[List[UUID]] = Field(
default=None, description="A list of flow run parents to include"
)


class FlowRunFilterParentTaskRunId(PrefectBaseModel, OperatorMixin):
"""Filter by `FlowRun.parent_task_run_id`."""

Expand Down
31 changes: 31 additions & 0 deletions src/prefect/server/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,31 @@ def _get_filter_list(self, db: "PrefectDBInterface") -> List:
return filters


class FlowRunFilterParentFlowRunId(PrefectOperatorFilterBaseModel):
"""Filter for subflows of a given flow run"""

any_: Optional[List[UUID]] = Field(
default=None, description="A list of parent flow run ids to include"
)

def _get_filter_list(self, db: "PrefectDBInterface") -> List:
filters = []
if self.any_ is not None:
filters.append(
db.FlowRun.id.in_(
sa.select(db.FlowRun.id)
.join(
db.TaskRun,
sa.and_(
db.TaskRun.id == db.FlowRun.parent_task_run_id,
),
)
.where(db.TaskRun.flow_run_id.in_(self.any_))
)
)
return filters


class FlowRunFilterParentTaskRunId(PrefectOperatorFilterBaseModel):
"""Filter by `FlowRun.parent_task_run_id`."""

Expand Down Expand Up @@ -530,6 +555,9 @@ class FlowRunFilter(PrefectOperatorFilterBaseModel):
default=None,
description="Filter criteria for `FlowRun.next_scheduled_start_time`",
)
parent_flow_run_id: Optional[FlowRunFilterParentFlowRunId] = Field(
default=None, description="Filter criteria for subflows of the given flow runs"
)
parent_task_run_id: Optional[FlowRunFilterParentTaskRunId] = Field(
default=None, description="Filter criteria for `FlowRun.parent_task_run_id`"
)
Expand All @@ -550,6 +578,7 @@ def only_filters_on_id(self):
and self.start_time is None
and self.expected_start_time is None
and self.next_scheduled_start_time is None
and self.parent_flow_run_id is None
and self.parent_task_run_id is None
and self.idempotency_key is None
)
Expand Down Expand Up @@ -577,6 +606,8 @@ def _get_filter_list(self, db: "PrefectDBInterface") -> List:
filters.append(self.expected_start_time.as_sql_filter(db))
if self.next_scheduled_start_time is not None:
filters.append(self.next_scheduled_start_time.as_sql_filter(db))
if self.parent_flow_run_id is not None:
filters.append(self.parent_flow_run_id.as_sql_filter(db))
if self.parent_task_run_id is not None:
filters.append(self.parent_task_run_id.as_sql_filter(db))
if self.idempotency_key is not None:
Expand Down
127 changes: 126 additions & 1 deletion tests/server/api/test_flow_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List
from uuid import uuid4
from uuid import UUID, uuid4

import pendulum
import pydantic
Expand Down Expand Up @@ -645,6 +645,131 @@ async def test_read_flow_runs_sort_succeeds_for_all_sort_values(
assert len(response.json()) == 1
assert response.json()[0]["id"] == str(flow_run.id)

@pytest.fixture
async def parent_flow_run(self, flow, session):
flow_run = await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.core.FlowRun(
flow_id=flow.id,
flow_version="1.0",
state=schemas.states.Pending(),
),
)
await session.commit()
return flow_run

@pytest.fixture
async def child_runs(
self,
flow,
parent_flow_run,
session,
):
children = []
for i in range(5):
dummy_task = await models.task_runs.create_task_run(
session=session,
task_run=schemas.core.TaskRun(
flow_run_id=parent_flow_run.id,
name=f"dummy-{i}",
task_key=f"dummy-{i}",
dynamic_key=f"dummy-{i}",
),
)
children.append(
await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.core.FlowRun(
flow_id=flow.id,
flow_version="1.0",
state=schemas.states.Pending(),
parent_task_run_id=dummy_task.id,
),
)
)
return children

@pytest.fixture
async def grandchild_runs(self, flow, child_runs, session):
grandchildren = []
for child in child_runs:
for i in range(3):
dummy_task = await models.task_runs.create_task_run(
session=session,
task_run=schemas.core.TaskRun(
flow_run_id=child.id,
name=f"dummy-{i}",
task_key=f"dummy-{i}",
dynamic_key=f"dummy-{i}",
),
)
grandchildren.append(
await models.flow_runs.create_flow_run(
session=session,
flow_run=schemas.core.FlowRun(
flow_id=flow.id,
flow_version="1.0",
state=schemas.states.Pending(),
parent_task_run_id=dummy_task.id,
),
)
)
return grandchildren

async def test_read_subflow_runs(
self,
client,
parent_flow_run,
child_runs,
# included to make sure we're only going 1 level deep
grandchild_runs,
# included to make sure we're not bringing in extra flow runs
flow_runs,
):
"""We should be able to find all subflow runs of a given flow run."""
subflow_filter = {
"flow_runs": schemas.filters.FlowRunFilter(
parent_flow_run_id=schemas.filters.FlowRunFilterParentFlowRunId(
any_=[parent_flow_run.id]
)
).dict(json_compatible=True)
}

response = await client.post(
"/flow_runs/filter",
json=subflow_filter,
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == len(child_runs)

returned = {UUID(run["id"]) for run in response.json()}
expected = {run.id for run in child_runs}
assert returned == expected

async def test_read_subflow_runs_non_existant(
self,
client,
# including these to make sure we aren't bringing in extra flow runs
parent_flow_run,
child_runs,
grandchild_runs,
flow_runs,
):
subflow_filter = {
"flow_runs": schemas.filters.FlowRunFilter(
parent_flow_run_id=schemas.filters.FlowRunFilterParentFlowRunId(
any_=[uuid4()]
)
).dict(json_compatible=True)
}

response = await client.post(
"/flow_runs/filter",
json=subflow_filter,
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 0


class TestReadFlowRunGraph:
@pytest.fixture
Expand Down

0 comments on commit d739f53

Please sign in to comment.