Source code for airflow_pydantic.extras.balancer.balancer

from logging import getLogger
from pathlib import Path
from random import choice
from typing import Callable, List, Optional, Union

from pydantic import Field, model_validator

try:
    from sqlalchemy.exc import OperationalError
except ImportError:
    OperationalError = RuntimeError

from typing_extensions import Self

from ...core import BaseModel
from ...utils import Pool, Variable
from .host import Host
from .port import Port

__all__ = ("BalancerConfiguration", "load_config")

_log = getLogger(__name__)


[docs] class BalancerConfiguration(BaseModel): hosts: List[Host] = Field(default_factory=list) ports: List[Port] = Field(default_factory=list) default_username: str = "airflow" # Password default_password: Optional[Union[str, Variable]] = None # Or get key file default_key_file: Optional[str] = None # The queue that might include the host running airflow itself primary_queue: str = "default" # The queue that does not include the host running airflow itself secondary_queue: str = "default" # The default worker queue default_queue: str = "default" # The default pool size default_size: int = Field(default=8) # rewrite pool size from config if differs from airflow variable stored value override_pool_size: bool = False # create connection object in airflow for host create_connection: bool = False @property def all_hosts(self): return sorted(list(set(self.hosts))) @property def all_ports(self): return sorted(list(set(self.ports))) @model_validator(mode="after") def _validate(self) -> Self: from airflow_pydantic.airflow import Pool as AirflowPool, PoolNotFound, get_parsing_context # Validate no duplicate hosts seen_hostnames = set() for host in self.hosts: if host.name in seen_hostnames: raise ValueError(f"Duplicate host found: {host.name}") seen_hostnames.add(host.name) # Handle limits for host in self.hosts: if not host.size: host.size = self.default_size if not host.pool: host.pool = Pool( pool=host.name, slots=host.size, description=f"Balancer pool for host({host.name})", ) if get_parsing_context().dag_id is not None: # check airflow first try: if isinstance(host.pool, Pool): pool_name = host.pool.pool elif isinstance(host.pool, AirflowPool): pool_name = host.pool.pool elif isinstance(host.pool, dict): pool_name = host.pool.get("pool") elif isinstance(host.pool, str): pool_name = host.pool res = AirflowPool.get_pool(pool_name) # airflow return value differs version-to-version if res is None: raise PoolNotFound elif res.slots != host.size: if self.override_pool_size: AirflowPool.create_or_update_pool( name=pool_name, slots=host.size, description=host.pool.description, include_deferred=False, ) else: host.size = res.slots except PoolNotFound: try: # else set to default AirflowPool.create_or_update_pool( name=host.name, slots=host.size, description=f"Balancer pool for host({host.name})", include_deferred=False, ) except OperationalError: # If the database is not available, we cannot create the pool pass if not host.username and self.default_username: host.username = self.default_username if not host.password and self.default_password: host.password = self.default_password if not host.key_file and self.default_key_file: host.key_file = self.default_key_file if not host.size: host.size = self.default_size # Handle ports _used_ports = set() for port in self.ports: if port.host_name and not port.host: port.host = next((host for host in self.all_hosts if host.name == port.host_name), None) if not port.port: raise ValueError("Port must be specified") if not port.host: raise ValueError("Host must be specified") if (port.host.name, port.port) in _used_ports: raise ValueError(f"Duplicate port usage for host: {port.host.name}:{port.port}") _used_ports.add((port.host.name, port.port)) # Create pools # TODO reenable # Pool.create_or_update_pool( # name=port.pool, # slots=1, # description=f"Balancer pool for host({port.port}) port({port.port})", # include_deferred=True, # ) # sort hosts by name, sort ports by host name then port number # NOTE: since we're in a validator and we have validate_on_assignment, bypass here self.__dict__["hosts"] = sorted(self.hosts, key=lambda host: host.name) self.__dict__["ports"] = sorted(self.ports, key=lambda port: (port.host.name, port.port)) return self
[docs] def filter_hosts( self, name: Optional[Union[str, List[str]]] = None, queue: Optional[Union[str, List[str]]] = None, os: Optional[Union[str, List[str]]] = None, tag: Optional[Union[str, List[str]]] = None, custom: Optional[Callable] = None, ) -> List[Host]: from .query import BalancerHostQueryConfiguration query = BalancerHostQueryConfiguration( kind="filter", name=name, queue=queue, os=os, tag=tag, custom=custom, balancer=self, ) return query.execute()
[docs] def select_host( self, name: Optional[Union[str, List[str]]] = None, queue: Optional[Union[str, List[str]]] = None, os: Union[str, List[str]] = "", tag: Union[str, List[str]] = "", custom: Callable = None, ) -> List[Host]: from .query import BalancerHostQueryConfiguration query = BalancerHostQueryConfiguration( kind="select", name=name, queue=queue, os=os, tag=tag, custom=custom, balancer=self, ) return query.execute()
[docs] def filter_ports( self, name: Optional[Union[str, List[str]]] = None, tag: Optional[Union[str, List[str]]] = None, custom: Optional[Callable] = None, ) -> List[Host]: from .query import BalancerPortQueryConfiguration query = BalancerPortQueryConfiguration( kind="filter", name=name, tag=tag, custom=custom, balancer=self, ) return query.execute()
[docs] def select_port( self, name: Optional[Union[str, List[str]]] = None, tag: Union[str, List[str]] = "", custom: Callable = None, ) -> List[Host]: from .query import BalancerPortQueryConfiguration query = BalancerPortQueryConfiguration( kind="select", name=name, tag=tag, custom=custom, balancer=self, ) return query.execute()
[docs] def free_port( self, host: Host, min: int = 1000, max: int = 65535, ) -> Port: used_ports = [port.port for port in self.ports if port.host == host] port = Port(host=host, port=choice(range(min, max))) while port.port in used_ports: port = Port(host=host, port=choice(range(min, max))) # TODO add pool around port? or just allow? context manager? return port
[docs] @staticmethod def load_path(yaml_file: str | Path, _config_dir: str | Path = None) -> Self: """Load configuration from yaml file""" from hydra import compose, initialize_config_dir from hydra.utils import instantiate if not isinstance(yaml_file, Path): yaml_file = Path(yaml_file).resolve() if not yaml_file.suffix == ".yaml": raise ValueError(f"File {yaml_file} must end in .yaml") file_name = yaml_file.stem config_dir = str(_config_dir or yaml_file.parent) # TODO: how to get the underlying value directly with initialize_config_dir(config_dir=config_dir, version_base=None): cfg = compose(config_name=file_name, overrides=[]) inst_cfg = instantiate(cfg) # Try to find the BalancerConfiguration in the config # TODO: recursively search for BalancerConfiguration # TODO: more than one? if yaml_file.parent.stem in inst_cfg: config = inst_cfg[yaml_file.parent.stem] else: config = inst_cfg if "extensions" in config: config = config["extensions"] if not isinstance(config, BalancerConfiguration): # Try to look through values again for value in config.values(): if isinstance(value, BalancerConfiguration): config = value break else: raise ValueError(f"Config {file_name} does not contain a BalancerConfiguration") return config
[docs] @staticmethod def load( config_dir: Path | str = "config", config_name: Path | str = "", overrides: Optional[list[str]] = None, *, basepath: str = "", _offset: int = 4, ) -> "BalancerConfiguration": from airflow_config import load_config as load_airflow_config from hydra.errors import HydraException try: _log.info(f"Loading balancer configuration from {config_dir} with name {config_name}") cfg = load_airflow_config( config_dir=str(config_dir), config_name=str(config_name), overrides=overrides, basepath=basepath, _offset=_offset, ) if cfg.extensions: for ext in cfg.extensions.values(): if isinstance(ext, BalancerConfiguration): return ext except HydraException: pass _log.warning(f"Balancer configuration not found in {config_dir} with name {config_name}, loading default") return BalancerConfiguration.load_path( yaml_file=Path(config_dir) / f"{config_name or 'balancer'}.yaml", _config_dir=config_dir, )
load_config = BalancerConfiguration.load