Skip to content

Commit

Permalink
AIP-72: Adding Endpoint to set rendered task instance fields (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 9, 2024
1 parent 7d05a47 commit 8b1492e
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 4 deletions.
6 changes: 5 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import Discriminator, Field, Tag, WithJsonSchema
from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
Expand Down Expand Up @@ -135,3 +135,7 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None


"""Schema for setting RTIF for a task instance."""
RTIFPayload = RootModel[dict[str, str]]
33 changes: 32 additions & 1 deletion airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
RTIFPayload,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -219,3 +220,33 @@ def ti_heartbeat(
# Update the last heartbeat time!
session.execute(update(TI).where(TI.id == ti_id_str).values(last_heartbeat_at=timezone.utcnow()))
log.debug("Task with %s state heartbeated", previous_state)


@router.put(
"/{task_instance_id}/rtif",
status_code=status.HTTP_201_CREATED,
# TODO: Add description to the operation
# TODO: Add Operation ID to control the function name in the OpenAPI spec
# TODO: Do we need to use create_openapi_http_exception_doc here?
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {
"description": "Invalid payload for the setting rendered task instance fields"
},
},
)
def ti_put_rtif(
task_instance_id: UUID,
put_rtif_payload: RTIFPayload,
session: SessionDep,
):
"""Add an RTIF entry for a task instance, sent by the worker."""
ti_id_str = str(task_instance_id)
task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
if not task_instance:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
_update_rtif(task_instance, put_rtif_payload.model_dump(), session)

return {"message": "Rendered task instance fields successfully set"}
8 changes: 8 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def defer(self, id: uuid.UUID, msg):
# Create a deferred state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
"""Set Rendered Task Instance Fields via the API server."""
self.client.put(f"task-instances/{id}/rtif", json=body)
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
# decouple from the server response string
return {"ok": True}


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
42 changes: 42 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import httpx
import pytest
import uuid6

from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError
from airflow.sdk.api.datamodels._generated import VariableResponse, XComResponse
Expand Down Expand Up @@ -84,6 +85,47 @@ def make_client(transport: httpx.MockTransport) -> Client:
return Client(base_url="test://server", token="", transport=transport)


class TestTaskInstanceOperations:
"""
Test that the TestVariableOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for task instances including endpoint and
response parsing.
"""

# TODO: Add tests for different ti endpoints

@pytest.mark.parametrize(
"rendered_fields",
[
pytest.param({"field1": "rendered_value1", "field2": "rendered_value2"}, id="simple-rendering"),
pytest.param(
{
"field1": "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
"{'att1': 'test', 'att2': 'test2'), "
"'nested2': ClassWithCustomAttributes("
"{'att3': 'test3', 'att4': 'test4')"
},
id="complex-rendering",
),
],
)
def test_taskinstance_set_rtif_success(self, rendered_fields):
TI_ID = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{TI_ID}/rtif":
return httpx.Response(
status_code=201,
json={"message": "Rendered task instance fields successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.task_instances.set_rtif(id=TI_ID, body=rendered_fields)

assert result == {"ok": True}


class TestVariableOperations:
"""
Test that the VariableOperations class works as expected. While the operations are simple, it
Expand Down
80 changes: 78 additions & 2 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from unittest import mock

import pytest
import uuid6
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError

from airflow.models import Trigger
from airflow.models import RenderedTaskInstanceFields, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState

from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -410,3 +411,78 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
# If successful, ensure last_heartbeat_at is updated
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)


class TestTIPutRTIF:
def setup_method(self):
clear_db_runs()
clear_rendered_ti_fields()

def teardown_method(self):
clear_db_runs()
clear_rendered_ti_fields()

def test_ti_put_rtif_success(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_put_rtif_success",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {"field1": "rendered_value1", "field2": "rendered_value2"}

response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload)
assert response.status_code == 201
assert response.json() == {"message": "Rendered task instance fields successfully set"}

session.expire_all()

rtifs = session.query(RenderedTaskInstanceFields).all()
assert len(rtifs) == 1

assert rtifs[0].dag_id == "dag"
assert rtifs[0].run_id == "test"
assert rtifs[0].task_id == "test_ti_put_rtif_success"
assert rtifs[0].map_index == -1
assert rtifs[0].rendered_fields == payload

def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance):
create_task_instance(
task_id="test_ti_put_rtif_missing_ti",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {"field1": "rendered_value1", "field2": "rendered_value2"}

random_id = uuid6.uuid7()
response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Not Found"

def test_ti_put_rtif_extra_fields(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_put_rtif_missing_ti",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {
"field1": "rendered_value1",
"field2": "rendered_value2",
"invalid_key": {"field3": "rendered_value3"},
}

response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload)
assert response.status_code == 422
assert response.json()["detail"] == [
{
"input": {"field3": "rendered_value3"},
"loc": ["body", "invalid_key"],
"msg": "Input should be a valid string",
"type": "string_type",
}
]

0 comments on commit 8b1492e

Please sign in to comment.