diff --git a/README.md b/README.md index f5a045b..e03d6df 100644 --- a/README.md +++ b/README.md @@ -9,25 +9,27 @@ High Availability (HA) DAG Utility ## Overview -This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and does the following: +This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and runs a user-provided `python_callable`. +The return value can trigger the following actions: -- runs a user-provided `python_callable` as a sensor - - if this returns `"done"`, mark the DAG as passed and finish - - if this returns `"running"`, keep checking - - if this returns `"failed"`, mark the DAG as failed and re-run -- if the sensor times out, mark the DAG as passed and re-run +| Return | Result | Current DAGrun End State | +| :----- | :----- | :----------------------- | +| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same DAG to run again | `pass` | +| `(Result.PASS, Action.STOP)` | Finish the DAG, until its next scheduled run | `pass` | +| `(Result.FAIL, Action.RETRIGGER)` | Retrigger the same DAG to run again | `fail` | +| `(Result.FAIL, Action.STOP)` | Finish the DAG, until its next scheduled run | `fail` | +| `(*, Action.RETRIGGER)` | Continue to run the Sensor | N/A | +| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` | +| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` | +| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` | -Consider the following DAG: - -```python -from datetime import datetime, timedelta -from random import choice +Note: if the sensor times out, the behavior matches `(Result.PASS, Action.RETRIGGER)`. -from airflow import DAG -from airflow.operators.python import PythonOperator -from airflow_ha import HighAvailabilityOperator +### Example - Always On +Consider the following DAG: +```python with DAG( dag_id="test-high-availability", description="Test HA Operator", @@ -39,20 +41,32 @@ with DAG( task_id="ha", timeout=30, poke_interval=5, - python_callable=lambda **kwargs: choice(("done", "failed", "running", "")) + python_callable=lambda **kwargs: choice( + ( + (Result.PASS, Action.CONTINUE), + (Result.PASS, Action.RETRIGGER), + (Result.PASS, Action.STOP), + (Result.FAIL, Action.CONTINUE), + (Result.FAIL, Action.RETRIGGER), + (Result.FAIL, Action.STOP), + ) + ), ) pre = PythonOperator(task_id="pre", python_callable=lambda **kwargs: "test") pre >> ha - fail = PythonOperator(task_id="fail", python_callable=lambda **kwargs: "test") - ha.failed >> fail + retrigger_fail = PythonOperator(task_id="retrigger_fail", python_callable=lambda **kwargs: "test") + ha.retrigger_fail >> retrigger_fail + + stop_fail = PythonOperator(task_id="stop_fail", python_callable=lambda **kwargs: "test") + ha.stop_fail >> stop_fail - passed = PythonOperator(task_id="passed", python_callable=lambda **kwargs: "test") - ha.passed >> passed + retrigger_pass = PythonOperator(task_id="retrigger_pass", python_callable=lambda **kwargs: "test") + ha.retrigger_pass >> retrigger_pass - done = PythonOperator(task_id="done", python_callable=lambda **kwargs: "test") - ha.done >> done + stop_pass = PythonOperator(task_id="stop_pass", python_callable=lambda **kwargs: "test") + ha.stop_pass >> stop_pass ``` This produces a DAG with the following topology: @@ -60,12 +74,54 @@ This produces a DAG with the following topology: This DAG exhibits cool behavior. -If a check fails or the interval elapses, the DAG will re-trigger itself. -If the check passes, the DAG will finish. +If the check returns `CONTINUE`, the DAG will continue to run the sensor. +If the check returns `RETRIGGER` or the interval elapses, the DAG will re-trigger itself and finish. +If the check returns `STOP`, the DAG will finish and not retrigger itself. +If the check returns `PASS`, the current DAG run will end in a successful state. +If the check returns `FAIL`, the current DAG run will end in a failed state. + This allows the one to build "always-on" DAGs without having individual long blocking tasks. This library is used to build [airflow-supervisor](https://github.com/airflow-laminar/airflow-supervisor), which uses [supervisor](http://supervisord.org) as a process-monitor while checking and restarting jobs via `airflow-ha`. +### Example - Recursive + +You can also use this library to build recursive DAGs - or "Cyclic DAGs", despite the oxymoronic name. + +The following code makes a DAG that triggers itself with some decrementing counter, starting with value 3: + +```python + +with DAG( + dag_id="test-ha-counter", + description="Test HA Countdown", + schedule=timedelta(days=1), + start_date=datetime(2024, 1, 1), + catchup=False, +): + + def _get_count(**kwargs): + # The default is 3 + return kwargs['dag_run'].conf.get('counter', 3) - 1 + + get_count = PythonOperator(task_id="get-count", python_callable=_get_count) + + def _keep_counting(**kwargs): + count = kwargs["task_instance"].xcom_pull(key="return_value", task_ids="get-count") + return (Result.PASS, Action.RETRIGGER) if count > 0 else (Result.PASS, Action.STOP) if count == 0 else (Result.FAIL, Action.STOP) + + keep_counting = HighAvailabilityOperator( + task_id="ha", + timeout=30, + poke_interval=5, + python_callable=_keep_counting, + pass_trigger_kwargs={"conf": '''{"counter": {{ ti.xcom_pull(key="return_value", task_ids="get-count") }}}'''}, + ) + + get_count >> keep_counting +``` + + ## License This software is licensed under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details. diff --git a/airflow_ha/__init__.py b/airflow_ha/__init__.py index 0e37288..3b326f1 100644 --- a/airflow_ha/__init__.py +++ b/airflow_ha/__init__.py @@ -1,3 +1,3 @@ __version__ = "0.1.0" -from .operator import HighAvailabilityOperator +from .operator import * diff --git a/airflow_ha/operator.py b/airflow_ha/operator.py index 395d689..36067c3 100644 --- a/airflow_ha/operator.py +++ b/airflow_ha/operator.py @@ -1,4 +1,5 @@ -from typing import Literal +from enum import Enum +from typing import Any, Callable, Dict, Optional, Tuple from airflow.models.operator import Operator from airflow.exceptions import AirflowSkipException, AirflowFailException @@ -6,14 +7,26 @@ from airflow.operators.trigger_dagrun import TriggerDagRunOperator from airflow.sensors.python import PythonSensor -__all__ = ("HighAvailabilityOperator",) +__all__ = ( + "HighAvailabilityOperator", + "Result", + "Action", + "CheckResult", +) -CheckResult = Literal[ - "done", - "running", - "failed", -] +class Result(str, Enum): + PASS = "pass" + FAIL = "fail" + + +class Action(str, Enum): + CONTINUE = "continue" + RETRIGGER = "retrigger" + STOP = "stop" + + +CheckResult = Tuple[Result, Action] def skip_(): @@ -28,93 +41,90 @@ def pass_(): pass -class HighAvailabilityOperator(PythonSensor): +class HighAvailabilityOperatorMixin: _decide_task: BranchPythonOperator - - _end_fail: Operator - _end_pass: Operator - - _loop_pass: Operator - _loop_fail: Operator - - _done_task: Operator - _end_task: Operator - _running_task: Operator - _failed_task: Operator - _kill_task: Operator - - _cleanup_task: Operator - _loop_task: Operator - _restart_task: Operator - - def __init__(self, **kwargs) -> None: + _fail: Operator + _retrigger_fail: Operator + _retrigger_pass: Operator + _stop_pass: Operator + _stop_fail: Operator + _sensor_failed_task: Operator + + def __init__( + self, + python_callable: Callable[..., CheckResult], + pass_trigger_kwargs: Optional[Dict[str, Any]] = None, + fail_trigger_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: """The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks. It resembles a BranchPythonOperator with the following predefined set of outcomes: - /-> "done" -> Done -> EndPass - check -> decide -> "running" -> Loop -> EndPass - \-> "failed" -> Loop -> EndFail - \-------------> failed -> Loop -> EndPass - - Given a check, there are four outcomes: - - The tasks finished/exited cleanly, and thus the DAG should terminate cleanly - - The tasks finished/exited uncleanly, in which case the DAG should restart - - The tasks did not finish, but the end time has been reached anyway, so the DAG should terminate cleanly - - The tasks did not finish, but we've reached an interval and should loop and rerun the DAG - - The last case is particularly important when DAGs have a max run time, e.g. on AWS MWAA where DAGs - cannot run for longer than 12 hours at a time and so must be "restarted". - - Additionally, there is a "KillTask" to force kill the DAG. + check -> decide -> PASS/RETRIGGER + -> PASS/STOP + -> FAIL/RETRIGGER + -> FAIL/STOP + -> */CONTINUE Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first). """ - python_callable = kwargs.pop("python_callable") + pass_trigger_kwargs = pass_trigger_kwargs or {} + fail_trigger_kwargs = fail_trigger_kwargs or {} def _callable_wrapper(**kwargs): task_instance = kwargs["task_instance"] ret: CheckResult = python_callable(**kwargs) - if ret == "done": - task_instance.xcom_push(key="return_value", value="done") - # finish - return True - elif ret == "failed": - task_instance.xcom_push(key="return_value", value="failed") - # finish - return True - elif ret == "running": - task_instance.xcom_push(key="return_value", value="running") - # finish + + if not isinstance(ret, tuple) or not len(ret) == 2 or not isinstance(ret[0], Result) or not isinstance(ret[1], Action): + # malformed + task_instance.xcom_push(key="return_value", value=(Result.FAIL, Action.STOP)) return True - task_instance.xcom_push(key="return_value", value="") - return False + + # push to xcom + task_instance.xcom_push(key="return_value", value=ret) + + if ret[1] == Action.CONTINUE: + # keep checking + return False + return True super().__init__(python_callable=_callable_wrapper, **kwargs) - self._end_fail = PythonOperator(task_id=f"{self.task_id}-dag-fail", python_callable=fail_, trigger_rule="all_success") - self._end_pass = PythonOperator(task_id=f"{self.task_id}-dag-pass", python_callable=pass_, trigger_rule="all_success") + # this is needed to ensure the dag fails, since the + # retrigger_fail step will pass (to ensure dag retriggers!) + self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail_, trigger_rule="all_success") - self._loop_fail = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-fail", trigger_dag_id=self.dag_id, trigger_rule="all_success") - self._loop_pass = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-pass", trigger_dag_id=self.dag_id, trigger_rule="one_success") + self._retrigger_fail = TriggerDagRunOperator( + task_id=f"{self.task_id}-retrigger-fail", **{"trigger_dag_id": self.dag_id, "trigger_rule": "all_success", **fail_trigger_kwargs} + ) + self._retrigger_pass = TriggerDagRunOperator( + task_id=f"{self.task_id}-retrigger-pass", **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **pass_trigger_kwargs} + ) + + self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_, trigger_rule="all_success") + self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail_, trigger_rule="all_success") - self._done_task = PythonOperator(task_id=f"{self.task_id}-done", python_callable=pass_, trigger_rule="all_success") - self._running_task = PythonOperator(task_id=f"{self.task_id}-running", python_callable=pass_, trigger_rule="all_success") - self._failed_task = PythonOperator(task_id=f"{self.task_id}-failed", python_callable=pass_, trigger_rule="all_success") self._sensor_failed_task = PythonOperator(task_id=f"{self.task_id}-sensor-timeout", python_callable=pass_, trigger_rule="all_failed") branch_choices = { - "done": self._done_task.task_id, - "running": self._running_task.task_id, - "failed": self._failed_task.task_id, - "": self._sensor_failed_task.task_id, + (Result.PASS, Action.RETRIGGER): self._retrigger_pass.task_id, + (Result.PASS, Action.STOP): self._stop_pass.task_id, + (Result.FAIL, Action.RETRIGGER): self._retrigger_fail.task_id, + (Result.FAIL, Action.STOP): self._stop_fail.task_id, } def _choose_branch(branch_choices=branch_choices, **kwargs): task_instance = kwargs["task_instance"] check_program_result = task_instance.xcom_pull(key="return_value", task_ids=self.task_id) - ret = branch_choices.get(check_program_result, None) + try: + result = Result(check_program_result[0]) + action = Action(check_program_result[1]) + ret = branch_choices.get((result, action), None) + except (ValueError, IndexError, TypeError): + ret = None if ret is None: + # skip result raise AirflowSkipException return ret @@ -125,25 +135,28 @@ def _choose_branch(branch_choices=branch_choices, **kwargs): trigger_rule="all_success", ) - self >> self._sensor_failed_task >> self._loop_pass >> self._end_pass - self >> self._decide_task >> self._done_task - self >> self._decide_task >> self._running_task >> self._loop_pass >> self._end_pass - self >> self._decide_task >> self._failed_task >> self._loop_fail >> self._end_fail + self >> self._decide_task >> self._stop_pass + self >> self._decide_task >> self._stop_fail + self >> self._decide_task >> self._retrigger_pass + self >> self._decide_task >> self._retrigger_fail >> self._fail + self >> self._sensor_failed_task >> self._retrigger_pass @property - def check(self) -> Operator: - return self + def stop_fail(self) -> Operator: + return self._stop_fail @property - def failed(self) -> Operator: - # NOTE: use loop_fail as this will pass, but self._end_fail will fail to mark the DAG failed - return self._loop_fail + def stop_pass(self) -> Operator: + return self._stop_pass @property - def passed(self) -> Operator: - # NOTE: use loop_pass here to match failed() - return self._loop_pass + def retrigger_fail(self) -> Operator: + return self._retrigger_fail @property - def done(self) -> Operator: - return self._done_task + def retrigger_pass(self) -> Operator: + return self._retrigger_pass + + +class HighAvailabilityOperator(HighAvailabilityOperatorMixin, PythonSensor): + pass diff --git a/docs/src/rec.png b/docs/src/rec.png new file mode 100644 index 0000000..dd2555a Binary files /dev/null and b/docs/src/rec.png differ diff --git a/docs/src/top.png b/docs/src/top.png index 1a359ff..614eb99 100644 Binary files a/docs/src/top.png and b/docs/src/top.png differ