Source code for airflow_supervisor.airflow.ssh
from logging import getLogger
from shlex import quote
from typing import Dict
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow_balancer import Host, Port
from airflow_pydantic import SSHOperatorArgs
from supervisor_pydantic.convenience import SupervisorTaskStep
from airflow_supervisor.config import SupervisorSSHAirflowConfiguration
from .local import Supervisor
__all__ = ("SupervisorSSH",)
_log = getLogger(__name__)
[docs]
class SupervisorSSH(Supervisor):
# Mimic SSH Operator: https://airflow.apache.org/docs/apache-airflow-providers-ssh/stable/_api/airflow/providers/ssh/operators/ssh/index.html
[docs]
def __init__(
self,
dag: DAG,
cfg: SupervisorSSHAirflowConfiguration,
host: "Host" = None,
port: "Port" = None,
**kwargs,
):
for attr in ("command_prefix",):
if attr in kwargs:
_log.info(f"Setting {attr} to {kwargs.get(attr)}")
setattr(self, f"_{attr}", kwargs.pop(attr))
elif cfg and getattr(cfg, attr):
_log.info(f"Setting {attr} to {getattr(cfg, attr)}")
setattr(self, f"_{attr}", getattr(cfg, attr))
else:
_log.info(f"Setting {attr} to empty string")
setattr(self, f"_{attr}", "")
self._ssh_operator_kwargs = cfg.ssh_operator_args.model_dump(exclude_none=True)
pydantic_fields = (
cfg.ssh_operator_args.__pydantic_fields__.keys()
if hasattr(cfg.ssh_operator_args, "__pydantic_fields__")
else cfg.ssh_operator_args.__fields__
)
for attr in pydantic_fields:
if attr in kwargs:
_log.info(f"Setting {attr} to {kwargs.get(attr)}")
self._ssh_operator_kwargs[attr] = kwargs.pop(attr)
setattr(cfg, attr, self._ssh_operator_kwargs[attr])
# Integrate with airflow-balancer, use host if provided
if host:
_log.info(f"Setting host to {host.name}")
self._ssh_operator_kwargs["remote_host"] = host.name
self._ssh_operator_kwargs["ssh_hook"] = host.hook()
cfg.ssh_operator_args = SSHOperatorArgs(**self._ssh_operator_kwargs)
# Ensure host matches the configuration
cfg.convenience.host = host.name
# Extract pool if available
if host.pool and not cfg.airflow.pool:
_log.info(f"Setting airflow pool to {host.pool}")
cfg.airflow.pool = host.pool
if port:
_log.info(f"Setting port to {port.port}")
# Ensure port matches the configuration
cfg.convenience.port = f"*:{port.port}"
if host or port:
# revalidate
_log.info("Revalidating configuration")
cfg._setup_convenience_defaults()
super().__init__(dag=dag, cfg=cfg, **kwargs)
def get_step_kwargs(self, step: SupervisorTaskStep) -> Dict:
if step == "configure-supervisor":
# TODO
return {
**self._ssh_operator_kwargs,
"command": f"""
{self._command_prefix}
_supervisor_convenience {step} '{self._cfg.model_dump_json()}'
""",
}
elif step in ("start-supervisor", "stop-supervisor", "unconfigure-supervisor", "force-kill"):
# must be done via SSH
return {
**self._ssh_operator_kwargs,
"command": f"""
{self._command_prefix}
_supervisor_convenience {step} --cfg {quote(str(self._cfg._pydantic_path))}
""",
}
else:
# can be done via XMLRPC API
return super().get_step_kwargs(step=step)
def get_step_operator(self, step: SupervisorTaskStep) -> Operator:
if step in (
"configure-supervisor",
"start-supervisor",
"stop-supervisor",
"unconfigure-supervisor",
"force-kill",
):
return SSHOperator(
**{
"task_id": f"{self._dag.dag_id}-{step}",
**self.get_base_operator_kwargs(),
**self.get_step_kwargs(step),
}
)
else:
# can be done via XMLRPC API
return super().get_step_operator(step=step)