diff --git a/README.md b/README.md
index f5a045b..e03d6df 100644
--- a/README.md
+++ b/README.md
@@ -9,25 +9,27 @@ High Availability (HA) DAG Utility
## Overview
-This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and does the following:
+This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and runs a user-provided `python_callable`.
+The return value can trigger the following actions:
-- runs a user-provided `python_callable` as a sensor
- - if this returns `"done"`, mark the DAG as passed and finish
- - if this returns `"running"`, keep checking
- - if this returns `"failed"`, mark the DAG as failed and re-run
-- if the sensor times out, mark the DAG as passed and re-run
+| Return | Result | Current DAGrun End State |
+| :----- | :----- | :----------------------- |
+| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same DAG to run again | `pass` |
+| `(Result.PASS, Action.STOP)` | Finish the DAG, until its next scheduled run | `pass` |
+| `(Result.FAIL, Action.RETRIGGER)` | Retrigger the same DAG to run again | `fail` |
+| `(Result.FAIL, Action.STOP)` | Finish the DAG, until its next scheduled run | `fail` |
+| `(*, Action.RETRIGGER)` | Continue to run the Sensor | N/A |
+| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |
+| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |
+| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |
-Consider the following DAG:
-
-```python
-from datetime import datetime, timedelta
-from random import choice
+Note: if the sensor times out, the behavior matches `(Result.PASS, Action.RETRIGGER)`.
-from airflow import DAG
-from airflow.operators.python import PythonOperator
-from airflow_ha import HighAvailabilityOperator
+### Example - Always On
+Consider the following DAG:
+```python
with DAG(
dag_id="test-high-availability",
description="Test HA Operator",
@@ -39,20 +41,32 @@ with DAG(
task_id="ha",
timeout=30,
poke_interval=5,
- python_callable=lambda **kwargs: choice(("done", "failed", "running", ""))
+ python_callable=lambda **kwargs: choice(
+ (
+ (Result.PASS, Action.CONTINUE),
+ (Result.PASS, Action.RETRIGGER),
+ (Result.PASS, Action.STOP),
+ (Result.FAIL, Action.CONTINUE),
+ (Result.FAIL, Action.RETRIGGER),
+ (Result.FAIL, Action.STOP),
+ )
+ ),
)
pre = PythonOperator(task_id="pre", python_callable=lambda **kwargs: "test")
pre >> ha
- fail = PythonOperator(task_id="fail", python_callable=lambda **kwargs: "test")
- ha.failed >> fail
+ retrigger_fail = PythonOperator(task_id="retrigger_fail", python_callable=lambda **kwargs: "test")
+ ha.retrigger_fail >> retrigger_fail
+
+ stop_fail = PythonOperator(task_id="stop_fail", python_callable=lambda **kwargs: "test")
+ ha.stop_fail >> stop_fail
- passed = PythonOperator(task_id="passed", python_callable=lambda **kwargs: "test")
- ha.passed >> passed
+ retrigger_pass = PythonOperator(task_id="retrigger_pass", python_callable=lambda **kwargs: "test")
+ ha.retrigger_pass >> retrigger_pass
- done = PythonOperator(task_id="done", python_callable=lambda **kwargs: "test")
- ha.done >> done
+ stop_pass = PythonOperator(task_id="stop_pass", python_callable=lambda **kwargs: "test")
+ ha.stop_pass >> stop_pass
```
This produces a DAG with the following topology:
@@ -60,12 +74,54 @@ This produces a DAG with the following topology:
This DAG exhibits cool behavior.
-If a check fails or the interval elapses, the DAG will re-trigger itself.
-If the check passes, the DAG will finish.
+If the check returns `CONTINUE`, the DAG will continue to run the sensor.
+If the check returns `RETRIGGER` or the interval elapses, the DAG will re-trigger itself and finish.
+If the check returns `STOP`, the DAG will finish and not retrigger itself.
+If the check returns `PASS`, the current DAG run will end in a successful state.
+If the check returns `FAIL`, the current DAG run will end in a failed state.
+
This allows the one to build "always-on" DAGs without having individual long blocking tasks.
This library is used to build [airflow-supervisor](https://github.com/airflow-laminar/airflow-supervisor), which uses [supervisor](http://supervisord.org) as a process-monitor while checking and restarting jobs via `airflow-ha`.
+### Example - Recursive
+
+You can also use this library to build recursive DAGs - or "Cyclic DAGs", despite the oxymoronic name.
+
+The following code makes a DAG that triggers itself with some decrementing counter, starting with value 3:
+
+```python
+
+with DAG(
+ dag_id="test-ha-counter",
+ description="Test HA Countdown",
+ schedule=timedelta(days=1),
+ start_date=datetime(2024, 1, 1),
+ catchup=False,
+):
+
+ def _get_count(**kwargs):
+ # The default is 3
+ return kwargs['dag_run'].conf.get('counter', 3) - 1
+
+ get_count = PythonOperator(task_id="get-count", python_callable=_get_count)
+
+ def _keep_counting(**kwargs):
+ count = kwargs["task_instance"].xcom_pull(key="return_value", task_ids="get-count")
+ return (Result.PASS, Action.RETRIGGER) if count > 0 else (Result.PASS, Action.STOP) if count == 0 else (Result.FAIL, Action.STOP)
+
+ keep_counting = HighAvailabilityOperator(
+ task_id="ha",
+ timeout=30,
+ poke_interval=5,
+ python_callable=_keep_counting,
+ pass_trigger_kwargs={"conf": '''{"counter": {{ ti.xcom_pull(key="return_value", task_ids="get-count") }}}'''},
+ )
+
+ get_count >> keep_counting
+```
+
+
## License
This software is licensed under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details.
diff --git a/airflow_ha/__init__.py b/airflow_ha/__init__.py
index 0e37288..3b326f1 100644
--- a/airflow_ha/__init__.py
+++ b/airflow_ha/__init__.py
@@ -1,3 +1,3 @@
__version__ = "0.1.0"
-from .operator import HighAvailabilityOperator
+from .operator import *
diff --git a/airflow_ha/operator.py b/airflow_ha/operator.py
index 395d689..36067c3 100644
--- a/airflow_ha/operator.py
+++ b/airflow_ha/operator.py
@@ -1,4 +1,5 @@
-from typing import Literal
+from enum import Enum
+from typing import Any, Callable, Dict, Optional, Tuple
from airflow.models.operator import Operator
from airflow.exceptions import AirflowSkipException, AirflowFailException
@@ -6,14 +7,26 @@
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.python import PythonSensor
-__all__ = ("HighAvailabilityOperator",)
+__all__ = (
+ "HighAvailabilityOperator",
+ "Result",
+ "Action",
+ "CheckResult",
+)
-CheckResult = Literal[
- "done",
- "running",
- "failed",
-]
+class Result(str, Enum):
+ PASS = "pass"
+ FAIL = "fail"
+
+
+class Action(str, Enum):
+ CONTINUE = "continue"
+ RETRIGGER = "retrigger"
+ STOP = "stop"
+
+
+CheckResult = Tuple[Result, Action]
def skip_():
@@ -28,93 +41,90 @@ def pass_():
pass
-class HighAvailabilityOperator(PythonSensor):
+class HighAvailabilityOperatorMixin:
_decide_task: BranchPythonOperator
-
- _end_fail: Operator
- _end_pass: Operator
-
- _loop_pass: Operator
- _loop_fail: Operator
-
- _done_task: Operator
- _end_task: Operator
- _running_task: Operator
- _failed_task: Operator
- _kill_task: Operator
-
- _cleanup_task: Operator
- _loop_task: Operator
- _restart_task: Operator
-
- def __init__(self, **kwargs) -> None:
+ _fail: Operator
+ _retrigger_fail: Operator
+ _retrigger_pass: Operator
+ _stop_pass: Operator
+ _stop_fail: Operator
+ _sensor_failed_task: Operator
+
+ def __init__(
+ self,
+ python_callable: Callable[..., CheckResult],
+ pass_trigger_kwargs: Optional[Dict[str, Any]] = None,
+ fail_trigger_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
"""The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks.
It resembles a BranchPythonOperator with the following predefined set of outcomes:
- /-> "done" -> Done -> EndPass
- check -> decide -> "running" -> Loop -> EndPass
- \-> "failed" -> Loop -> EndFail
- \-------------> failed -> Loop -> EndPass
-
- Given a check, there are four outcomes:
- - The tasks finished/exited cleanly, and thus the DAG should terminate cleanly
- - The tasks finished/exited uncleanly, in which case the DAG should restart
- - The tasks did not finish, but the end time has been reached anyway, so the DAG should terminate cleanly
- - The tasks did not finish, but we've reached an interval and should loop and rerun the DAG
-
- The last case is particularly important when DAGs have a max run time, e.g. on AWS MWAA where DAGs
- cannot run for longer than 12 hours at a time and so must be "restarted".
-
- Additionally, there is a "KillTask" to force kill the DAG.
+ check -> decide -> PASS/RETRIGGER
+ -> PASS/STOP
+ -> FAIL/RETRIGGER
+ -> FAIL/STOP
+ -> */CONTINUE
Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first).
"""
- python_callable = kwargs.pop("python_callable")
+ pass_trigger_kwargs = pass_trigger_kwargs or {}
+ fail_trigger_kwargs = fail_trigger_kwargs or {}
def _callable_wrapper(**kwargs):
task_instance = kwargs["task_instance"]
ret: CheckResult = python_callable(**kwargs)
- if ret == "done":
- task_instance.xcom_push(key="return_value", value="done")
- # finish
- return True
- elif ret == "failed":
- task_instance.xcom_push(key="return_value", value="failed")
- # finish
- return True
- elif ret == "running":
- task_instance.xcom_push(key="return_value", value="running")
- # finish
+
+ if not isinstance(ret, tuple) or not len(ret) == 2 or not isinstance(ret[0], Result) or not isinstance(ret[1], Action):
+ # malformed
+ task_instance.xcom_push(key="return_value", value=(Result.FAIL, Action.STOP))
return True
- task_instance.xcom_push(key="return_value", value="")
- return False
+
+ # push to xcom
+ task_instance.xcom_push(key="return_value", value=ret)
+
+ if ret[1] == Action.CONTINUE:
+ # keep checking
+ return False
+ return True
super().__init__(python_callable=_callable_wrapper, **kwargs)
- self._end_fail = PythonOperator(task_id=f"{self.task_id}-dag-fail", python_callable=fail_, trigger_rule="all_success")
- self._end_pass = PythonOperator(task_id=f"{self.task_id}-dag-pass", python_callable=pass_, trigger_rule="all_success")
+ # this is needed to ensure the dag fails, since the
+ # retrigger_fail step will pass (to ensure dag retriggers!)
+ self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail_, trigger_rule="all_success")
- self._loop_fail = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-fail", trigger_dag_id=self.dag_id, trigger_rule="all_success")
- self._loop_pass = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-pass", trigger_dag_id=self.dag_id, trigger_rule="one_success")
+ self._retrigger_fail = TriggerDagRunOperator(
+ task_id=f"{self.task_id}-retrigger-fail", **{"trigger_dag_id": self.dag_id, "trigger_rule": "all_success", **fail_trigger_kwargs}
+ )
+ self._retrigger_pass = TriggerDagRunOperator(
+ task_id=f"{self.task_id}-retrigger-pass", **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **pass_trigger_kwargs}
+ )
+
+ self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_, trigger_rule="all_success")
+ self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail_, trigger_rule="all_success")
- self._done_task = PythonOperator(task_id=f"{self.task_id}-done", python_callable=pass_, trigger_rule="all_success")
- self._running_task = PythonOperator(task_id=f"{self.task_id}-running", python_callable=pass_, trigger_rule="all_success")
- self._failed_task = PythonOperator(task_id=f"{self.task_id}-failed", python_callable=pass_, trigger_rule="all_success")
self._sensor_failed_task = PythonOperator(task_id=f"{self.task_id}-sensor-timeout", python_callable=pass_, trigger_rule="all_failed")
branch_choices = {
- "done": self._done_task.task_id,
- "running": self._running_task.task_id,
- "failed": self._failed_task.task_id,
- "": self._sensor_failed_task.task_id,
+ (Result.PASS, Action.RETRIGGER): self._retrigger_pass.task_id,
+ (Result.PASS, Action.STOP): self._stop_pass.task_id,
+ (Result.FAIL, Action.RETRIGGER): self._retrigger_fail.task_id,
+ (Result.FAIL, Action.STOP): self._stop_fail.task_id,
}
def _choose_branch(branch_choices=branch_choices, **kwargs):
task_instance = kwargs["task_instance"]
check_program_result = task_instance.xcom_pull(key="return_value", task_ids=self.task_id)
- ret = branch_choices.get(check_program_result, None)
+ try:
+ result = Result(check_program_result[0])
+ action = Action(check_program_result[1])
+ ret = branch_choices.get((result, action), None)
+ except (ValueError, IndexError, TypeError):
+ ret = None
if ret is None:
+ # skip result
raise AirflowSkipException
return ret
@@ -125,25 +135,28 @@ def _choose_branch(branch_choices=branch_choices, **kwargs):
trigger_rule="all_success",
)
- self >> self._sensor_failed_task >> self._loop_pass >> self._end_pass
- self >> self._decide_task >> self._done_task
- self >> self._decide_task >> self._running_task >> self._loop_pass >> self._end_pass
- self >> self._decide_task >> self._failed_task >> self._loop_fail >> self._end_fail
+ self >> self._decide_task >> self._stop_pass
+ self >> self._decide_task >> self._stop_fail
+ self >> self._decide_task >> self._retrigger_pass
+ self >> self._decide_task >> self._retrigger_fail >> self._fail
+ self >> self._sensor_failed_task >> self._retrigger_pass
@property
- def check(self) -> Operator:
- return self
+ def stop_fail(self) -> Operator:
+ return self._stop_fail
@property
- def failed(self) -> Operator:
- # NOTE: use loop_fail as this will pass, but self._end_fail will fail to mark the DAG failed
- return self._loop_fail
+ def stop_pass(self) -> Operator:
+ return self._stop_pass
@property
- def passed(self) -> Operator:
- # NOTE: use loop_pass here to match failed()
- return self._loop_pass
+ def retrigger_fail(self) -> Operator:
+ return self._retrigger_fail
@property
- def done(self) -> Operator:
- return self._done_task
+ def retrigger_pass(self) -> Operator:
+ return self._retrigger_pass
+
+
+class HighAvailabilityOperator(HighAvailabilityOperatorMixin, PythonSensor):
+ pass
diff --git a/docs/src/rec.png b/docs/src/rec.png
new file mode 100644
index 0000000..dd2555a
Binary files /dev/null and b/docs/src/rec.png differ
diff --git a/docs/src/top.png b/docs/src/top.png
index 1a359ff..614eb99 100644
Binary files a/docs/src/top.png and b/docs/src/top.png differ