Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions airflow-core/src/airflow/models/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,28 +216,50 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) ->
def handle_miss(self, session: Session):
"""Handle a missed deadline by queueing the callback."""

def get_simple_context():
def _build_deadline_context():
from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun

# TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context
# from the scheduler

# Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields
# are not in the current session.
# Fetch the DagRun from the database again to avoid errors when self.dagrun's
# relationship fields are not in the current session.
dagrun = session.get(DagRun, self.dagrun_id)
logical_date = dagrun.logical_date

return {
context: dict[str, Any] = {
# Full DAGRunResponse as a JSON-serializable dict
"dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"),
"deadline": {"id": self.id, "deadline_time": self.deadline_time},
# Top-level convenience keys for Jinja templates (match standard context naming)
"dag_id": dagrun.dag_id,
"run_id": dagrun.run_id,
"logical_date": logical_date,
"data_interval_start": dagrun.data_interval_start,
"data_interval_end": dagrun.data_interval_end,
"run_type": dagrun.run_type,
"conf": dagrun.conf or {},
# Deadline-specific information
"deadline": {
"id": self.id,
"deadline_time": self.deadline_time,
"alert_name": self.deadline_alert.name if self.deadline_alert else None,
},
Comment on lines +239 to +244
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alert_name is derived via self.deadline_alert.name, but Deadline.deadline_alert is not eager-loaded in the scheduler’s deadline query (it currently selectinloads only callback and dagrun). Since handle_miss() is called in a loop, this will trigger a per-deadline lazy-load query (N+1). Either eager-load Deadline.deadline_alert in the scheduler query, or avoid relationship access here by fetching the name in bulk/alongside the deadline rows.

Copilot uses AI. Check for mistakes.
}

# Derived date/time template variables
if logical_date is not None:
context["ds"] = logical_date.strftime("%Y-%m-%d")
context["ds_nodash"] = logical_date.strftime("%Y%m%d")
context["ts"] = logical_date.isoformat()
context["ts_nodash"] = logical_date.strftime("%Y%m%dT%H%M%S")
context["ts_nodash_with_tz"] = logical_date.isoformat().replace("-", "").replace(":", "")

return context

if isinstance(self.callback, TriggererCallback):
# Update the callback with context before queuing
if "kwargs" not in self.callback.data:
self.callback.data["kwargs"] = {}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | {
"context": get_simple_context()
"context": _build_deadline_context()
}

self.callback.queue()
Expand All @@ -248,7 +270,7 @@ def get_simple_context():
if "kwargs" not in self.callback.data:
self.callback.data["kwargs"] = {}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | {
"context": get_simple_context()
"context": _build_deadline_context()
}
self.callback.data["deadline_id"] = str(self.id)
self.callback.data["dag_run_id"] = str(self.dagrun.id)
Expand Down
42 changes: 40 additions & 2 deletions airflow-core/src/airflow/triggers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

from __future__ import annotations

import inspect
import logging
import traceback
from collections.abc import AsyncIterator
from typing import Any
from typing import Any, cast

from airflow._shared.module_loading import import_string, qualname
from airflow.models.callback import CallbackState, _accepts_context
Expand All @@ -32,6 +33,38 @@
PAYLOAD_BODY_KEY = "body"


def _is_notifier_class(callback: Any) -> bool:
"""
Check if the callback is a BaseNotifier subclass (not an instance).

Uses duck-typing (checks for ``async_notify`` and ``template_fields``)
to avoid importing ``airflow.sdk`` in core.
"""
return (
inspect.isclass(callback)
and hasattr(callback, "async_notify")
and hasattr(callback, "template_fields")
and hasattr(callback, "__await__")
)


def _render_callback_kwargs(kwargs: dict[str, Any], context: dict) -> dict[str, Any]:
"""
Render Jinja2 templates in callback kwargs using the provided context.

Uses ``Templater.render_template`` to recursively render all string values
in the kwargs dict. Non-string values (int, float, datetime, …) pass
through unchanged.
"""
# Use CallbackTrigger (which inherits Templater via BaseTrigger) to access
# render_template without importing airflow.sdk directly in core.
from jinja2.sandbox import SandboxedEnvironment

trigger = CallbackTrigger(callback_path="", callback_kwargs={})
jinja_env = SandboxedEnvironment(cache_size=0)
return trigger.render_template(kwargs, cast("Any", context), jinja_env)
Comment on lines +51 to +65
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_render_callback_kwargs() builds a raw jinja2.sandbox.SandboxedEnvironment, which bypasses Airflow’s templating environment (custom sandbox behavior, filters like ds/ts, extensions, etc.). This can make template rendering for plain function callbacks behave differently than Notifier rendering (which uses Templater.get_template_env()). Consider using CallbackTrigger.get_template_env() (or importing Airflow’s SDK SandboxedEnvironment from airflow.sdk.definitions._internal.templater) instead of the raw Jinja environment, and drop the string-based cast("Any", ...) in favor of a proper type (or no cast).

Copilot uses AI. Check for mistakes.


class CallbackTrigger(BaseTrigger):
"""Trigger that executes a callback function asynchronously."""

Expand All @@ -52,9 +85,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
callback = import_string(self.callback_path)
# TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs`
context = self.callback_kwargs.pop("context", None)

# Render Jinja templates in kwargs for plain function callbacks.
# Notifiers handle their own template rendering in __await__ via
# render_template_fields(), so we skip rendering here for them.
if context is not None and not _is_notifier_class(callback):
self.callback_kwargs = _render_callback_kwargs(self.callback_kwargs, context)

if _accepts_context(callback) and context is not None:
result = await callback(**self.callback_kwargs, context=context)
else:
Expand Down
22 changes: 21 additions & 1 deletion airflow-core/tests/unit/models/test_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,29 @@ def test_handle_miss(self, dagrun, session):
context = callback_kwargs.pop("context")
assert callback_kwargs == TEST_CALLBACK_KWARGS

# Verify enriched context — dag_run and deadline info
assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json")
assert context["deadline"]["id"] == deadline_orm.id
assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp()
assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json")
assert context["deadline"]["alert_name"] is None # no deadline_alert in this test
Comment on lines +237 to +241
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test validates the enriched context shape, but it mocks deadline_orm.callback.queue(), so it doesn’t exercise the new context through TriggererCallback.queue()Trigger.from_object()airflow.sdk.serde.serialize(). Since the enriched context now includes additional types (UUIDs/datetimes/nested dicts), a regression test should ensure the callback can be queued successfully and the trigger kwargs can be serialized/deserialized without error.

Copilot generated this review using guidance from repository custom instructions.

# Verify top-level convenience keys
assert context["dag_id"] == dagrun.dag_id
assert context["run_id"] == dagrun.run_id
assert context["logical_date"] == dagrun.logical_date
assert context["data_interval_start"] == dagrun.data_interval_start
assert context["data_interval_end"] == dagrun.data_interval_end
assert context["run_type"] == dagrun.run_type
assert context["conf"] == (dagrun.conf or {})

# Verify derived template variables
assert context["ds"] == dagrun.logical_date.strftime("%Y-%m-%d")
assert context["ds_nodash"] == dagrun.logical_date.strftime("%Y%m%d")
assert context["ts"] == dagrun.logical_date.isoformat()
assert context["ts_nodash"] == dagrun.logical_date.strftime("%Y%m%dT%H%M%S")
assert context["ts_nodash_with_tz"] == dagrun.logical_date.isoformat().replace("-", "").replace(
":", ""
)


@pytest.mark.db_test
Expand Down
176 changes: 171 additions & 5 deletions airflow-core/tests/unit/triggers/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,32 @@

from airflow.models.callback import CallbackState
from airflow.sdk import BaseNotifier
from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, CallbackTrigger
from airflow.triggers.callback import (
PAYLOAD_BODY_KEY,
PAYLOAD_STATUS_KEY,
CallbackTrigger,
_is_notifier_class,
_render_callback_kwargs,
)

TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback"
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}}
TEST_CONTEXT = {
"dag_run": {"dag_id": "test_dag"},
"dag_id": "test_dag",
"run_id": "test_run",
"ds": "2024-01-01",
"ts": "2024-01-01T00:00:00+00:00",
"deadline": {"id": "abc-123", "deadline_time": "2024-01-01T01:00:00+00:00", "alert_name": "my_alert"},
}
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": TEST_CONTEXT}


class ExampleAsyncNotifier(BaseNotifier):
"""Example of a properly implemented async notifier."""

template_fields = ("message",)

def __init__(self, message, **kwargs):
super().__init__(**kwargs)
self.message = message
Expand Down Expand Up @@ -93,7 +109,7 @@ async def test_run_success_with_async_function(self, trigger, mock_import_string
success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT)
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value

Expand All @@ -112,7 +128,7 @@ async def test_run_success_with_notifier(self, trigger, mock_import_string):
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
assert (
success_event.payload[PAYLOAD_BODY_KEY]
== f"Async notification: {TEST_MESSAGE}, context: {{'dag_run': 'test'}}"
== f"Async notification: {TEST_MESSAGE}, context: {TEST_CONTEXT}"
)

@pytest.mark.asyncio
Expand All @@ -129,6 +145,156 @@ async def test_run_failure(self, trigger, mock_import_string):
failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT)
assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])


class TestTemplateRendering:
"""Tests for Jinja2 template rendering in callback kwargs."""

@pytest.mark.asyncio
async def test_run_renders_jinja_templates_in_function_kwargs(self):
"""Plain async function callbacks get their kwargs rendered."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "DAG {{ dag_id }} missed deadline at {{ ds }}",
"context": context,
},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(
message="DAG my_dag missed deadline at 2024-06-15",
context=context,
)

@pytest.mark.asyncio
async def test_run_does_not_double_render_notifier_kwargs(self):
"""Notifier classes should NOT have kwargs pre-rendered — they handle it themselves."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "DAG {{ dag_id }}",
"context": context,
},
)
with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
# The notifier's __await__ renders template_fields, so the final output
# should show the rendered message (rendered by the notifier, not pre-rendered).
assert "DAG my_dag" in events[-1].payload[PAYLOAD_BODY_KEY]

@pytest.mark.asyncio
async def test_run_renders_nested_kwargs(self):
"""Template rendering works recursively on nested dicts and lists."""
context = {"dag_id": "etl_pipeline"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"recipients": ["{{ dag_id }}-team@example.com"],
"metadata": {"dag": "{{ dag_id }}"},
"context": context,
},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(
recipients=["etl_pipeline-team@example.com"],
metadata={"dag": "etl_pipeline"},
context=context,
)

@pytest.mark.asyncio
async def test_run_skips_rendering_when_no_context(self):
"""Without context, kwargs pass through unrendered."""
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={"message": "{{ dag_id }}"},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(message="{{ dag_id }}")

@pytest.mark.asyncio
async def test_notifier_template_fields_rendered_with_context(self):
"""Notifier template_fields are rendered using the provided context."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "Alert for {{ dag_id }} on {{ ds }}",
"context": context,
},
)
with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
# The notifier's __await__ renders template_fields (self.message), so the
# notification body contains the rendered message.
assert "Alert for my_dag on 2024-06-15" in events[-1].payload[PAYLOAD_BODY_KEY]


class TestHelpers:
"""Tests for module-level helper functions."""

def test_is_notifier_class_with_notifier(self):
assert _is_notifier_class(ExampleAsyncNotifier) is True

def test_is_notifier_class_with_function(self):
async def my_func():
pass

assert _is_notifier_class(my_func) is False

def test_is_notifier_class_with_non_notifier_class(self):
class MyClass:
pass

assert _is_notifier_class(MyClass) is False

def test_is_notifier_class_with_notifier_instance(self):
"""Instances are not classes — should return False."""
instance = ExampleAsyncNotifier(message="hi")
assert _is_notifier_class(instance) is False

def test_render_callback_kwargs_renders_strings(self):
result = _render_callback_kwargs(
{"message": "Hello {{ name }}", "count": 5},
{"name": "World"},
)
assert result == {"message": "Hello World", "count": 5}

def test_render_callback_kwargs_handles_nested_structures(self):
result = _render_callback_kwargs(
{"items": ["{{ x }}", "{{ y }}"], "meta": {"key": "{{ x }}"}},
{"x": "a", "y": "b"},
)
assert result == {"items": ["a", "b"], "meta": {"key": "a"}}

def test_render_callback_kwargs_missing_key_renders_empty(self):
result = _render_callback_kwargs(
{"message": "Hello {{ nonexistent }}"},
{"name": "World"},
)
assert result == {"message": "Hello "}

def test_render_callback_kwargs_no_templates_is_noop(self):
kwargs = {"message": "plain text", "count": 42}
result = _render_callback_kwargs(kwargs, {"dag_id": "test"})
assert result == kwargs
Loading