Source code for airflow_config.configuration.base

import os
from pathlib import Path
from typing import Dict, Optional

from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from pydantic import BaseModel, Field

from airflow_config.configuration.airflow import Dag, DagArgs, Task, TaskArgs
from airflow_config.configuration.python import PythonConfiguration
from airflow_config.exceptions import ConfigNotFoundError
from airflow_config.utils import _get_calling_dag

__all__ = (
    "Configuration",
    "load_config",
)


[docs] class Configuration(BaseModel): default_task_args: TaskArgs = Field(default_factory=TaskArgs, description="Global default default_args (task arguments)", alias="default_args") default_dag_args: DagArgs = Field(default_factory=DagArgs, description="Global default dag arguments") dags: Optional[Dict[str, Dag]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic") tasks: Optional[Dict[str, Task]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic") python: PythonConfiguration = Field(default_factory=PythonConfiguration, description="Global Python configuration") extensions: Optional[Dict[str, BaseModel]] = Field(default_factory=dict, description="Any user-defined extensions") @property def default_args(self): return self.default_task_args @staticmethod def _find_parent_config_folder(config_dir: str = "config", config_name: str = "", *, basepath: str = "", _offset: int = 2): if basepath: if basepath.endswith((".py", ".cfg", ".yml", ".yaml")): calling_dag = Path(basepath) else: calling_dag = Path(basepath) / "dummy.py" else: calling_dag = Path(_get_calling_dag(offset=_offset)) folder = calling_dag.parent.resolve() exists = ( (folder / config_dir).exists() if not config_name else ((folder / config_dir / f"{config_name}.yml").exists() or (folder / config_dir / f"{config_name}.yaml").exists()) ) while not exists: folder = folder.parent if str(folder) == os.path.abspath(os.sep): raise ConfigNotFoundError(config_dir=config_dir, dagfile=calling_dag) exists = ( (folder / config_dir).exists() if not config_name else ((folder / config_dir / f"{config_name}.yml").exists() or (folder / config_dir / f"{config_name}.yaml").exists()) ) config_dir = (folder / config_dir).resolve() if not config_name: return folder.resolve(), config_dir, "" elif (folder / config_dir / f"{config_name}.yml").exists(): return folder.resolve(), config_dir, (folder / config_dir / f"{config_name}.yml").resolve() return folder.resolve(), config_dir, (folder / config_dir / f"{config_name}.yaml").resolve()
[docs] @staticmethod def load( config_dir: str = "config", config_name: str = "", overrides: Optional[list[str]] = None, *, basepath: str = "", _offset: int = 3, ) -> "Configuration": overrides = overrides or [] with initialize_config_dir(config_dir=str(Path(__file__).resolve().parent / "hydra"), version_base=None): if config_dir: hydra_folder, config_dir, _ = Configuration._find_parent_config_folder( config_dir=config_dir, config_name=config_name, basepath=basepath, _offset=_offset ) cfg = compose(config_name="base", overrides=[], return_hydra_config=True) searchpaths = cfg["hydra"]["searchpath"] searchpaths.extend([hydra_folder, config_dir]) if config_name: overrides = [f"+config={config_name}", *overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"] else: overrides = [*overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"] cfg = compose(config_name="base", overrides=overrides) config = instantiate(cfg) if not isinstance(config, Configuration): config = Configuration(**config) return config
[docs] def pre_apply(self, dag, dag_kwargs): # update options in config based on hard-coded overrides # in the DAG file itself if "default_args" not in dag_kwargs: dag_kwargs["default_args"] = {} # look up per-dag options if dag_kwargs.get("dag_id", None) in self.dags: # first try to see if per-dag options have default_args for subtasks per_dag_kwargs = self.dags[dag_kwargs["dag_id"]] default_args = per_dag_kwargs.default_args for attr in TaskArgs.model_fields: if attr not in dag_kwargs["default_args"] and getattr(default_args, attr, None): dag_kwargs["default_args"][attr] = getattr(default_args, attr) for attr in DagArgs.model_fields: if attr == "default_args": # skip continue if attr == "dag_id": # set per_dag_kwargs.dag_id = dag_kwargs["dag_id"] continue val = getattr(per_dag_kwargs, attr, None) if attr not in dag_kwargs and val: dag_kwargs[attr] = val for attr in TaskArgs.model_fields: if attr not in dag_kwargs["default_args"] and getattr(self.default_args, attr, None) is not None: dag_kwargs["default_args"][attr] = getattr(self.default_args, attr) for attr in DagArgs.model_fields: if attr not in dag_kwargs: val = getattr(self.default_dag_args, attr, None) if attr not in dag_kwargs and val is not None: dag_kwargs[attr] = val
[docs] def apply(self, dag, dag_kwargs): # update the options in the dag if necessary, # instantiate tasks if dag.dag_id in self.dags: tasks = self.dags[dag.dag_id].tasks task_insts = {} if tasks: for task_id, task_inst in tasks.items(): task_inst.task_id = task_id task_insts[task_id] = task_inst.instantiate(dag=dag) for task_id, task_inst in task_insts.items(): task_deps = tasks[task_id].dependencies print(task_insts) if task_deps: for dep in task_deps: task_insts[dep] >> task_inst
load_config = Configuration.load