Source code for airflow_config.configuration.base

import os
import sys
from inspect import currentframe
from logging import getLogger
from pathlib import Path
from typing import Dict, Optional

from airflow.operators.bash import BashOperator
from airflow.utils.session import NEW_SESSION, provide_session
from airflow_pydantic import Dag, DagArgs, Task, TaskArgs
from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from pydantic import AliasChoices, BaseModel, Field, model_validator

from airflow_config.exceptions import ConfigNotFoundError
from airflow_config.utils import _get_calling_dag

__all__ = (
    "Configuration",
    "load_config",
)

_log = getLogger(__name__)


[docs] class Configuration(BaseModel): default_task_args: TaskArgs = Field( default_factory=TaskArgs, description="Global default default_args (task arguments)", validation_alias=AliasChoices("default_args", "default_task_args"), ) default_dag_args: DagArgs = Field(default_factory=DagArgs, description="Global default dag arguments") dags: Optional[Dict[str, Dag]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic") tasks: Optional[Dict[str, Task]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic") # Extensions extensions: Optional[Dict[str, BaseModel]] = Field(default_factory=dict, description="Any user-defined extensions") # Generic options env: Optional[str] = Field(default="", description="Environment to use for this configuration") name: Optional[str] = Field(default="", description="Name of the configuration") root: Optional[Path] = Field(default=None, description="Root path") tags: Optional[Dict[str, str]] = Field(default_factory=dict, description="Generic Tags for config. NOTE: Not related to dag tags") @property def default_args(self): return self.default_task_args @model_validator(mode="after") def _validate(self): # TODO add more validation here if not self.default_dag_args.start_date: _log.warning("No start_date set in default_dag_args, please set a start_date in the default_dag_args or in the dag itself.") return self @staticmethod def _find_parent_config_folder(config_dir: str = "config", config_name: str = "", *, basepath: str = "", _offset: int = 2): if config_name.endswith(".yml"): raise Exception("Config file must be .yaml, not .yml") if config_name and not config_name.endswith(".yaml"): config_name = f"{config_name}.yaml" if basepath: if basepath.endswith((".py", ".cfg", ".yaml")): calling_dag = Path(basepath) else: calling_dag = Path(basepath) / "dummy.py" else: calling_dag = Path(_get_calling_dag(offset=_offset)) folder = calling_dag.parent.resolve() exists = (folder / config_dir).exists() if not config_name else (folder / config_dir / f"{config_name}").exists() while not exists: folder = folder.parent if str(folder) == os.path.abspath(os.sep): raise ConfigNotFoundError(config_dir=config_dir, dagfile=calling_dag) exists = (folder / config_dir).exists() if not config_name else (folder / config_dir / f"{config_name}").exists() if not exists and (folder / config_dir / f"{config_name}.yml").exists(): raise Exception(f"Config file {config_name}.yml exists in {config_dir} but must be .yaml!") config_dir = (folder / config_dir).resolve() if not config_name: return folder.resolve(), config_dir, "" elif (folder / config_dir / f"{config_name}.yml").exists(): raise Exception(f"Config file {config_name}.yml exists in {config_dir} but must be .yaml!") return folder.resolve(), config_dir, (folder / config_dir / f"{config_name}").resolve()
[docs] @staticmethod def load( config_dir: str = "config", config_name: str = "", overrides: Optional[list[str]] = None, *, basepath: str = "", _offset: int = 3, ) -> "Configuration": overrides = overrides or [] with initialize_config_dir(config_dir=str(Path(__file__).resolve().parent / "hydra"), version_base=None): if config_dir: hydra_folder, config_dir, _ = Configuration._find_parent_config_folder( config_dir=config_dir, config_name=config_name, basepath=basepath, _offset=_offset ) cfg = compose(config_name="conf", overrides=[], return_hydra_config=True) searchpaths = cfg["hydra"]["searchpath"] searchpaths.extend([hydra_folder, config_dir]) if config_name: overrides = [f"+config={config_name}", *overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"] else: overrides = [*overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"] cfg = compose(config_name="conf", overrides=overrides) config = instantiate(cfg) if not isinstance(config, Configuration): config = Configuration(**config) # Populate dag ids and task ids as a convenience for dag_id, dag in config.dags.items(): if not dag.dag_id: dag.dag_id = dag_id for task_id, task in dag.tasks.items(): if not task.task_id: task.task_id = task_id return config
[docs] def pre_apply(self, dag, dag_kwargs): # update options in config based on hard-coded overrides # in the DAG file itself per_dag_default_args = None # look up per-dag options if dag_kwargs.get("dag_id", None) in self.dags: # first try to see if per-dag options have default_args for subtasks per_dag_kwargs = self.dags[dag_kwargs["dag_id"]] per_dag_default_args = per_dag_kwargs.default_args # if dag is disabled directly, quit right away if per_dag_kwargs.enabled is False or (per_dag_kwargs.enabled is None and self.default_dag_args.enabled is False): sys.exit(0) for attr, val in per_dag_kwargs.model_dump(exclude_unset=True, exclude={"default_args", "enabled", "tasks"}).items(): if attr == "dag_id": # set per_dag_kwargs.dag_id = dag_kwargs["dag_id"] continue if attr not in dag_kwargs: dag_kwargs[attr] = val elif self.default_dag_args.enabled is False: # if dag has no per-dag-config, but default dag args is disabled, quit right away sys.exit(0) # start with empty default_args default_args = {} # First, override with global defaults specified in config for attr, val in self.default_args.model_dump(exclude_unset=True).items(): default_args[attr] = val # Next, update with per-dag defaults specified in config if per_dag_default_args is not None: for attr, val in per_dag_default_args.model_dump(exclude_unset=True).items(): default_args[attr] = val # Finally, override with args hardcoded in the DAG file if "default_args" in dag_kwargs: default_args.update(dag_kwargs.get("default_args", {})) # update the dag_kwargs with the final default_args dag_kwargs["default_args"] = default_args for attr, val in self.default_dag_args.model_dump(exclude_unset=True, exclude={"enabled"}).items(): if attr not in dag_kwargs: dag_kwargs[attr] = val doc_md = dag_kwargs.get("doc_md", "") if dag_kwargs.get("dag_id", None) in self.dags and self.dags[dag_kwargs["dag_id"]].tasks: doc_md += f"\n# Code Equivalent\n```python\n{self.dags[dag_kwargs['dag_id']].render()}\n```" # doc_md += f"\n# DAG Config\n```json\n{self.dags[dag_kwargs['dag_id']].model_dump_json(indent=2, serialize_as_any=True)}\n```" # doc_md += f"\n# Global Config\n```json\n{self.model_dump_json(indent=2, serialize_as_any=True)}\n```" if doc_md: dag_kwargs["doc_md"] = doc_md
[docs] def apply(self, dag, dag_kwargs): # update the options in the dag if necessary, instantiate tasks if dag.dag_id in self.dags: # Instantiate self against the dag self.dags[dag.dag_id].instantiate(dag=dag)
[docs] @provide_session def generate_in_mem(self, dir: Path | str = None, session=NEW_SESSION, placeholder_dag_id: str = "airflow-config-generate-dags"): from ..dag import DAG cur_frame = currentframe().f_back _log.info("Generating Placeholder DAG for in-memory generation") placeholder_dag = DAG(dag_id=placeholder_dag_id, config=self, schedule=None) with placeholder_dag: BashOperator(task_id=f"{placeholder_dag_id}-placeholder-task", bash_command='echo "Placeholder task"') cur_frame.f_globals[placeholder_dag_id] = placeholder_dag dir = dir or Path.cwd() dir_path = Path(dir) for dag_id, dag in self.dags.items(): if dag.tasks: _log.info(f"Generating DAG: {dag_id} in memory") # rendered = dag.render() dag_path = dir_path / f"{dag_id}.py" dag_instance = DAG(dag_id=dag_id, config=self) # dag_instance.doc_md = f"# Code\n```python\n{rendered}\n```" _log.info(f"Updating DAG fileloc from {dag_instance.fileloc} to {str(dag_path)}") dag_instance.fileloc = str(dag_path) cur_frame.f_globals[dag_id] = dag_instance # Swap out DagCode # First, grab DAG code # _log.info(f"Updating DagModel for {dag_id} in memory") # query = select(DagModel).where(DagModel.dag_id == dag_id) # query = with_row_locks(query, of=DagModel, session=session) # orm_dag: DagModel = session.scalars(query).unique().first() # if not orm_dag: # _log.info(f"No existing DagModel found for {dag_id}, check logs!") # continue # else: # orm_dag.fileloc = str(dag_path) # orm_dag.last_parsed_time = timezone.utcnow() # session.merge(orm_dag) # _log.info(f"Updating DagCode for {dag_id} in memory") # dag_code = DagCode(full_filepath=str(dag_path), source_code=rendered) # _log.info(f"Querying DB for {dag_code.fileloc_hash}") # existing_orm_dag_code = session.scalars(select(DagCode).where(DagCode.fileloc_hash == dag_code.fileloc_hash)).first() # if existing_orm_dag_code: # _log.info(f"Updating existing DagCode for {dag_id} in memory") # existing_orm_dag_code.last_updated = timezone.utcnow() # existing_orm_dag_code.source_code = rendered # session.merge(existing_orm_dag_code) # else: # _log.info(f"Adding new DagCode for {dag_id} in memory") # session.add(dag_code) _log.info(f"Adding DAG {dag_id} complete") session.commit()
[docs] def generate(self, dir: Path | str = None): dir = dir or Path.cwd() dir_path = Path(dir) dir_path.mkdir(parents=True, exist_ok=True) for dag_id, dag in self.dags.items(): if dag.tasks: rendered = dag.render() dag_path = dir_path / f"{dag_id}.py" if dag_path.exists() and dag_path.read_text() == rendered: continue dag_path.write_text(rendered)
load_config = Configuration.load