Source code for airflow_ha.operator

from enum import Enum
from typing import Any, Callable, Dict, Optional, Tuple

from airflow.exceptions import AirflowFailException, AirflowSkipException
from airflow.models.operator import Operator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.python import PythonSensor

__all__ = (
    "HighAvailabilityOperator",
    "HighAvailabilityOperatorMixin",
    "Result",
    "Action",
    "CheckResult",
)


[docs] class Result(str, Enum): PASS = "pass" FAIL = "fail"
[docs] class Action(str, Enum): CONTINUE = "continue" RETRIGGER = "retrigger" STOP = "stop"
CheckResult = Tuple[Result, Action] def skip_(): raise AirflowSkipException def fail_(): raise AirflowFailException def pass_(): pass
[docs] class HighAvailabilityOperatorMixin: _decide_task: BranchPythonOperator _fail: Operator _retrigger_fail: Operator _retrigger_pass: Operator _stop_pass: Operator _stop_fail: Operator _sensor_failed_task: Operator
[docs] def __init__( self, python_callable: Callable[..., CheckResult], pass_trigger_kwargs: Optional[Dict[str, Any]] = None, fail_trigger_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: """The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks. It resembles a BranchPythonOperator with the following predefined set of outcomes: check -> decide -> PASS/RETRIGGER -> PASS/STOP -> FAIL/RETRIGGER -> FAIL/STOP -> */CONTINUE Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first). """ pass_trigger_kwargs = pass_trigger_kwargs or {} fail_trigger_kwargs = fail_trigger_kwargs or {} def _callable_wrapper(**kwargs): task_instance = kwargs["task_instance"] ret: CheckResult = python_callable(**kwargs) if not isinstance(ret, tuple) or not len(ret) == 2 or not isinstance(ret[0], Result) or not isinstance(ret[1], Action): # malformed task_instance.xcom_push(key="return_value", value=(Result.FAIL, Action.STOP)) return True # push to xcom task_instance.xcom_push(key="return_value", value=ret) if ret[1] == Action.CONTINUE: # keep checking return False return True super().__init__(python_callable=_callable_wrapper, **kwargs) # this is needed to ensure the dag fails, since the # retrigger_fail step will pass (to ensure dag retriggers!) self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail_, trigger_rule="all_success") self._retrigger_fail = TriggerDagRunOperator( task_id=f"{self.task_id}-retrigger-fail", **{"trigger_dag_id": self.dag_id, "trigger_rule": "all_success", **fail_trigger_kwargs} ) self._retrigger_pass = TriggerDagRunOperator( task_id=f"{self.task_id}-retrigger-pass", **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **pass_trigger_kwargs} ) self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_, trigger_rule="all_success") self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail_, trigger_rule="all_success") self._sensor_failed_task = PythonOperator(task_id=f"{self.task_id}-sensor-timeout", python_callable=pass_, trigger_rule="all_failed") branch_choices = { (Result.PASS, Action.RETRIGGER): self._retrigger_pass.task_id, (Result.PASS, Action.STOP): self._stop_pass.task_id, (Result.FAIL, Action.RETRIGGER): self._retrigger_fail.task_id, (Result.FAIL, Action.STOP): self._stop_fail.task_id, } def _choose_branch(branch_choices=branch_choices, **kwargs): task_instance = kwargs["task_instance"] check_program_result = task_instance.xcom_pull(key="return_value", task_ids=self.task_id) try: result = Result(check_program_result[0]) action = Action(check_program_result[1]) ret = branch_choices.get((result, action), None) except (ValueError, IndexError, TypeError): ret = None if ret is None: # skip result raise AirflowSkipException return ret self._decide_task = BranchPythonOperator( task_id=f"{self.task_id}-decide", python_callable=_choose_branch, provide_context=True, trigger_rule="all_success", ) self >> self._decide_task >> self._stop_pass self >> self._decide_task >> self._stop_fail self >> self._decide_task >> self._retrigger_pass self >> self._decide_task >> self._retrigger_fail >> self._fail self >> self._sensor_failed_task >> self._retrigger_pass
@property def stop_fail(self) -> Operator: return self._stop_fail @property def stop_pass(self) -> Operator: return self._stop_pass @property def retrigger_fail(self) -> Operator: return self._retrigger_fail @property def retrigger_pass(self) -> Operator: return self._retrigger_pass
[docs] class HighAvailabilityOperator(HighAvailabilityOperatorMixin, PythonSensor): pass