Skip to content

Commit

Permalink
Merge pull request #7 from airflow-laminar/tkp/refactor
Browse files Browse the repository at this point in the history
refactor results, move to mixin, add recursive
  • Loading branch information
timkpaine authored Aug 27, 2024
2 parents ead2f9d + 443ca27 commit ecea856
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 105 deletions.
102 changes: 79 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@ High Availability (HA) DAG Utility

## Overview

This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and does the following:
This library provides an operator called `HighAvailabilityOperator`, which inherits from `PythonSensor` and runs a user-provided `python_callable`.
The return value can trigger the following actions:

- runs a user-provided `python_callable` as a sensor
- if this returns `"done"`, mark the DAG as passed and finish
- if this returns `"running"`, keep checking
- if this returns `"failed"`, mark the DAG as failed and re-run
- if the sensor times out, mark the DAG as passed and re-run
| Return | Result | Current DAGrun End State |
| :----- | :----- | :----------------------- |
| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same DAG to run again | `pass` |
| `(Result.PASS, Action.STOP)` | Finish the DAG, until its next scheduled run | `pass` |
| `(Result.FAIL, Action.RETRIGGER)` | Retrigger the same DAG to run again | `fail` |
| `(Result.FAIL, Action.STOP)` | Finish the DAG, until its next scheduled run | `fail` |
| `(*, Action.RETRIGGER)` | Continue to run the Sensor | N/A |
| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |
| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |
| `(Result.PASS, Action.RETRIGGER)` | Retrigger the same dag to run again | `pass` |

Consider the following DAG:

```python
from datetime import datetime, timedelta
from random import choice
Note: if the sensor times out, the behavior matches `(Result.PASS, Action.RETRIGGER)`.

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow_ha import HighAvailabilityOperator
### Example - Always On

Consider the following DAG:

```python
with DAG(
dag_id="test-high-availability",
description="Test HA Operator",
Expand All @@ -39,33 +41,87 @@ with DAG(
task_id="ha",
timeout=30,
poke_interval=5,
python_callable=lambda **kwargs: choice(("done", "failed", "running", ""))
python_callable=lambda **kwargs: choice(
(
(Result.PASS, Action.CONTINUE),
(Result.PASS, Action.RETRIGGER),
(Result.PASS, Action.STOP),
(Result.FAIL, Action.CONTINUE),
(Result.FAIL, Action.RETRIGGER),
(Result.FAIL, Action.STOP),
)
),
)

pre = PythonOperator(task_id="pre", python_callable=lambda **kwargs: "test")
pre >> ha

fail = PythonOperator(task_id="fail", python_callable=lambda **kwargs: "test")
ha.failed >> fail
retrigger_fail = PythonOperator(task_id="retrigger_fail", python_callable=lambda **kwargs: "test")
ha.retrigger_fail >> retrigger_fail

stop_fail = PythonOperator(task_id="stop_fail", python_callable=lambda **kwargs: "test")
ha.stop_fail >> stop_fail

passed = PythonOperator(task_id="passed", python_callable=lambda **kwargs: "test")
ha.passed >> passed
retrigger_pass = PythonOperator(task_id="retrigger_pass", python_callable=lambda **kwargs: "test")
ha.retrigger_pass >> retrigger_pass

done = PythonOperator(task_id="done", python_callable=lambda **kwargs: "test")
ha.done >> done
stop_pass = PythonOperator(task_id="stop_pass", python_callable=lambda **kwargs: "test")
ha.stop_pass >> stop_pass
```

This produces a DAG with the following topology:

<img src="https://raw.githubusercontent.com/airflow-laminar/airflow-ha/main/docs/src/top.png" />

This DAG exhibits cool behavior.
If a check fails or the interval elapses, the DAG will re-trigger itself.
If the check passes, the DAG will finish.
If the check returns `CONTINUE`, the DAG will continue to run the sensor.
If the check returns `RETRIGGER` or the interval elapses, the DAG will re-trigger itself and finish.
If the check returns `STOP`, the DAG will finish and not retrigger itself.
If the check returns `PASS`, the current DAG run will end in a successful state.
If the check returns `FAIL`, the current DAG run will end in a failed state.

This allows the one to build "always-on" DAGs without having individual long blocking tasks.

This library is used to build [airflow-supervisor](https://github.com/airflow-laminar/airflow-supervisor), which uses [supervisor](http://supervisord.org) as a process-monitor while checking and restarting jobs via `airflow-ha`.

### Example - Recursive

You can also use this library to build recursive DAGs - or "Cyclic DAGs", despite the oxymoronic name.

The following code makes a DAG that triggers itself with some decrementing counter, starting with value 3:

```python

with DAG(
dag_id="test-ha-counter",
description="Test HA Countdown",
schedule=timedelta(days=1),
start_date=datetime(2024, 1, 1),
catchup=False,
):

def _get_count(**kwargs):
# The default is 3
return kwargs['dag_run'].conf.get('counter', 3) - 1

get_count = PythonOperator(task_id="get-count", python_callable=_get_count)

def _keep_counting(**kwargs):
count = kwargs["task_instance"].xcom_pull(key="return_value", task_ids="get-count")
return (Result.PASS, Action.RETRIGGER) if count > 0 else (Result.PASS, Action.STOP) if count == 0 else (Result.FAIL, Action.STOP)

keep_counting = HighAvailabilityOperator(
task_id="ha",
timeout=30,
poke_interval=5,
python_callable=_keep_counting,
pass_trigger_kwargs={"conf": '''{"counter": {{ ti.xcom_pull(key="return_value", task_ids="get-count") }}}'''},
)

get_count >> keep_counting
```
<img src="https://raw.githubusercontent.com/airflow-laminar/airflow-ha/main/docs/src/rec.png" />

## License

This software is licensed under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details.
2 changes: 1 addition & 1 deletion airflow_ha/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.1.0"

from .operator import HighAvailabilityOperator
from .operator import *
175 changes: 94 additions & 81 deletions airflow_ha/operator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
from typing import Literal
from enum import Enum
from typing import Any, Callable, Dict, Optional, Tuple

from airflow.models.operator import Operator
from airflow.exceptions import AirflowSkipException, AirflowFailException
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.python import PythonSensor

__all__ = ("HighAvailabilityOperator",)
__all__ = (
"HighAvailabilityOperator",
"Result",
"Action",
"CheckResult",
)


CheckResult = Literal[
"done",
"running",
"failed",
]
class Result(str, Enum):
PASS = "pass"
FAIL = "fail"


class Action(str, Enum):
CONTINUE = "continue"
RETRIGGER = "retrigger"
STOP = "stop"


CheckResult = Tuple[Result, Action]


def skip_():
Expand All @@ -28,93 +41,90 @@ def pass_():
pass


class HighAvailabilityOperator(PythonSensor):
class HighAvailabilityOperatorMixin:
_decide_task: BranchPythonOperator

_end_fail: Operator
_end_pass: Operator

_loop_pass: Operator
_loop_fail: Operator

_done_task: Operator
_end_task: Operator
_running_task: Operator
_failed_task: Operator
_kill_task: Operator

_cleanup_task: Operator
_loop_task: Operator
_restart_task: Operator

def __init__(self, **kwargs) -> None:
_fail: Operator
_retrigger_fail: Operator
_retrigger_pass: Operator
_stop_pass: Operator
_stop_fail: Operator
_sensor_failed_task: Operator

def __init__(
self,
python_callable: Callable[..., CheckResult],
pass_trigger_kwargs: Optional[Dict[str, Any]] = None,
fail_trigger_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
"""The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks.
It resembles a BranchPythonOperator with the following predefined set of outcomes:
/-> "done" -> Done -> EndPass
check -> decide -> "running" -> Loop -> EndPass
\-> "failed" -> Loop -> EndFail
\-------------> failed -> Loop -> EndPass
Given a check, there are four outcomes:
- The tasks finished/exited cleanly, and thus the DAG should terminate cleanly
- The tasks finished/exited uncleanly, in which case the DAG should restart
- The tasks did not finish, but the end time has been reached anyway, so the DAG should terminate cleanly
- The tasks did not finish, but we've reached an interval and should loop and rerun the DAG
The last case is particularly important when DAGs have a max run time, e.g. on AWS MWAA where DAGs
cannot run for longer than 12 hours at a time and so must be "restarted".
Additionally, there is a "KillTask" to force kill the DAG.
check -> decide -> PASS/RETRIGGER
-> PASS/STOP
-> FAIL/RETRIGGER
-> FAIL/STOP
-> */CONTINUE
Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first).
"""
python_callable = kwargs.pop("python_callable")
pass_trigger_kwargs = pass_trigger_kwargs or {}
fail_trigger_kwargs = fail_trigger_kwargs or {}

def _callable_wrapper(**kwargs):
task_instance = kwargs["task_instance"]
ret: CheckResult = python_callable(**kwargs)
if ret == "done":
task_instance.xcom_push(key="return_value", value="done")
# finish
return True
elif ret == "failed":
task_instance.xcom_push(key="return_value", value="failed")
# finish
return True
elif ret == "running":
task_instance.xcom_push(key="return_value", value="running")
# finish

if not isinstance(ret, tuple) or not len(ret) == 2 or not isinstance(ret[0], Result) or not isinstance(ret[1], Action):
# malformed
task_instance.xcom_push(key="return_value", value=(Result.FAIL, Action.STOP))
return True
task_instance.xcom_push(key="return_value", value="")
return False

# push to xcom
task_instance.xcom_push(key="return_value", value=ret)

if ret[1] == Action.CONTINUE:
# keep checking
return False
return True

super().__init__(python_callable=_callable_wrapper, **kwargs)

self._end_fail = PythonOperator(task_id=f"{self.task_id}-dag-fail", python_callable=fail_, trigger_rule="all_success")
self._end_pass = PythonOperator(task_id=f"{self.task_id}-dag-pass", python_callable=pass_, trigger_rule="all_success")
# this is needed to ensure the dag fails, since the
# retrigger_fail step will pass (to ensure dag retriggers!)
self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail_, trigger_rule="all_success")

self._loop_fail = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-fail", trigger_dag_id=self.dag_id, trigger_rule="all_success")
self._loop_pass = TriggerDagRunOperator(task_id=f"{self.task_id}-loop-pass", trigger_dag_id=self.dag_id, trigger_rule="one_success")
self._retrigger_fail = TriggerDagRunOperator(
task_id=f"{self.task_id}-retrigger-fail", **{"trigger_dag_id": self.dag_id, "trigger_rule": "all_success", **fail_trigger_kwargs}
)
self._retrigger_pass = TriggerDagRunOperator(
task_id=f"{self.task_id}-retrigger-pass", **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **pass_trigger_kwargs}
)

self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_, trigger_rule="all_success")
self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail_, trigger_rule="all_success")

self._done_task = PythonOperator(task_id=f"{self.task_id}-done", python_callable=pass_, trigger_rule="all_success")
self._running_task = PythonOperator(task_id=f"{self.task_id}-running", python_callable=pass_, trigger_rule="all_success")
self._failed_task = PythonOperator(task_id=f"{self.task_id}-failed", python_callable=pass_, trigger_rule="all_success")
self._sensor_failed_task = PythonOperator(task_id=f"{self.task_id}-sensor-timeout", python_callable=pass_, trigger_rule="all_failed")

branch_choices = {
"done": self._done_task.task_id,
"running": self._running_task.task_id,
"failed": self._failed_task.task_id,
"": self._sensor_failed_task.task_id,
(Result.PASS, Action.RETRIGGER): self._retrigger_pass.task_id,
(Result.PASS, Action.STOP): self._stop_pass.task_id,
(Result.FAIL, Action.RETRIGGER): self._retrigger_fail.task_id,
(Result.FAIL, Action.STOP): self._stop_fail.task_id,
}

def _choose_branch(branch_choices=branch_choices, **kwargs):
task_instance = kwargs["task_instance"]
check_program_result = task_instance.xcom_pull(key="return_value", task_ids=self.task_id)
ret = branch_choices.get(check_program_result, None)
try:
result = Result(check_program_result[0])
action = Action(check_program_result[1])
ret = branch_choices.get((result, action), None)
except (ValueError, IndexError, TypeError):
ret = None
if ret is None:
# skip result
raise AirflowSkipException
return ret

Expand All @@ -125,25 +135,28 @@ def _choose_branch(branch_choices=branch_choices, **kwargs):
trigger_rule="all_success",
)

self >> self._sensor_failed_task >> self._loop_pass >> self._end_pass
self >> self._decide_task >> self._done_task
self >> self._decide_task >> self._running_task >> self._loop_pass >> self._end_pass
self >> self._decide_task >> self._failed_task >> self._loop_fail >> self._end_fail
self >> self._decide_task >> self._stop_pass
self >> self._decide_task >> self._stop_fail
self >> self._decide_task >> self._retrigger_pass
self >> self._decide_task >> self._retrigger_fail >> self._fail
self >> self._sensor_failed_task >> self._retrigger_pass

@property
def check(self) -> Operator:
return self
def stop_fail(self) -> Operator:
return self._stop_fail

@property
def failed(self) -> Operator:
# NOTE: use loop_fail as this will pass, but self._end_fail will fail to mark the DAG failed
return self._loop_fail
def stop_pass(self) -> Operator:
return self._stop_pass

@property
def passed(self) -> Operator:
# NOTE: use loop_pass here to match failed()
return self._loop_pass
def retrigger_fail(self) -> Operator:
return self._retrigger_fail

@property
def done(self) -> Operator:
return self._done_task
def retrigger_pass(self) -> Operator:
return self._retrigger_pass


class HighAvailabilityOperator(HighAvailabilityOperatorMixin, PythonSensor):
pass
Binary file added docs/src/rec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/top.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit ecea856

Please sign in to comment.