import os
import sys
from pathlib import Path
from typing import Dict, Optional
from airflow_pydantic import Dag, DagArgs, Task, TaskArgs
from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from pydantic import BaseModel, Field
from airflow_config.exceptions import ConfigNotFoundError
from airflow_config.utils import _get_calling_dag
__all__ = (
"Configuration",
"load_config",
)
[docs]
class Configuration(BaseModel):
default_task_args: TaskArgs = Field(default_factory=TaskArgs, description="Global default default_args (task arguments)", alias="default_args")
default_dag_args: DagArgs = Field(default_factory=DagArgs, description="Global default dag arguments")
dags: Optional[Dict[str, Dag]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic")
tasks: Optional[Dict[str, Task]] = Field(default_factory=dict, description="List of dags statically configured via Pydantic")
extensions: Optional[Dict[str, BaseModel]] = Field(default_factory=dict, description="Any user-defined extensions")
# Generic options
env: Optional[str] = Field(default="", description="Environment to use for this configuration")
name: Optional[str] = Field(default="", description="Name of the configuration")
root: Optional[Path] = Field(default=None, description="Root path")
tags: Optional[Dict[str, str]] = Field(default_factory=dict, description="Generic Tags for config. NOTE: Not related to dag tags")
@property
def default_args(self):
return self.default_task_args
@staticmethod
def _find_parent_config_folder(config_dir: str = "config", config_name: str = "", *, basepath: str = "", _offset: int = 2):
if basepath:
if basepath.endswith((".py", ".cfg", ".yml", ".yaml")):
calling_dag = Path(basepath)
else:
calling_dag = Path(basepath) / "dummy.py"
else:
calling_dag = Path(_get_calling_dag(offset=_offset))
folder = calling_dag.parent.resolve()
exists = (
(folder / config_dir).exists()
if not config_name
else ((folder / config_dir / f"{config_name}.yml").exists() or (folder / config_dir / f"{config_name}.yaml").exists())
)
while not exists:
folder = folder.parent
if str(folder) == os.path.abspath(os.sep):
raise ConfigNotFoundError(config_dir=config_dir, dagfile=calling_dag)
exists = (
(folder / config_dir).exists()
if not config_name
else ((folder / config_dir / f"{config_name}.yml").exists() or (folder / config_dir / f"{config_name}.yaml").exists())
)
config_dir = (folder / config_dir).resolve()
if not config_name:
return folder.resolve(), config_dir, ""
elif (folder / config_dir / f"{config_name}.yml").exists():
return folder.resolve(), config_dir, (folder / config_dir / f"{config_name}.yml").resolve()
return folder.resolve(), config_dir, (folder / config_dir / f"{config_name}.yaml").resolve()
[docs]
@staticmethod
def load(
config_dir: str = "config",
config_name: str = "",
overrides: Optional[list[str]] = None,
*,
basepath: str = "",
_offset: int = 3,
) -> "Configuration":
overrides = overrides or []
with initialize_config_dir(config_dir=str(Path(__file__).resolve().parent / "hydra"), version_base=None):
if config_dir:
hydra_folder, config_dir, _ = Configuration._find_parent_config_folder(
config_dir=config_dir, config_name=config_name, basepath=basepath, _offset=_offset
)
cfg = compose(config_name="base", overrides=[], return_hydra_config=True)
searchpaths = cfg["hydra"]["searchpath"]
searchpaths.extend([hydra_folder, config_dir])
if config_name:
overrides = [f"+config={config_name}", *overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"]
else:
overrides = [*overrides.copy(), f"hydra.searchpath=[{','.join(searchpaths)}]"]
cfg = compose(config_name="base", overrides=overrides)
config = instantiate(cfg)
if not isinstance(config, Configuration):
config = Configuration(**config)
return config
[docs]
def pre_apply(self, dag, dag_kwargs):
# update options in config based on hard-coded overrides
# in the DAG file itself
per_dag_default_args = {}
# look up per-dag options
if dag_kwargs.get("dag_id", None) in self.dags:
# first try to see if per-dag options have default_args for subtasks
per_dag_kwargs = self.dags[dag_kwargs["dag_id"]]
per_dag_default_args = per_dag_kwargs.default_args
# if dag is disabled directly, quit right away
if per_dag_kwargs.enabled is False or (per_dag_kwargs.enabled is None and self.default_dag_args.enabled is False):
sys.exit(0)
for attr in DagArgs.model_fields:
if attr in ("default_args", "enabled"):
# skip
continue
if attr == "dag_id":
# set
per_dag_kwargs.dag_id = dag_kwargs["dag_id"]
continue
val = getattr(per_dag_kwargs, attr, None)
if attr not in dag_kwargs and val:
dag_kwargs[attr] = val
elif self.default_dag_args.enabled is False:
# if dag has no per-dag-config, but default dag args is disabled, quit right away
sys.exit(0)
# start with empty default_args
default_args = {}
# First, override with global defaults specified in config
for attr in TaskArgs.model_fields:
if getattr(self.default_args, attr, None) is not None:
default_args[attr] = getattr(self.default_args, attr)
# Next, update with per-dag defaults specified in config
for attr in TaskArgs.model_fields:
if getattr(per_dag_default_args, attr, None) is not None:
default_args[attr] = getattr(per_dag_default_args, attr)
# Finally, override with args hardcoded in the DAG file
if "default_args" in dag_kwargs:
default_args.update(dag_kwargs.get("default_args", {}))
# update the dag_kwargs with the final default_args
dag_kwargs["default_args"] = default_args
for attr in DagArgs.model_fields:
if attr not in dag_kwargs and attr not in ("enabled",):
val = getattr(self.default_dag_args, attr, None)
if attr not in dag_kwargs and val is not None:
dag_kwargs[attr] = val
[docs]
def apply(self, dag, dag_kwargs):
# update the options in the dag if necessary,
# instantiate tasks
if dag.dag_id in self.dags:
tasks = self.dags[dag.dag_id].tasks
task_insts = {}
if tasks:
for task_id, task_inst in tasks.items():
task_inst.task_id = task_id
task_insts[task_id] = task_inst.instantiate(dag=dag)
for task_id, task_inst in task_insts.items():
task_deps = tasks[task_id].dependencies
if task_deps:
for dep in task_deps:
task_insts[dep] >> task_inst
load_config = Configuration.load