Skip to content

Commit

Permalink
Added callback_file & callback_name to default_args DAG level and tes…
Browse files Browse the repository at this point in the history
…ts (#218)

Hi astro team!

This PR brings the following enhancements:

1) Unittest for DAG-Level Callback using parameters:

**_on_success_callback_name_** & **_on_success_callback_file_**
**_on_failure_callback_name_** & **_on_failure_callback_file_**

[Astro instruction on how to apply these
parameters](https://www.astronomer.io/docs/learn/dag-factory#step-4-optional-add-a-dag-level-callback)

2) While I find the existing callback feature useful which allows
specifying callback code at any location, this PR takes it a step
further by enabling callbacks to be specified within the DAG's
_default_args_. With this enhancement, the callbacks will automatically
propagate to the task level as well. This update aligns with how
[default_args](https://airflow.apache.org/docs/apache-airflow/2.9.2/core-concepts/dags.html#default-arguments)
are passed in Airflow. Additionally, I’ve included unit tests.
  • Loading branch information
subbota19 authored Oct 16, 2024
1 parent bf9272c commit 93e32e6
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
26 changes: 26 additions & 0 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,32 @@ def get_dag_params(self) -> Dict[str, Any]:
dag_params["on_failure_callback_file"],
)

if utils.check_dict_key(
dag_params["default_args"], "on_success_callback_name"
) and utils.check_dict_key(
dag_params["default_args"], "on_success_callback_file"
):

dag_params["default_args"]["on_success_callback"]: Callable = (
utils.get_python_callable(
dag_params["default_args"]["on_success_callback_name"],
dag_params["default_args"]["on_success_callback_file"],
)
)

if utils.check_dict_key(
dag_params["default_args"], "on_failure_callback_name"
) and utils.check_dict_key(
dag_params["default_args"], "on_failure_callback_file"
):

dag_params["default_args"]["on_failure_callback"]: Callable = (
utils.get_python_callable(
dag_params["default_args"]["on_failure_callback_name"],
dag_params["default_args"]["on_failure_callback_file"],
)
)

if utils.check_dict_key(dag_params, "template_searchpath"):
if isinstance(dag_params["template_searchpath"], (list, str)) and utils.check_template_searchpath(
dag_params["template_searchpath"]
Expand Down
103 changes: 103 additions & 0 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,69 @@
},
},
}

DAG_CONFIG_CALLBACK_NAME_AND_FILE = {
"doc_md": "##here is a doc md string",
"default_args": {
"owner": "custom_owner",
},
"description": "this is an example dag",
"schedule_interval": "0 3 * * *",
"tags": ["tag1", "tag2"],
"on_failure_callback_name": "print_context_callback",
"on_failure_callback_file": __file__,
"on_success_callback_name": "print_context_callback",
"on_success_callback_file": __file__,
"tasks": {
"task_1": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 1",
"execution_timeout_secs": 5,
},
"task_2": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 2",
"dependencies": ["task_1"],
},
"task_3": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 3",
"dependencies": ["task_1"],
},
},
}

DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS = {
"doc_md": "##here is a doc md string",
"default_args": {
"owner": "custom_owner",
"on_failure_callback_name": "print_context_callback",
"on_failure_callback_file": __file__,
"on_success_callback_name": "print_context_callback",
"on_success_callback_file": __file__,
},
"description": "this is an example dag",
"schedule_interval": "0 3 * * *",
"tags": ["tag1", "tag2"],
"tasks": {
"task_1": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 1",
"execution_timeout_secs": 5,
},
"task_2": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 2",
"dependencies": ["task_1"],
},
"task_3": {
"operator": "airflow.operators.bash_operator.BashOperator",
"bash_command": "echo 3",
"dependencies": ["task_1"],
},
},
}

UTC = pendulum.timezone("UTC")

DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS = {
Expand Down Expand Up @@ -667,6 +730,46 @@ def test_make_task_with_callback():
assert callable(actual.on_retry_callback)


def test_dag_with_callback_name_and_file():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE, DEFAULT_CONFIG)
dag = td.build().get("dag")

# Verify that the callbacks have been set up properly per DAG after specifying:
# - 'on_success_callback_file' & 'on_success_callback_name' for 'on_success_callback'
# - 'on_failure_callback_file' & 'on_failure_callback_name' for 'on_failure_callback'
assert "on_success_callback" in td.dag_config
assert "on_failure_callback" in td.dag_config
assert callable(td.dag_config["on_success_callback"])
assert callable(td.dag_config["on_failure_callback"])
assert td.dag_config["on_success_callback"].__name__ == "print_context_callback"
assert td.dag_config["on_success_callback"].__name__ == "print_context_callback"

# Ensure that no callbacks were directly provided at the task level.
for td_task_id, td_task in dag.task_dict.items():
assert not callable(td_task.on_success_callback)
assert not callable(td_task.on_failure_callback)


def test_dag_with_callback_name_and_file_default_args():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS, DEFAULT_CONFIG)
dag = td.build().get("dag")

# Verify that the callbacks have been set up properly per DAG and tasks after specifying through default_args:
# - 'on_success_callback_file' & 'on_success_callback_name' for 'on_success_callback'
# - 'on_failure_callback_file' & 'on_failure_callback_name' for 'on_failure_callback'
td_default_args = td.dag_config.get("default_args")
assert "on_success_callback" in td_default_args
assert "on_failure_callback" in td_default_args
assert callable(td_default_args["on_success_callback"])
assert callable(td_default_args["on_failure_callback"])

for td_task_id, td_task in dag.task_dict.items():
assert callable(td_task.on_success_callback)
assert callable(td_task.on_failure_callback)
assert td_task.on_success_callback.__name__ == "print_context_callback"
assert td_task.on_success_callback.__name__ == "print_context_callback"


def test_make_timetable():
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
Expand Down

0 comments on commit 93e32e6

Please sign in to comment.