Skip to content

Commit

Permalink
Make filters param optional and fix typing (apache#44226)
Browse files Browse the repository at this point in the history
Given that sometimes we don't want to apply any filters, it makes sense to make the param optional.  I also fix the typing on `paginated_select`.
  • Loading branch information
dstandish authored Nov 21, 2024
1 parent 9bc2840 commit 22d1406
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 55 deletions.
56 changes: 45 additions & 11 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Database helpers for Airflow REST API.
:meta private:
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Literal, Sequence, overload

from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -47,30 +52,59 @@ def your_route(session: Annotated[Session, Depends(get_session)]):
yield session


def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select:
base_select = base_select
for filter in filters:
if filter is None:
def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
if filters is None:
return base_select
for f in filters:
if f is None:
continue
base_select = filter.to_orm(base_select)
base_select = f.to_orm(base_select)

return base_select


@overload
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...


@overload
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...


@provide_session
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam],
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> Select:
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
select,
filters,
base_select=select,
filters=filters,
)

total_entries = None
Expand All @@ -82,6 +116,6 @@ def paginated_select(
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select, [order_by, offset, limit])
base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])

return base_select, total_entries
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_assets(
limit=limit,
session=session,
)

assets = session.scalars(
assets_select.options(
subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks)
Expand Down Expand Up @@ -211,7 +212,7 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
dag_asset_queued_events_select, total_entries = paginated_select(select=query)
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down Expand Up @@ -270,9 +271,8 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
dag_asset_queued_events_select, total_entries = paginated_select(select=query)
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with dag_id: `{dag_id}` was not found")

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def list_backfills(
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select=select(Backfill).where(Backfill.dag_id == dag_id),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

backfills = session.scalars(select_stmt)

return BackfillCollectionResponse(
backfills=[BackfillResponse.model_validate(x, from_attributes=True) for x in backfills],
backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills],
total_entries=total_entries,
)

Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def get_connections(
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select=select(Connection),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
3 changes: 1 addition & 2 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ def get_dag_runs(
limit=limit,
session=session,
)

dag_runs = session.scalars(dag_run_select)
return DAGRunCollectionResponse(
dag_runs=[DAGRunResponse.model_validate(dag_run, from_attributes=True) for dag_run in dag_runs],
dag_runs=[DAGRunResponse.model_validate(dr, from_attributes=True) for dr in dag_runs],
total_entries=total_entries,
)
6 changes: 1 addition & 5 deletions airflow/api_fastapi/core_api/routes/public/dag_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@ def list_dag_warnings(
limit=limit,
session=session,
)

dag_warnings = session.scalars(dag_warnings_select)

return DAGWarningCollectionResponse(
dag_warnings=[
DAGWarningResponse.model_validate(dag_warning, from_attributes=True)
for dag_warning in dag_warnings
],
dag_warnings=[DAGWarningResponse.model_validate(w, from_attributes=True) for w in dag_warnings],
total_entries=total_entries,
)
8 changes: 3 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_dag_tags(
session=session,
)
dag_tags = session.execute(dag_tags_select).scalars().all()
return DAGTagCollectionResponse(tags=[dag_tag for dag_tag in dag_tags], total_entries=total_entries)
return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries)


@dags_router.get(
Expand Down Expand Up @@ -259,6 +259,7 @@ def patch_dags(
status.HTTP_400_BAD_REQUEST, "Only `is_paused` field can be updated through the REST API"
)
else:
# todo: this is not used?
update_mask = ["is_paused"]

dags_select, total_entries = paginated_select(
Expand All @@ -269,11 +270,8 @@ def patch_dags(
limit=limit,
session=session,
)

dags = session.scalars(dags_select).all()

dags_to_update = {dag.dag_id for dag in dags}

session.execute(
update(DagModel)
.where(DagModel.dag_id.in_(dags_to_update))
Expand All @@ -282,7 +280,7 @@ def patch_dags(
)

return DAGCollectionResponse(
dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags],
dags=[DAGResponse.model_validate(d, from_attributes=True) for d in dags],
total_entries=total_entries,
)

Expand Down
9 changes: 1 addition & 8 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def get_event_logs(
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
select=base_select,
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand All @@ -135,12 +134,6 @@ def get_event_logs(
event_logs = session.scalars(event_logs_select)

return EventLogCollectionResponse(
event_logs=[
EventLogResponse.model_validate(
event_log,
from_attributes=True,
)
for event_log in event_logs
],
event_logs=[EventLogResponse.model_validate(e, from_attributes=True) for e in event_logs],
total_entries=total_entries,
)
5 changes: 1 addition & 4 deletions airflow/api_fastapi/core_api/routes/public/import_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_import_errors(
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select=select(ParseImportError),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand All @@ -99,8 +98,6 @@ def get_import_errors(
import_errors = session.scalars(import_errors_select)

return ImportErrorCollectionResponse(
import_errors=[
ImportErrorResponse.model_validate(error, from_attributes=True) for error in import_errors
],
import_errors=[ImportErrorResponse.model_validate(i, from_attributes=True) for i in import_errors],
total_entries=total_entries,
)
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def get_pools(
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select=select(Pool),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
14 changes: 2 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def get_mapped_task_instances(
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
Expand Down Expand Up @@ -335,14 +334,9 @@ def get_task_instances(
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(task_instance, from_attributes=True)
for task_instance in task_instances
],
task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances],
total_entries=total_entries,
)

Expand Down Expand Up @@ -411,18 +405,14 @@ def get_task_instances_batch(
limit=limit,
session=session,
)

task_instance_select = task_instance_select.options(
joinedload(TI.rendered_task_instance_fields), joinedload(TI.task_instance_note)
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(task_instance, from_attributes=True)
for task_instance in task_instances
],
task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances],
total_entries=total_entries,
)

Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_variables(
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
select=select(Variable),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down

0 comments on commit 22d1406

Please sign in to comment.