Skip to content

Commit ea24232

Browse files
authored
Reduce per-DAG queries during DAG serialization with bulk prefetch (#64929) (#65208)
* Reduce per-DAG queries during DAG serialization with bulk prefetch Replaces 3 SELECTs per DAG in write_dag (update interval check, hash comparison, version fetch) with 2 bulk queries via a new _prefetch_dag_write_metadata classmethod. Also fixes DagCode.update_source_code to reuse the caller's session and eagerly loads dag_owner_links to prevent N+1 queries. * fixup! Reduce per-DAG queries during DAG serialization with bulk prefetch * fixup! fixup! Reduce per-DAG queries during DAG serialization with bulk prefetch (cherry picked from commit ef00040)
1 parent 91cca11 commit ea24232

4 files changed

Lines changed: 146 additions & 20 deletions

File tree

airflow-core/src/airflow/dag_processing/collection.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from airflow.models.dagrun import DagRun
5454
from airflow.models.dagwarning import DagWarningType
5555
from airflow.models.errors import ParseImportError
56+
from airflow.models.serialized_dag import SerializedDagModel
5657
from airflow.models.trigger import Trigger
5758
from airflow.serialization.definitions.assets import (
5859
SerializedAsset,
@@ -75,6 +76,7 @@
7576
from sqlalchemy.sql import Select
7677

7778
from airflow.models.dagwarning import DagWarning
79+
from airflow.models.serialized_dag import DagWriteMetadata
7880
from airflow.typing_compat import Self, Unpack
7981

8082
AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)
@@ -256,15 +258,18 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se
256258

257259

258260
def _serialize_dag_capturing_errors(
259-
dag: LazyDeserializedDAG, bundle_name, session: Session, bundle_version: str | None
261+
dag: LazyDeserializedDAG,
262+
bundle_name,
263+
session: Session,
264+
bundle_version: str | None,
265+
_prefetched: DagWriteMetadata | None = None,
260266
):
261267
"""
262268
Try to serialize the dag to the DB, but make a note of any errors.
263269
264270
We can't place them directly in import_errors, as this may be retried, and work the next time
265271
"""
266272
from airflow.models.dagcode import DagCode
267-
from airflow.models.serialized_dag import SerializedDagModel
268273

269274
# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate.
270275
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
@@ -279,10 +284,11 @@ def _serialize_dag_capturing_errors(
279284
bundle_version=bundle_version,
280285
min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
281286
session=session,
287+
_prefetched=_prefetched,
282288
)
283289
if not dag_was_updated:
284290
# Check and update DagCode
285-
DagCode.update_source_code(dag.dag_id, dag.fileloc)
291+
DagCode.update_source_code(dag.dag_id, dag.fileloc, session=session)
286292
if "FabAuthManager" in conf.get("core", "auth_manager"):
287293
_sync_dag_perms(dag, session=session)
288294

@@ -473,6 +479,13 @@ def update_dag_parsing_results_in_db(
473479
SerializedDAG.bulk_write_to_db(
474480
bundle_name, bundle_version, dags, parse_duration, session=session
475481
)
482+
# Bulk prefetch metadata for all DAGs to avoid the standard per-DAG
483+
# metadata lookups in write_dag. This replaces the update-interval,
484+
# hash, and version queries with 2 bulk queries total; DAGs with
485+
# deadlines may still do an additional lookup for deadline UUID reuse.
486+
prefetched_metadata = SerializedDagModel._prefetch_dag_write_metadata(
487+
[dag.dag_id for dag in dags], session=session
488+
)
476489
# Write Serialized DAGs to DB, capturing errors
477490
for dag in dags:
478491
serialize_errors.extend(
@@ -481,6 +494,7 @@ def update_dag_parsing_results_in_db(
481494
bundle_name=bundle_name,
482495
bundle_version=bundle_version,
483496
session=session,
497+
_prefetched=prefetched_metadata.get(dag.dag_id),
484498
)
485499
)
486500
except OperationalError:
@@ -526,6 +540,7 @@ def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
526540
.options(joinedload(DagModel.schedule_asset_references))
527541
.options(joinedload(DagModel.schedule_asset_alias_references))
528542
.options(joinedload(DagModel.task_outlet_asset_references))
543+
.options(joinedload(DagModel.dag_owner_links))
529544
),
530545
of=DagModel,
531546
session=session,

airflow-core/src/airflow/models/serialized_dag.py

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import zlib
2424
from collections.abc import Callable, Iterable, Iterator, Sequence
2525
from datetime import datetime, timedelta
26-
from typing import TYPE_CHECKING, Any, Literal
26+
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
2727
from uuid import UUID
2828

2929
import uuid6
@@ -70,6 +70,14 @@
7070
_COMPRESS_SERIALIZED_DAGS = conf.getboolean("core", "compress_serialized_dags", fallback=False)
7171

7272

73+
class DagWriteMetadata(NamedTuple):
74+
"""Pre-fetched metadata for write_dag to avoid per-DAG queries."""
75+
76+
last_updated: datetime | None
77+
dag_hash: str | None
78+
dag_version: DagVersion | None
79+
80+
7381
class _DagDependenciesResolver:
7482
"""Resolver that resolves dag dependencies to include asset id and assets link to asset aliases."""
7583

@@ -508,6 +516,70 @@ def _create_deadline_alert_records(
508516
)
509517
serialized_dag.deadline_alerts.append(alert)
510518

519+
@classmethod
520+
def _prefetch_dag_write_metadata(
521+
cls, dag_ids: Iterable[str], *, session: Session
522+
) -> dict[str, DagWriteMetadata]:
523+
"""
524+
Bulk-fetch metadata needed by write_dag for multiple DAGs in two queries.
525+
526+
Instead of running 3 SELECTs per DAG in write_dag (update interval check,
527+
hash comparison, version fetch), this fetches all needed data upfront.
528+
529+
:param dag_ids: DAG IDs to prefetch metadata for
530+
:param session: ORM Session
531+
:returns: dict mapping dag_id to DagWriteMetadata
532+
"""
533+
dag_id_list = list(set(dag_ids))
534+
if not dag_id_list:
535+
return {}
536+
537+
# Fetch latest serialized_dag (last_updated, dag_hash) per dag_id
538+
# using a window function to pick the most recent row.
539+
sd_subq = (
540+
select(
541+
cls.dag_id.label("dag_id"),
542+
cls.last_updated.label("last_updated"),
543+
cls.dag_hash.label("dag_hash"),
544+
func.row_number().over(partition_by=cls.dag_id, order_by=cls.created_at.desc()).label("rn"),
545+
)
546+
.where(cls.dag_id.in_(dag_id_list))
547+
.subquery()
548+
)
549+
sd_rows = session.execute(
550+
select(sd_subq.c.dag_id, sd_subq.c.last_updated, sd_subq.c.dag_hash).where(sd_subq.c.rn == 1)
551+
).all()
552+
sd_by_dag_id: dict[str, tuple[datetime, str]] = {
553+
row.dag_id: (row.last_updated, row.dag_hash) for row in sd_rows
554+
}
555+
556+
# Fetch latest DagVersion per dag_id using a window function,
557+
# matching the original write_dag ordering (ORDER BY created_at DESC).
558+
dv_subq = (
559+
select(
560+
DagVersion.id.label("id"),
561+
DagVersion.dag_id.label("dag_id"),
562+
func.row_number()
563+
.over(partition_by=DagVersion.dag_id, order_by=DagVersion.created_at.desc())
564+
.label("rn"),
565+
)
566+
.where(DagVersion.dag_id.in_(dag_id_list))
567+
.subquery()
568+
)
569+
dag_versions = session.scalars(
570+
select(DagVersion).join(dv_subq, DagVersion.id == dv_subq.c.id).where(dv_subq.c.rn == 1)
571+
).all()
572+
dv_by_dag_id: dict[str, DagVersion] = {dv.dag_id: dv for dv in dag_versions}
573+
574+
return {
575+
dag_id: DagWriteMetadata(
576+
last_updated=sd_by_dag_id[dag_id][0] if dag_id in sd_by_dag_id else None,
577+
dag_hash=sd_by_dag_id[dag_id][1] if dag_id in sd_by_dag_id else None,
578+
dag_version=dv_by_dag_id.get(dag_id),
579+
)
580+
for dag_id in dag_id_list
581+
}
582+
511583
@classmethod
512584
@provide_session
513585
def write_dag(
@@ -517,6 +589,7 @@ def write_dag(
517589
bundle_version: str | None = None,
518590
min_update_interval: int | None = None,
519591
session: Session = NEW_SESSION,
592+
_prefetched: DagWriteMetadata | None = None,
520593
) -> bool:
521594
"""
522595
Serialize a DAG and writes it into database.
@@ -529,33 +602,28 @@ def write_dag(
529602
:param bundle_version: bundle version of the DAG
530603
:param min_update_interval: minimal interval in seconds to update serialized DAG
531604
:param session: ORM Session
605+
:param _prefetched: Pre-fetched metadata to skip per-DAG queries; used by bulk callers
532606
533607
:returns: Boolean indicating if the DAG was written to the DB
534608
"""
609+
if _prefetched is None:
610+
_prefetched = cls._prefetch_dag_write_metadata([dag.dag_id], session=session).get(
611+
dag.dag_id, DagWriteMetadata(last_updated=None, dag_hash=None, dag_version=None)
612+
)
613+
535614
# Checks if (Current Time - Time when the DAG was written to DB) < min_update_interval
536615
# If Yes, does nothing
537616
# If No or the DAG does not exists, updates / writes Serialized DAG to DB
538617
if min_update_interval is not None:
539-
if session.scalar(
540-
select(literal(True))
541-
.where(
542-
cls.dag_id == dag.dag_id,
543-
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
544-
)
545-
.select_from(cls)
618+
if (
619+
_prefetched.last_updated is not None
620+
and (timezone.utcnow() - timedelta(seconds=min_update_interval)) < _prefetched.last_updated
546621
):
547622
return False
548623

549624
log.debug("Checking if DAG (%s) changed", dag.dag_id)
550-
serialized_dag_hash = session.scalars(
551-
select(cls.dag_hash).where(cls.dag_id == dag.dag_id).order_by(cls.created_at.desc())
552-
).first()
553-
dag_version = session.scalar(
554-
select(DagVersion)
555-
.where(DagVersion.dag_id == dag.dag_id)
556-
.order_by(DagVersion.created_at.desc())
557-
.limit(1)
558-
)
625+
serialized_dag_hash = _prefetched.dag_hash
626+
dag_version = _prefetched.dag_version
559627

560628
if dag.data.get("dag", {}).get("deadline"):
561629
# Try to reuse existing deadline UUIDs if the deadline definitions haven't changed.

airflow-core/tests/unit/dag_processing/test_collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def test_sync_to_db_is_retried(
537537
bundle_version=None,
538538
min_update_interval=mock.ANY,
539539
session=mock_session,
540+
_prefetched=mock.ANY,
540541
),
541542
]
542543
)

airflow-core/tests/unit/models/test_serialized_dag.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,48 @@ def test_min_update_interval_is_respected(self, provide_interval, new_task, shou
524524
)
525525
assert did_write is should_write
526526

527+
def test_prefetch_dag_write_metadata_multiple_dags(self, dag_maker, session):
528+
"""Test that _prefetch_dag_write_metadata returns correct metadata for multiple DAGs."""
529+
with dag_maker("prefetch_multi_dag1"):
530+
EmptyOperator(task_id="task1")
531+
with dag_maker("prefetch_multi_dag2"):
532+
EmptyOperator(task_id="task1")
533+
534+
result = SDM._prefetch_dag_write_metadata(
535+
["prefetch_multi_dag1", "prefetch_multi_dag2"], session=session
536+
)
537+
538+
assert len(result) == 2
539+
for dag_id in ("prefetch_multi_dag1", "prefetch_multi_dag2"):
540+
metadata = result[dag_id]
541+
assert metadata.last_updated is not None
542+
assert metadata.dag_hash is not None
543+
assert metadata.dag_version is not None
544+
assert metadata.dag_version.dag_id == dag_id
545+
546+
def test_prefetch_dag_write_metadata_returns_latest_version(self, dag_maker, session):
547+
"""Test that _prefetch_dag_write_metadata returns the latest DagVersion."""
548+
with dag_maker("prefetch_version_dag") as dag:
549+
PythonOperator(task_id="task1", python_callable=lambda: None)
550+
# Create a dagrun so that writing a changed DAG creates a new version
551+
dag_maker.create_dagrun(run_id="run1", logical_date=pendulum.datetime(2025, 1, 1))
552+
553+
# Modify the DAG (add a task) and write again to create version 2
554+
PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
555+
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker")
556+
557+
assert (
558+
session.scalar(
559+
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id)
560+
)
561+
== 2
562+
)
563+
564+
result = SDM._prefetch_dag_write_metadata([dag.dag_id], session=session)
565+
metadata = result[dag.dag_id]
566+
assert metadata.dag_version is not None
567+
assert metadata.dag_version.version_number == 2
568+
527569
def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(self, dag_maker, session):
528570
"""Test that new dag_version is created if bundle_name changes but DAG is unchanged."""
529571
# Create and write initial DAG

0 commit comments

Comments
 (0)