Source code for airflow_ha.operator

from datetime import UTC, datetime, time, timedelta
from logging import getLogger
from typing import Any, Callable, Dict, Literal, Optional, Union

from airflow.models.param import Param
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.python import PythonSensor
from airflow_common_operators import fail, pass_

from .common import Action, CheckResult, Result

__all__ = ("HighAvailabilityOperator",)
_log = getLogger(__name__)


[docs] class HighAvailabilityOperator(PythonSensor): _decide_task: BranchPythonOperator _fail: PythonOperator _retrigger_fail: TriggerDagRunOperator _retrigger_pass: TriggerDagRunOperator _stop_pass: PythonOperator _stop_fail: PythonOperator _pass_trigger_kwargs: Optional[Dict[str, Any]] = None _pass_trigger_kwargs_conf: str = "{}" _fail_trigger_kwargs: Optional[Dict[str, Any]] = None _fail_trigger_kwargs_conf: str = "{}" _check_end_conditions: Optional[Callable] = None _runtime: Optional[timedelta] = None _endtime: Optional[time] = None _maxretrigger: Optional[int] = None _reference_date: Optional[str] = None
[docs] def __init__( self, python_callable: Callable[..., CheckResult], pass_trigger_kwargs: Optional[Dict[str, Any]] = None, fail_trigger_kwargs: Optional[Dict[str, Any]] = None, *, runtime: Optional[Union[int, timedelta]] = None, endtime: Optional[Union[str, time]] = None, maxretrigger: Optional[int] = None, reference_date: Literal["start_date", "logical_date", "data_interval_end"] = "data_interval_end", **kwargs, ) -> None: """The HighAvailabilityOperator is an Airflow Meta-Operator for long-running or "always-on" tasks. It resembles a BranchPythonOperator with the following predefined set of outcomes: check -> decide -> PASS/RETRIGGER -> PASS/STOP -> FAIL/RETRIGGER -> FAIL/STOP -> */CONTINUE Any setup should be state-aware (e.g. don't just start a process, check if it is currently started first). """ # These options control the behavior of the sensor self._runtime = timedelta(seconds=runtime) if isinstance(runtime, int) else runtime self._endtime = time.fromisoformat(endtime) if isinstance(endtime, str) else endtime self._maxretrigger = maxretrigger or None self._reference_date = reference_date or "data_interval_end" # These are kwarsg to pass to the trigger operators self._pass_trigger_kwargs = pass_trigger_kwargs or {} self._fail_trigger_kwargs = fail_trigger_kwargs or {} self._pass_trigger_kwargs_conf = self._pass_trigger_kwargs.pop("conf", {}) self._fail_trigger_kwargs_conf = self._fail_trigger_kwargs.pop("conf", {}) # Function to check end conditions self._check_end_conditions = ( # noqa: E731 lambda task_id=kwargs.get("task_id"), runtime=self._runtime, endtime=self._endtime, maxretrigger=self._maxretrigger, reference_date=self._reference_date, **kwargs: _check_end_conditions( task_id=task_id, runtime=runtime, endtime=endtime, maxretrigger=maxretrigger, reference_date=reference_date, **kwargs, ) ) # Function to control the sensor callable_wrapper = lambda python_callable=python_callable, check_end_conditions=self._check_end_conditions, **kwargs: _callable_wrapper( # noqa: E731 python_callable=python_callable, check_end_conditions=check_end_conditions, **kwargs ) if not kwargs.get("trigger_rule"): kwargs["trigger_rule"] = "none_failed" # Initialize the sensor super().__init__(python_callable=callable_wrapper, **kwargs) # Add params to dag to control overrides self.dag.params.update( { f"{self.task_id}-force-run": Param(False, "Ignore runtime/endtime/maxretrigger and force run the task", type="boolean"), f"{self.task_id}-force-runtime": Param( None, "Override `runtime` in seconds, incompatible with `force-run`", type=["null", "integer"] ), f"{self.task_id}-force-endtime": Param( None, "Override `endtime`, incompatible with `force-run`", type=["null", "string"], format="time" ), f"{self.task_id}-force-maxretrigger": Param(None, "Override `maxretrigger`, incompatible with `force-run`", type=["null", "integer"]), } ) # this is needed to ensure the dag fails, since the # retrigger_fail step will pass (to ensure dag retriggers!) self._fail = PythonOperator(task_id=f"{self.task_id}-force-dag-fail", python_callable=fail) self._stop_pass = PythonOperator(task_id=f"{self.task_id}-stop-pass", python_callable=pass_) self._stop_fail = PythonOperator(task_id=f"{self.task_id}-stop-fail", python_callable=fail) # Update the retrigger counts in trigger kwargs retrigger_count_conf = f'''{{{{ (ti.dag_run.conf.get("{self.task_id}-retrigger", 0)|int) + 1 }}}}''' referencedate_conf = f'''{{{{ ti.dag_run.conf.get("{self.task_id}-referencedate", ti.dag_run.start_date.isoformat()) }}}}''' if isinstance(self._pass_trigger_kwargs_conf, dict): self._pass_trigger_kwargs_conf[f"{self.task_id}-retrigger"] = retrigger_count_conf self._pass_trigger_kwargs_conf[f"{self.task_id}-referencedate"] = referencedate_conf else: if not isinstance(self._pass_trigger_kwargs_conf, str) or not self._pass_trigger_kwargs_conf.strip().endswith("}"): raise ValueError("pass_trigger_kwargs must be a dict or a JSON string") self._pass_trigger_kwargs_conf = self._pass_trigger_kwargs_conf.strip()[:-1] + f', "{self.task_id}-retrigger": {retrigger_count_conf} }}' self._pass_trigger_kwargs_conf = ( self._pass_trigger_kwargs_conf.strip()[:-1] + f', "{self.task_id}-referencedate": "{referencedate_conf}" }}' ) if isinstance(self._fail_trigger_kwargs_conf, dict): self._fail_trigger_kwargs_conf[f"{self.task_id}-retrigger"] = retrigger_count_conf self._fail_trigger_kwargs_conf[f"{self.task_id}-referencedate"] = referencedate_conf else: if not isinstance(self._fail_trigger_kwargs_conf, str) or not self._fail_trigger_kwargs_conf.strip().endswith("}"): raise ValueError("fail_trigger_kwargs must be a dict or a JSON string") self._fail_trigger_kwargs_conf = self._fail_trigger_kwargs_conf.strip()[:-1] + f', "{self.task_id}-retrigger": {retrigger_count_conf} }}' self._fail_trigger_kwargs_conf = ( self._fail_trigger_kwargs_conf.strip()[:-1] + f', "{self.task_id}-referencedate": "{referencedate_conf}" }}' ) # Create the retrigger pass/fail operators self._retrigger_fail = TriggerDagRunOperator( task_id=f"{self.task_id}-retrigger-fail", conf=self._fail_trigger_kwargs_conf, **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **self._fail_trigger_kwargs}, ) self._retrigger_pass = TriggerDagRunOperator( task_id=f"{self.task_id}-retrigger-pass", conf=self._pass_trigger_kwargs_conf, **{"trigger_dag_id": self.dag_id, "trigger_rule": "one_success", **self._pass_trigger_kwargs}, ) # Create the branch operator branch_choices = { (Result.PASS, Action.RETRIGGER): self._retrigger_pass.task_id, (Result.PASS, Action.STOP): self._stop_pass.task_id, (Result.FAIL, Action.RETRIGGER): self._retrigger_fail.task_id, (Result.FAIL, Action.STOP): self._stop_fail.task_id, } choose_branch = lambda task_id=self.task_id, branch_choices=branch_choices, **kwargs: _choose_branch( # noqa: E731 task_id=task_id, branch_choices=branch_choices, check_end_conditions=self._check_end_conditions, **kwargs ) self._decide_task = BranchPythonOperator( task_id=f"{self.task_id}-decide", python_callable=choose_branch, provide_context=True, # NOTE: use none_skipped here as the sensor will fail in a timeout trigger_rule="none_skipped", ) self >> self._decide_task self._decide_task >> self._stop_pass self._decide_task >> self._stop_fail self._decide_task >> self._retrigger_pass self._decide_task >> self._retrigger_fail >> self._fail
@property def decide_task(self) -> PythonOperator: return self._decide_task @property def stop_fail(self) -> PythonOperator: return self._stop_fail @property def stop_pass(self) -> PythonOperator: return self._stop_pass @property def retrigger_fail(self) -> TriggerDagRunOperator: return self._retrigger_fail @property def retrigger_pass(self) -> TriggerDagRunOperator: return self._retrigger_pass @property def check_end_conditions(self) -> Callable: return self._check_end_conditions
# Function to check end conditions def _check_end_conditions(task_id, runtime, endtime, maxretrigger, reference_date, **kwargs): # Check if force run in dag run kwargs force_run_conf = kwargs["dag_run"].conf.get("airflow_ha_force_run", False) force_run_param = kwargs["params"].get(f"{task_id}-force-run", False) _log.info(f"airflow-ha configuration -- force_run (conf): {force_run_conf}, force_run (param): {force_run_param}") # Grab the dag start date dag_reference_date = kwargs["dag_run"].conf.get(f"{task_id}-referencedate") if dag_reference_date is not None: dag_reference_date = datetime.fromisoformat(dag_reference_date) else: dag_reference_date = getattr(kwargs["dag_run"], reference_date) if not force_run_conf and not force_run_param: runtime = kwargs["params"].get(f"{task_id}-force-runtime", None) or runtime endtime = kwargs["params"].get(f"{task_id}-force-endtime", None) or endtime # Check if runtime has exceeded # NOTE: start date will always be normalize to UTC by airflow elapsed_time = (datetime.now(tz=UTC) - dag_reference_date).total_seconds() _log.info( f"airflow-ha configuration -- runtime: {runtime}, reference_date({reference_date}): {dag_reference_date}, elapsed_time: {elapsed_time}" ) if runtime is not None and elapsed_time > runtime.total_seconds(): # Runtime has exceeded, end _log.info(f"runtime exceeded for {task_id}, stopping") return Result.PASS, Action.STOP # Check if endtime has passed if endtime is not None: # NOTE: we normalize to the DAG's timezone # and use the provided endtime for convenience dag_timezone = kwargs.get("dag").timezone endtime_as_datetime = datetime.combine( date=dag_reference_date.astimezone(dag_timezone).date(), time=endtime, tzinfo=dag_timezone, ).astimezone(UTC) _log.info( f"airflow-ha configuration -- endtime: {endtime}, endtime_as_datetime: {endtime_as_datetime}, datetime.now(tz=UTC): {datetime.now(tz=UTC)}" ) if endtime_as_datetime <= datetime.now(tz=UTC): # Endtime has passed, end _log.info(f"endtime passed for {task_id}, stopping") return Result.PASS, Action.STOP # Handle retrigger couts maxretrigger = kwargs["params"].get(f"{task_id}-force-maxretrigger", None) or maxretrigger or -1 current_retrigger = int(kwargs["dag_run"].conf.get(f"{task_id}-retrigger", 0)) _log.info(f"airflow-ha configuration -- current_retrigger: {current_retrigger}, maxretrigger: {maxretrigger}") # Check if maxretrigger has exceeded if maxretrigger > -1 and current_retrigger >= maxretrigger: # maxretrigger has exceeded, end _log.info(f"maxretrigger exceeded for {task_id}: {maxretrigger} / {current_retrigger}, stopping") return None, Action.STOP # Function to control the sensor def _callable_wrapper(python_callable, check_end_conditions, **kwargs): task_instance = kwargs["task_instance"] ret: CheckResult = python_callable(**kwargs) if not isinstance(ret, tuple) or not len(ret) == 2 or not isinstance(ret[0], Result) or not isinstance(ret[1], Action): # malformed task_instance.xcom_push(key="return_value", value=(Result.FAIL, Action.STOP)) return True # push to xcom task_instance.xcom_push(key="return_value", value=ret) end_conditions = check_end_conditions(**kwargs) if end_conditions is not None and end_conditions[1] == Action.STOP: # end conditions met _log.info(f"End conditions met, stopping: {end_conditions}") return True if ret[1] == Action.CONTINUE: # keep checking return False return True def _choose_branch(branch_choices, task_id, check_end_conditions, **kwargs): # Grab the task instance task_instance = kwargs["task_instance"] # Check if force run in dag run kwargs end_conditions = check_end_conditions(task_id=task_id, **kwargs) if end_conditions is not None and end_conditions[0] is not None: return branch_choices[end_conditions] if end_conditions is not None: retrigger_exceeded = True else: retrigger_exceeded = False # Otherwise, continue to evaluate check_program_result = task_instance.xcom_pull(key="return_value", task_ids=task_id) try: result = Result(check_program_result[0]) action = Action(check_program_result[1]) ret = branch_choices.get((result, action), branch_choices[(Result.PASS, Action.RETRIGGER)]) if retrigger_exceeded: ret[1] = Action.STOP _log.info(f"Sensor returned {result.name}, {action.name}, branching to {ret}") except (ValueError, IndexError, TypeError): # Sensor has failed, retrigger _log.warning("Sensor failed, pass/retrigger") ret = branch_choices[(Result.PASS, Action.RETRIGGER if not retrigger_exceeded else Action.STOP)] return ret