Skip to content

Commit

Permalink
Merge pull request #2 from airflow-laminar/tkp/init
Browse files Browse the repository at this point in the history
Setup HA operator
  • Loading branch information
timkpaine authored Aug 24, 2024
2 parents 2d9a46c + 349b9ce commit dcc9a46
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 25 deletions.
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,63 @@ High Availability (HA) DAG Utility

## Overview

This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and does the following:

- runs a user-provided `python_callable` as a sensor
- if this returns `"done"`, mark the DAG as passed and finish
- if this returns `"running"`, keep checking
- if this returns `"failed"`, mark the DAG as failed and re-run
- if the sensor times out, mark the DAG as passed and re-run

Consider the following DAG:

```python
from datetime import datetime, timedelta
from random import choice

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow_ha import HighAvailabilityOperator


with DAG(
dag_id="test-high-availability",
description="Test HA Operator",
schedule=timedelta(days=1),
start_date=datetime(2024, 1, 1),
catchup=False,
):
ha = HighAvailabilityOperator(
task_id="ha",
timeout=30,
poke_interval=5,
python_callable=lambda **kwargs: choice(("done", "failed", "running", ""))
)

pre = PythonOperator(task_id="pre", python_callable=lambda **kwargs: "test")
pre >> ha

fail = PythonOperator(task_id="fail", python_callable=lambda **kwargs: "test")
ha.failed >> fail

passed = PythonOperator(task_id="passed", python_callable=lambda **kwargs: "test")
ha.passed >> passed

done = PythonOperator(task_id="done", python_callable=lambda **kwargs: "test")
ha.done >> done
```

This produces a DAG with the following topology:

<img src="./docs/src/top.png" />

This DAG exhibits cool behavior.
If a check fails or the interval elapses, the DAG will re-trigger itself.
If the check passes, the DAG will finish.
This allows the one to build "always-on" DAGs without having individual long blocking tasks.

This library is used to build [airflow-supervisor](https://github.com/airflow-laminar/airflow-supervisor), which uses [supervisor](http://supervisord.org) as a process-monitor while checking and restarting jobs via `airflow-ha`.

## License

This software is licensed under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details.
2 changes: 2 additions & 0 deletions airflow_ha/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
__version__ = "0.1.0"

from .operator import HighAvailabilityOperator
153 changes: 128 additions & 25 deletions airflow_ha/operator.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,61 @@
from airflow.models.dag import DAG # noqa: F401
from airflow.models.operator import Operator # noqa: F401
from airflow.operators.python import BranchPythonOperator, PythonOperator # noqa: F401
from airflow.operators.trigger_dagrun import TriggerDagRunOperator # noqa: F401
from typing import Literal

from airflow.models.operator import Operator
from airflow.exceptions import AirflowSkipException, AirflowFailException
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.python import PythonSensor

__all__ = ("HighAvailabilityOperator",)


CheckResult = Literal[
"done",
"running",
"failed",
]


def skip_():
raise AirflowSkipException


def fail_():
raise AirflowFailException


def pass_():
pass


class HighAvailabilityOperator(PythonSensor):
_decide_task: BranchPythonOperator

_end_fail: Operator
_end_pass: Operator

_loop_pass: Operator
_loop_fail: Operator

_done_task: Operator
_end_task: Operator
_running_task: Operator
_failed_task: Operator
_kill_task: Operator

_cleanup_task: Operator
_loop_task: Operator
_restart_task: Operator

class HighAvailabilityOperator(PythonOperator):
def __init__(self, **kwargs) -> None:
"""The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks.
It resembles a BranchPythonOperator with the following predefined set of outcomes:
HA-Task (the instance of HighAvailabilityOperator itself)
|-> CheckTask (run a user-provided check or task)
| DoneTask (tasks finished cleanly)------|
| EndTask (end time reached)-------------|
|
|--> CleanupTask (Finish DAG, success)
| RunningTask (tasks are still running)--|
|--> LoopTask (Re-trigger DAG, success)
| FailedTask (tasks finished uncleanly)--|
|--> RestartTask (Re-trigger DAG, failure)
| KillTask-------------------------------|
|--> CleanupTask (Finish DAG, failure)
/-> "done" -> Done -> EndPass
check -> decide -> "running" -> Loop -> EndPass
\-> "failed" -> Loop -> EndFail
\-------------> failed -> Loop -> EndPass
Given a check, there are four outcomes:
- The tasks finished/exited cleanly, and thus the DAG should terminate cleanly
Expand All @@ -36,11 +70,80 @@ def __init__(self, **kwargs) -> None:
Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first).
"""
...

kwargs.pop("python_callable", None)
kwargs.pop("op_args", None)
kwargs.pop("op_kwargs", None)
kwargs.pop("templates_dict", None)
kwargs.pop("templates_exts", None)
kwargs.pop("show_return_value_in_logs", None)
python_callable = kwargs.pop("python_callable")

def _callable_wrapper(**kwargs):
task_instance = kwargs["task_instance"]
ret: CheckResult = python_callable(**kwargs)
if ret == "done":
task_instance.xcom_push(key="return_value", value="done")
# finish
return True
elif ret == "failed":
task_instance.xcom_push(key="return_value", value="failed")
# finish
return True
elif ret == "running":
task_instance.xcom_push(key="return_value", value="running")
# finish
return True
task_instance.xcom_push(key="return_value", value="")
return False

super().__init__(python_callable=_callable_wrapper, **kwargs)

self._end_fail = PythonOperator(task_id=f"{self.task_id}-dag-fail", python_callable=fail_, trigger_rule="all_success")
self._end_pass = PythonOperator(task_id=f"{self.task_id}-dag-pass", python_callable=pass_, trigger_rule="all_success")

self._loop_fail = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-fail", trigger_dag_id=self.dag_id, trigger_rule="all_success")
self._loop_pass = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-pass", trigger_dag_id=self.dag_id, trigger_rule="one_success")

self._done_task = PythonOperator(task_id=f"{self.task_id}-done", python_callable=pass_, trigger_rule="all_success")
self._running_task = PythonOperator(task_id=f"{self.task_id}-running", python_callable=pass_, trigger_rule="all_success")
self._failed_task = PythonOperator(task_id=f"{self.task_id}-failed", python_callable=pass_, trigger_rule="all_success")
self._sensor_failed_task = PythonOperator(task_id=f"{self.task_id}-sensor-timeout", python_callable=pass_, trigger_rule="all_failed")

branch_choices = {
"done": self._done_task.task_id,
"running": self._running_task.task_id,
"failed": self._failed_task.task_id,
"": self._sensor_failed_task.task_id,
}

def _choose_branch(branch_choices=branch_choices, **kwargs):
task_instance = kwargs["task_instance"]
check_program_result = task_instance.xcom_pull(key="return_value", task_ids=self.task_id)
ret = branch_choices.get(check_program_result, None)
if ret is None:
raise AirflowSkipException
return ret

self._decide_task = BranchPythonOperator(
task_id=f"{self.task_id}-decide",
python_callable=_choose_branch,
provide_context=True,
trigger_rule="all_success",
)

self >> self._sensor_failed_task >> self._loop_pass >> self._end_pass
self >> self._decide_task >> self._done_task
self >> self._decide_task >> self._running_task >> self._loop_pass >> self._end_pass
self >> self._decide_task >> self._failed_task >> self._loop_fail >> self._end_fail

@property
def check(self) -> Operator:
return self

@property
def failed(self) -> Operator:
# NOTE: use loop_fail as this will pass, but self._end_fail will fail to mark the DAG failed
return self._loop_fail

@property
def passed(self) -> Operator:
# NOTE: use loop_pass here to match failed()
return self._loop_pass

@property
def done(self) -> Operator:
return self._done_task
Binary file added docs/src/top.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit dcc9a46

Please sign in to comment.