Source code for airflow_balancer.config.balancer

from fnmatch import fnmatch
from random import choice
from typing import Callable, List, Optional, Union

from airflow.models.pool import Pool, PoolNotFound  # noqa: F401
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self

from .host import Host
from .port import Port

__all__ = ("BalancerConfiguration",)


[docs] class BalancerConfiguration(BaseModel): hosts: List[Host] = Field(default_factory=list) ports: List[Port] = Field(default_factory=list) default_username: str = "airflow" # Password default_password: Optional[str] = None # If password is stored in a variable default_password_variable: Optional[str] = None # if stored in structured container, access by key default_password_variable_key: Optional[str] = None # Or get key file default_key_file: Optional[str] = None # The queue that might include the host running airflow itself primary_queue: str = "default" # The queue that does not include the host running airflow itself secondary_queue: str = "default" # The default worker queue default_queue: str = "default" # The default pool size default_size: int = Field(default=8) # rewrite pool size from config if differs from airflow variable stored value override_pool_size: bool = False # create connection object in airflow for host create_connection: bool = False @property def all_hosts(self): return sorted(list(set(self.hosts))) @property def all_ports(self): return sorted(list(set(self.ports))) @model_validator(mode="after") def _validate(self) -> Self: # Validate no duplicate hosts seen_hostnames = set() for host in self.hosts: if host.name in seen_hostnames: raise ValueError(f"Duplicate host found: {host.name}") seen_hostnames.add(host.name) # Handle limits for host in self.hosts: if not host.pool: host.pool = host.name if not host.size: host.size = self.default_size # check airflow first try: res = Pool.get_pool(host.pool) # airflow return value differs version-to-version if res is None: raise PoolNotFound elif res.slots != host.size: if self.override_pool_size: Pool.create_or_update_pool( name=host.pool, slots=host.size, description=f"Balancer pool for host({host.name}) pool({host.pool})", include_deferred=False, ) else: host.size = res.slots except PoolNotFound: # else set to default Pool.create_or_update_pool( name=host.pool, slots=host.size, description=f"Balancer pool for host({host.name}) pool({host.pool})", include_deferred=False ) if not host.username and self.default_username: host.username = self.default_username if not host.password and self.default_password: host.password = self.default_password if not host.password_variable and self.default_password_variable: host.password_variable = self.default_password_variable if not host.password_variable_key and self.default_password_variable_key: host.password_variable_key = self.default_password_variable_key if not host.key_file and self.default_key_file: host.key_file = self.default_key_file if not host.size: host.size = self.default_size # Handle ports _used_ports = set() for port in self.ports: if port.host_name and not port.host: port.host = next((host for host in self.all_hosts if host.name == port.host_name), None) if not port.port: raise ValueError("Port must be specified") if not port.host: raise ValueError("Host must be specified") if (port.host.name, port.port) in _used_ports: raise ValueError(f"Duplicate port usage for host: {port.host.name}:{port.port}") _used_ports.add((port.host.name, port.port)) # Create pools # TODO reenable # Pool.create_or_update_pool( # name=port.pool, # slots=1, # description=f"Balancer pool for host({port.port}) port({port.port})", # include_deferred=True, # )
[docs] def filter_hosts( self, name: Optional[Union[str, List[str]]] = None, queue: Optional[Union[str, List[str]]] = None, os: Optional[Union[str, List[str]]] = None, tag: Optional[Union[str, List[str]]] = None, custom: Optional[Callable] = None, ) -> List[Host]: name = name or [] queue = queue or [] os = os or [] tag = tag or [] if isinstance(name, str): name = [name] if isinstance(queue, str): queue = [queue] if isinstance(os, str): os = [os] if isinstance(tag, str): tag = [tag] return [ host for host in self.all_hosts if (not name or any(fnmatch(host.name, n) for n in name)) and (not queue or any(fnmatch(host_queue, queue_pat) for queue_pat in queue for host_queue in host.queues)) and (not tag or any(fnmatch(host_tag, tag_pat) for tag_pat in tag for host_tag in host.tags)) and (not os or any(fnmatch(host.os, o) for o in os)) and (not custom or custom(host)) ]
[docs] def select_host( self, name: Optional[Union[str, List[str]]] = None, queue: Optional[Union[str, List[str]]] = None, os: Union[str, List[str]] = "", tag: Union[str, List[str]] = "", custom: Callable = None, ) -> List[Host]: candidates = self.filter_hosts(name=name, queue=queue, os=os, tag=tag, custom=custom) if not candidates: raise RuntimeError(f"No host found for {name} / {queue} / {os} / {tag}") # TODO more schemes, interrogate usage return choice(candidates)
[docs] def filter_ports( self, name: Optional[Union[str, List[str]]] = None, tag: Optional[Union[str, List[str]]] = None, custom: Optional[Callable] = None, ) -> List[Host]: name = name or [] tag = tag or [] if isinstance(name, str): name = [name] if isinstance(tag, str): tag = [tag] return [ port for port in self.all_ports if (not name or any(fnmatch(port.name, n) for n in name)) and (not tag or any(fnmatch(port_tag, tag_pat) for tag_pat in tag for port_tag in port.tags)) and (not custom or custom(port)) ]
[docs] def select_port( self, name: Optional[Union[str, List[str]]] = None, tag: Union[str, List[str]] = "", custom: Callable = None, ) -> List[Host]: candidates = self.filter_ports(name=name, tag=tag, custom=custom) if not candidates: raise RuntimeError(f"No port found for {name} / {tag}") # TODO more schemes, interrogate usage return choice(candidates)
[docs] def free_port( self, host: Host, min: int = 1000, max: int = 65535, ) -> Port: used_ports = [port.port for port in self.ports if port.host == host] port = Port(host=host, port=choice(range(min, max))) while port.port in used_ports: port = Port(host=host, port=choice(range(min, max))) # TODO add pool around port? or just allow? context manager? return port