-
Notifications
You must be signed in to change notification settings - Fork 17k
Add Jinja template rendering and richer context for async deadline callbacks #64984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
72a347c
131ccca
04d524e
9630ad9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,11 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import inspect | ||
| import logging | ||
| import traceback | ||
| from collections.abc import AsyncIterator | ||
| from typing import Any | ||
| from typing import Any, cast | ||
|
|
||
| from airflow._shared.module_loading import import_string, qualname | ||
| from airflow.models.callback import CallbackState, _accepts_context | ||
|
|
@@ -32,6 +33,38 @@ | |
| PAYLOAD_BODY_KEY = "body" | ||
|
|
||
|
|
||
| def _is_notifier_class(callback: Any) -> bool: | ||
| """ | ||
| Check if the callback is a BaseNotifier subclass (not an instance). | ||
|
|
||
| Uses duck-typing (checks for ``async_notify`` and ``template_fields``) | ||
| to avoid importing ``airflow.sdk`` in core. | ||
| """ | ||
| return ( | ||
| inspect.isclass(callback) | ||
| and hasattr(callback, "async_notify") | ||
| and hasattr(callback, "template_fields") | ||
| and hasattr(callback, "__await__") | ||
| ) | ||
|
|
||
|
|
||
| def _render_callback_kwargs(kwargs: dict[str, Any], context: dict) -> dict[str, Any]: | ||
| """ | ||
| Render Jinja2 templates in callback kwargs using the provided context. | ||
|
|
||
| Uses ``Templater.render_template`` to recursively render all string values | ||
| in the kwargs dict. Non-string values (int, float, datetime, …) pass | ||
| through unchanged. | ||
| """ | ||
| # Use CallbackTrigger (which inherits Templater via BaseTrigger) to access | ||
| # render_template without importing airflow.sdk directly in core. | ||
| from jinja2.sandbox import SandboxedEnvironment | ||
|
|
||
| trigger = CallbackTrigger(callback_path="", callback_kwargs={}) | ||
| jinja_env = SandboxedEnvironment(cache_size=0) | ||
| return trigger.render_template(kwargs, cast("Any", context), jinja_env) | ||
|
Comment on lines
+51
to
+65
|
||
|
|
||
|
|
||
| class CallbackTrigger(BaseTrigger): | ||
| """Trigger that executes a callback function asynchronously.""" | ||
|
|
||
|
|
@@ -52,9 +85,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: | |
| try: | ||
| yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) | ||
| callback = import_string(self.callback_path) | ||
| # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` | ||
| context = self.callback_kwargs.pop("context", None) | ||
|
|
||
| # Render Jinja templates in kwargs for plain function callbacks. | ||
| # Notifiers handle their own template rendering in __await__ via | ||
| # render_template_fields(), so we skip rendering here for them. | ||
| if context is not None and not _is_notifier_class(callback): | ||
| self.callback_kwargs = _render_callback_kwargs(self.callback_kwargs, context) | ||
|
|
||
| if _accepts_context(callback) and context is not None: | ||
| result = await callback(**self.callback_kwargs, context=context) | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,9 +234,29 @@ def test_handle_miss(self, dagrun, session): | |
| context = callback_kwargs.pop("context") | ||
| assert callback_kwargs == TEST_CALLBACK_KWARGS | ||
|
|
||
| # Verify enriched context — dag_run and deadline info | ||
| assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") | ||
| assert context["deadline"]["id"] == deadline_orm.id | ||
| assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp() | ||
| assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") | ||
| assert context["deadline"]["alert_name"] is None # no deadline_alert in this test | ||
|
Comment on lines
+237
to
+241
|
||
|
|
||
| # Verify top-level convenience keys | ||
| assert context["dag_id"] == dagrun.dag_id | ||
| assert context["run_id"] == dagrun.run_id | ||
| assert context["logical_date"] == dagrun.logical_date | ||
| assert context["data_interval_start"] == dagrun.data_interval_start | ||
| assert context["data_interval_end"] == dagrun.data_interval_end | ||
| assert context["run_type"] == dagrun.run_type | ||
| assert context["conf"] == (dagrun.conf or {}) | ||
|
|
||
| # Verify derived template variables | ||
| assert context["ds"] == dagrun.logical_date.strftime("%Y-%m-%d") | ||
| assert context["ds_nodash"] == dagrun.logical_date.strftime("%Y%m%d") | ||
| assert context["ts"] == dagrun.logical_date.isoformat() | ||
| assert context["ts_nodash"] == dagrun.logical_date.strftime("%Y%m%dT%H%M%S") | ||
| assert context["ts_nodash_with_tz"] == dagrun.logical_date.isoformat().replace("-", "").replace( | ||
| ":", "" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.db_test | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alert_nameis derived viaself.deadline_alert.name, butDeadline.deadline_alertis not eager-loaded in the scheduler’s deadline query (it currently selectinloads onlycallbackanddagrun). Sincehandle_miss()is called in a loop, this will trigger a per-deadline lazy-load query (N+1). Either eager-loadDeadline.deadline_alertin the scheduler query, or avoid relationship access here by fetching the name in bulk/alongside the deadline rows.