diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index e0d8f371f09d6..bbc557d012463 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -21,7 +21,7 @@ from datetime import timedelta from typing import Annotated, Any, Literal, Union -from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema +from pydantic import Discriminator, Field, Tag, WithJsonSchema from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.core_api.base import BaseModel @@ -135,7 +135,3 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int | None = None - - -"""Schema for setting RTIF for a task instance.""" -RTIFPayload = RootModel[dict[str, str]] diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 90bbe1c1d3e5b..e06798209c5da 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -22,6 +22,7 @@ from uuid import UUID from fastapi import Body, HTTPException, status +from pydantic import JsonValue from sqlalchemy import update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select @@ -29,7 +30,6 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( - RTIFPayload, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, @@ -237,7 +237,7 @@ def ti_heartbeat( ) def ti_put_rtif( task_instance_id: UUID, - put_rtif_payload: RTIFPayload, + put_rtif_payload: Annotated[dict[str, JsonValue], Body()], session: SessionDep, ): """Add an RTIF entry for a task instance, sent by the worker.""" @@ -247,6 +247,6 @@ def ti_put_rtif( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, ) - _update_rtif(task_instance, put_rtif_payload.model_dump(), session) + _update_rtif(task_instance, put_rtif_payload, session) return {"message": "Rendered task instance fields successfully set"} diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py index 85bf3a1cc551c..dc1aabbca986c 100644 --- a/airflow/serialization/helpers.py +++ b/airflow/serialization/helpers.py @@ -36,6 +36,11 @@ def serialize_template_field(template_field: Any, name: str) -> str | dict | lis def is_jsonable(x): try: json.dumps(x) + if isinstance(x, tuple): + # Tuple is converted to list in json.dumps + # so while it is jsonable, it changes the type which might be a surprise + # for the user, so instead we return False here -- which will convert it to string + return False except (TypeError, OverflowError): return False else: diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 34d6a9e3156d4..9e6093a092da0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -176,7 +176,7 @@ class SetRenderedFields(BaseModel): # We are using a BaseModel here compared to server using RootModel because we # have a discriminator running with "type", and RootModel doesn't support type - rendered_fields: dict[str, str | None] + rendered_fields: dict[str, JsonValue] type: Literal["SetRenderedFields"] = "SetRenderedFields" diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index c01677ce1a798..5aca25f590e5e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -27,7 +27,7 @@ import attrs import structlog -from pydantic import BaseModel, ConfigDict, TypeAdapter +from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator @@ -196,22 +196,26 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: # 1. Implementing the part where we pull in the logic to render fields and add that here # for all operators, we should do setattr(task, templated_field, rendered_templated_field) # task.templated_fields should give all the templated_fields and each of those fields should - # give the rendered values. + # give the rendered values. task.templated_fields should already be in a JSONable format and + # we should not have to handle that here. # 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB - templated_fields = ti.task.template_fields - payload = {} - - for field in templated_fields: - if field not in payload: - payload[field] = getattr(ti.task, field) # so that we do not call the API unnecessarily - if payload: - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=payload)) + if rendered_fields := _get_rendered_fields(ti.task): + SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) return ti, log +def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]: + # TODO: Port one of the following to Task SDK + # airflow.serialization.helpers.serialize_template_field or + # airflow.models.renderedtifields.get_serialized_template_fields + from airflow.serialization.helpers import serialize_template_field + + return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields} + + def run(ti: RuntimeTaskInstance, log: Logger): """Run the task in this process.""" from airflow.exceptions import ( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 517157e0a7a90..c9755c252bbe6 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -260,3 +260,61 @@ def test_startup_basic_templated_dag(mocked_parse): ), log=mock.ANY, ) + + +@pytest.mark.parametrize( + ["task_params", "expected_rendered_fields"], + [ + pytest.param( + {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + id="no_templates", + ), + pytest.param( + { + "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}], + "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}}, + }, + { + "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}], + "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}}, + }, + id="mixed_types", + ), + pytest.param( + {"my_tup": (1, 2), "my_set": {1, 2, 3}}, + {"my_tup": "(1, 2)", "my_set": "{1, 2, 3}"}, + id="tuples_and_sets", + ), + ], +) +def test_startup_dag_with_templated_fields(mocked_parse, task_params, expected_rendered_fields): + """Test startup of a DAG with various templated fields.""" + + class CustomOperator(BaseOperator): + template_fields = tuple(task_params.keys()) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for key, value in task_params.items(): + setattr(self, key, value) + + task = CustomOperator(task_id="templated_task") + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), + file="", + requests_fd=0, + ) + mocked_parse(what, "basic_dag", task) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = what + + startup() + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SetRenderedFields(rendered_fields=expected_rendered_fields), + log=mock.ANY, + ) diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index f39cff8ae7650..a4089c9785a98 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -351,7 +351,7 @@ def test_should_respond_200_task_instance_with_rendered(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -403,7 +403,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -2371,7 +2371,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 9b427253b2965..7ce944a4d47ae 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -344,7 +344,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -444,7 +444,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -3070,7 +3070,7 @@ def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_c "try_number": 0, "unixname": getuser(), "dag_run_id": self.RUN_ID, - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index c13effee0bb16..15e56bbc58710 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -422,16 +422,31 @@ def teardown_method(self): clear_db_runs() clear_rendered_ti_fields() - def test_ti_put_rtif_success(self, client, session, create_task_instance): + @pytest.mark.parametrize( + "payload", + [ + # string value + {"field1": "string_value", "field2": "another_string"}, + # dictionary value + {"field1": {"nested_key": "nested_value"}}, + # string lists value + {"field1": ["123"], "field2": ["a", "b", "c"]}, + # list of JSON values + {"field1": [1, "string", 3.14, True, None, {"nested": "dict"}]}, + # nested dictionary with mixed types in lists + { + "field1": {"nested_dict": {"key1": 123, "key2": "value"}}, + "field2": [3.14, {"sub_key": "sub_value"}, [1, 2]], + }, + ], + ) + def test_ti_put_rtif_success(self, client, session, create_task_instance, payload): ti = create_task_instance( task_id="test_ti_put_rtif_success", state=State.RUNNING, session=session, ) session.commit() - - payload = {"field1": "rendered_value1", "field2": "rendered_value2"} - response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) assert response.status_code == 201 assert response.json() == {"message": "Rendered task instance fields successfully set"} @@ -461,28 +476,3 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload) assert response.status_code == 404 assert response.json()["detail"] == "Not Found" - - def test_ti_put_rtif_extra_fields(self, client, session, create_task_instance): - ti = create_task_instance( - task_id="test_ti_put_rtif_missing_ti", - state=State.RUNNING, - session=session, - ) - session.commit() - - payload = { - "field1": "rendered_value1", - "field2": "rendered_value2", - "invalid_key": {"field3": "rendered_value3"}, - } - - response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) - assert response.status_code == 422 - assert response.json()["detail"] == [ - { - "input": {"field3": "rendered_value3"}, - "loc": ["body", "invalid_key"], - "msg": "Input should be a valid string", - "type": "string_type", - } - ] diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 3f1b13cd1a35d..ded755c4d01d9 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -19,6 +19,7 @@ from __future__ import annotations +import ast import os from collections import Counter from datetime import date, timedelta @@ -100,8 +101,12 @@ def teardown_method(self): (None, None), ([], []), ({}, {}), + ((), "()"), + (set(), "set()"), ("test-string", "test-string"), ({"foo": "bar"}, {"foo": "bar"}), + (("foo", "bar"), "('foo', 'bar')"), + ({"foo", "bar"}, "{'foo', 'bar'}"), ("{{ task.task_id }}", "test"), (date(2018, 12, 6), "2018-12-06"), (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"), @@ -158,16 +163,35 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da assert ti.dag_id == rtif.dag_id assert ti.task_id == rtif.task_id assert ti.run_id == rtif.run_id - assert expected_rendered_field == rtif.rendered_fields.get("bash_command") + if type(templated_field) is set: + # the output order of a set is non-deterministic and can change per process. + # this validation can fail if that happens before stringification, so we convert to set and compare. + assert ast.literal_eval(expected_rendered_field) == ast.literal_eval( + rtif.rendered_fields.get("bash_command") + ) + else: + assert expected_rendered_field == rtif.rendered_fields.get("bash_command") session.add(rtif) session.flush() - assert RTIF.get_templated_fields(ti=ti, session=session) == { - "bash_command": expected_rendered_field, - "env": None, - "cwd": None, - } + if type(templated_field) is set: + # the output order of a set is non-deterministic and can change per process. + # this validation can fail if that happens before stringification, so we convert to set and compare. + expected = RTIF.get_templated_fields(ti=ti, session=session) + expected["bash_command"] = ast.literal_eval(expected["bash_command"]) + actual = { + "bash_command": ast.literal_eval(expected_rendered_field), + "env": None, + "cwd": None, + } + assert expected == actual + else: + assert RTIF.get_templated_fields(ti=ti, session=session) == { + "bash_command": expected_rendered_field, + "env": None, + "cwd": None, + } # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None