Source code for airflow_supervisor.airflow.ssh

from shlex import quote
from typing import Dict

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow_balancer import Host, Port
from supervisor_pydantic.convenience import SupervisorTaskStep

from airflow_supervisor.config import SupervisorSSHAirflowConfiguration

from .local import Supervisor

__all__ = ("SupervisorSSH",)


[docs] class SupervisorSSH(Supervisor): # Mimic SSH Operator: https://airflow.apache.org/docs/apache-airflow-providers-ssh/stable/_api/airflow/providers/ssh/operators/ssh/index.html
[docs] def __init__( self, dag: DAG, cfg: SupervisorSSHAirflowConfiguration, host: "Host" = None, port: "Port" = None, **kwargs, ): for attr in ("command_prefix",): if attr in kwargs: setattr(self, f"_{attr}", kwargs.pop(attr)) elif cfg and getattr(cfg, attr): setattr(self, f"_{attr}", getattr(cfg, attr)) else: setattr(self, f"_{attr}", "") self._ssh_operator_kwargs = {} for attr in ( "ssh_hook", "ssh_conn_id", "remote_host", "conn_timeout", "cmd_timeout", "environment", "get_pty", "banner_timeout", "skip_on_exit_code", ): if attr in kwargs: self._ssh_operator_kwargs[attr] = kwargs.pop(attr) setattr(cfg, attr, self._ssh_operator_kwargs[attr]) elif cfg and getattr(cfg, attr): self._ssh_operator_kwargs[attr] = getattr(cfg, attr) # Integrate with airflow-balancer, use host if provided if host: self._ssh_operator_kwargs["remote_host"] = host.name self._ssh_operator_kwargs["ssh_hook"] = host.hook() # Ensure host matches the configuration cfg.convenience.host = host.name if port: # Ensure port matches the configuration cfg.convenience.port = f"*:{port.port}" if host or port: # revalidate cfg._setup_convenience_defaults() super().__init__(dag=dag, cfg=cfg, **kwargs)
def get_step_kwargs(self, step: SupervisorTaskStep) -> Dict: if step == "configure-supervisor": # TODO return { **self._ssh_operator_kwargs, "command": f""" {self._command_prefix} _supervisor_convenience {step} '{self._cfg.model_dump_json()}' """, } elif step in ("start-supervisor", "stop-supervisor", "unconfigure-supervisor", "force-kill"): # must be done via SSH return { **self._ssh_operator_kwargs, "command": f""" {self._command_prefix} _supervisor_convenience {step} --cfg {quote(str(self._cfg._pydantic_path))} """, } else: # can be done via XMLRPC API return super().get_step_kwargs(step=step) def get_step_operator(self, step: SupervisorTaskStep) -> Operator: if step in ( "configure-supervisor", "start-supervisor", "stop-supervisor", "unconfigure-supervisor", "force-kill", ): return SSHOperator( **{ "task_id": f"{self._dag.dag_id}-{step}", **self.get_base_operator_kwargs(), **self.get_step_kwargs(step), } ) else: # can be done via XMLRPC API return super().get_step_operator(step=step)