Skip to content

Commit

Permalink
AIP-72: Extending SET RTIF endpoint to accept all JSONable types (ap…
Browse files Browse the repository at this point in the history
…ache#44843)

An endpoint to set RTIF was added in apache#44359. This allowed only `dict[str, str]` entries to be passed down to the API which lead to issues when running tests with DAGs like:
```py
from __future__ import annotations

import sys
import time
from datetime import datetime

from airflow import DAG
from airflow.decorators import dag, task
from airflow.operators.bash import BashOperator


@dag(
    # every minute on the 30-second mark
    catchup=False,
    tags=[],
    schedule=None,
    start_date=datetime(2021, 1, 1),
)
def hello_dag():
    """
    ### TaskFlow API Tutorial Documentation
    This is a simple data pipeline example which demonstrates the use of
    the TaskFlow API using three simple tasks for Extract, Transform, and Load.
    Documentation that goes along with the Airflow TaskFlow API tutorial is
    located
    [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
    """

    @task()
    def hello():
        print("hello")
        time.sleep(3)
        print("goodbye")
        print("err mesg", file=sys.stderr)

    hello()


hello_dag()
```

The reason for this is that the arguments such as `op_args` and `op_kwargs` for PythonOperator can be non str. So that leads to a conclusion that we should accept `str` keys but `JsonAble` values.

Some points to note for reviewers:
1. Type we store in the table: https://github.com/apache/airflow/blob/1eb683be3a79c80927e9af1e89dabb5e78ce3136/airflow/models/renderedtifields.py#L76. Hence we should be able to accept any JsonAble types and store them, for non JsonAble ones like tuple and set, we should convert them and do it.


### What does this PR change?
- Get rid of the `RTIFPayload` and consume the payload directly in the api handler.
- Handling special case of `tuples` - they are json serialisable but we used to store them as lists when passed as tuples, because of usage of json.dumps(). It has been made like this now:
```
    def is_jsonable(x):
        try:
            json.dumps(x)
            if isinstance(x, tuple):
                # Tuple is converted to list in json.dumps
                # so while it is jsonable, it changes the type which might be a surprise
                # for the user, so instead we return False here -- which will convert it to string
                return False
```
- Reusing `serialize_template_field` from `airflow.serialization.helpers` because copy pasting code will be expensive, hard to maintain. We will revisit it anyways when we port the logic of templating to TASK SDK. Discussion: https://github.com/apache/airflow/pull/44843/files#r1882834039
- Added test cases with different scopes and different types to handle different cases of templated_fields well.
  • Loading branch information
amoghrajesh authored Dec 13, 2024
1 parent 2c01457 commit ff7e700
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 60 deletions.
6 changes: 1 addition & 5 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema
from pydantic import Discriminator, Field, Tag, WithJsonSchema

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
Expand Down Expand Up @@ -135,7 +135,3 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None


"""Schema for setting RTIF for a task instance."""
RTIFPayload = RootModel[dict[str, str]]
6 changes: 3 additions & 3 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from uuid import UUID

from fastapi import Body, HTTPException, status
from pydantic import JsonValue
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
RTIFPayload,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
Expand Down Expand Up @@ -237,7 +237,7 @@ def ti_heartbeat(
)
def ti_put_rtif(
task_instance_id: UUID,
put_rtif_payload: RTIFPayload,
put_rtif_payload: Annotated[dict[str, JsonValue], Body()],
session: SessionDep,
):
"""Add an RTIF entry for a task instance, sent by the worker."""
Expand All @@ -247,6 +247,6 @@ def ti_put_rtif(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
_update_rtif(task_instance, put_rtif_payload.model_dump(), session)
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}
5 changes: 5 additions & 0 deletions airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def serialize_template_field(template_field: Any, name: str) -> str | dict | lis
def is_jsonable(x):
try:
json.dumps(x)
if isinstance(x, tuple):
# Tuple is converted to list in json.dumps
# so while it is jsonable, it changes the type which might be a surprise
# for the user, so instead we return False here -- which will convert it to string
return False
except (TypeError, OverflowError):
return False
else:
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class SetRenderedFields(BaseModel):
# We are using a BaseModel here compared to server using RootModel because we
# have a discriminator running with "type", and RootModel doesn't support type

rendered_fields: dict[str, str | None]
rendered_fields: dict[str, JsonValue]
type: Literal["SetRenderedFields"] = "SetRenderedFields"


Expand Down
24 changes: 14 additions & 10 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import attrs
import structlog
from pydantic import BaseModel, ConfigDict, TypeAdapter
from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter

from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
Expand Down Expand Up @@ -196,22 +196,26 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
# 1. Implementing the part where we pull in the logic to render fields and add that here
# for all operators, we should do setattr(task, templated_field, rendered_templated_field)
# task.templated_fields should give all the templated_fields and each of those fields should
# give the rendered values.
# give the rendered values. task.templated_fields should already be in a JSONable format and
# we should not have to handle that here.

# 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB
templated_fields = ti.task.template_fields
payload = {}

for field in templated_fields:
if field not in payload:
payload[field] = getattr(ti.task, field)

# so that we do not call the API unnecessarily
if payload:
SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=payload))
if rendered_fields := _get_rendered_fields(ti.task):
SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields))
return ti, log


def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]:
# TODO: Port one of the following to Task SDK
# airflow.serialization.helpers.serialize_template_field or
# airflow.models.renderedtifields.get_serialized_template_fields
from airflow.serialization.helpers import serialize_template_field

return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields}


def run(ti: RuntimeTaskInstance, log: Logger):
"""Run the task in this process."""
from airflow.exceptions import (
Expand Down
58 changes: 58 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,61 @@ def test_startup_basic_templated_dag(mocked_parse):
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
pytest.param(
{"op_args": [], "op_kwargs": {}, "templates_dict": None},
{"op_args": [], "op_kwargs": {}, "templates_dict": None},
id="no_templates",
),
pytest.param(
{
"op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
"op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}},
},
{
"op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
"op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}},
},
id="mixed_types",
),
pytest.param(
{"my_tup": (1, 2), "my_set": {1, 2, 3}},
{"my_tup": "(1, 2)", "my_set": "{1, 2, 3}"},
id="tuples_and_sets",
),
],
)
def test_startup_dag_with_templated_fields(mocked_parse, task_params, expected_rendered_fields):
"""Test startup of a DAG with various templated fields."""

class CustomOperator(BaseOperator):
template_fields = tuple(task_params.keys())

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in task_params.items():
setattr(self, key, value)

task = CustomOperator(task_id="templated_task")

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1),
file="",
requests_fd=0,
)
mocked_parse(what, "basic_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what

startup()
mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
)
6 changes: 3 additions & 3 deletions tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_should_respond_200_task_instance_with_rendered(self, session):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down Expand Up @@ -403,7 +403,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down Expand Up @@ -2371,7 +2371,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down Expand Up @@ -444,7 +444,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down Expand Up @@ -3070,7 +3070,7 @@ def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_c
"try_number": 0,
"unixname": getuser(),
"dag_run_id": self.RUN_ID,
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
Expand Down
48 changes: 19 additions & 29 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,31 @@ def teardown_method(self):
clear_db_runs()
clear_rendered_ti_fields()

def test_ti_put_rtif_success(self, client, session, create_task_instance):
@pytest.mark.parametrize(
"payload",
[
# string value
{"field1": "string_value", "field2": "another_string"},
# dictionary value
{"field1": {"nested_key": "nested_value"}},
# string lists value
{"field1": ["123"], "field2": ["a", "b", "c"]},
# list of JSON values
{"field1": [1, "string", 3.14, True, None, {"nested": "dict"}]},
# nested dictionary with mixed types in lists
{
"field1": {"nested_dict": {"key1": 123, "key2": "value"}},
"field2": [3.14, {"sub_key": "sub_value"}, [1, 2]],
},
],
)
def test_ti_put_rtif_success(self, client, session, create_task_instance, payload):
ti = create_task_instance(
task_id="test_ti_put_rtif_success",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {"field1": "rendered_value1", "field2": "rendered_value2"}

response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload)
assert response.status_code == 201
assert response.json() == {"message": "Rendered task instance fields successfully set"}
Expand Down Expand Up @@ -461,28 +476,3 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance):
response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Not Found"

def test_ti_put_rtif_extra_fields(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_put_rtif_missing_ti",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {
"field1": "rendered_value1",
"field2": "rendered_value2",
"invalid_key": {"field3": "rendered_value3"},
}

response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload)
assert response.status_code == 422
assert response.json()["detail"] == [
{
"input": {"field3": "rendered_value3"},
"loc": ["body", "invalid_key"],
"msg": "Input should be a valid string",
"type": "string_type",
}
]
36 changes: 30 additions & 6 deletions tests/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import ast
import os
from collections import Counter
from datetime import date, timedelta
Expand Down Expand Up @@ -100,8 +101,12 @@ def teardown_method(self):
(None, None),
([], []),
({}, {}),
((), "()"),
(set(), "set()"),
("test-string", "test-string"),
({"foo": "bar"}, {"foo": "bar"}),
(("foo", "bar"), "('foo', 'bar')"),
({"foo", "bar"}, "{'foo', 'bar'}"),
("{{ task.task_id }}", "test"),
(date(2018, 12, 6), "2018-12-06"),
(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
Expand Down Expand Up @@ -158,16 +163,35 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da
assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
assert expected_rendered_field == rtif.rendered_fields.get("bash_command")
if type(templated_field) is set:
# the output order of a set is non-deterministic and can change per process.
# this validation can fail if that happens before stringification, so we convert to set and compare.
assert ast.literal_eval(expected_rendered_field) == ast.literal_eval(
rtif.rendered_fields.get("bash_command")
)
else:
assert expected_rendered_field == rtif.rendered_fields.get("bash_command")

session.add(rtif)
session.flush()

assert RTIF.get_templated_fields(ti=ti, session=session) == {
"bash_command": expected_rendered_field,
"env": None,
"cwd": None,
}
if type(templated_field) is set:
# the output order of a set is non-deterministic and can change per process.
# this validation can fail if that happens before stringification, so we convert to set and compare.
expected = RTIF.get_templated_fields(ti=ti, session=session)
expected["bash_command"] = ast.literal_eval(expected["bash_command"])
actual = {
"bash_command": ast.literal_eval(expected_rendered_field),
"env": None,
"cwd": None,
}
assert expected == actual
else:
assert RTIF.get_templated_fields(ti=ti, session=session) == {
"bash_command": expected_rendered_field,
"env": None,
"cwd": None,
}
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
Expand Down

0 comments on commit ff7e700

Please sign in to comment.