Skip to content

Commit 8a5e277

Browse files
Fixes #27148: Eliminate N+1 is_paused queries in AirflowSource.get_pipelines_list
1 parent c379214 commit 8a5e277

2 files changed

Lines changed: 84 additions & 51 deletions

File tree

ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from airflow.models.serialized_dag import SerializedDagModel
2525
from airflow.serialization.serialized_objects import SerializedDAG
2626
from pydantic import BaseModel, ValidationError
27-
from sqlalchemy import and_, column, func, inspect, join
27+
from sqlalchemy import and_, column, func, inspect
2828
from sqlalchemy.orm import Session
2929

3030
from metadata.generated.schema.api.data.createPipeline import CreatePipelineRequest
@@ -478,26 +478,37 @@ def get_pipelines_list(self) -> Iterable[AirflowDagDetails]:
478478
# In Airflow 3.x, fileloc is not available on SerializedDagModel
479479
# We need to get it from DagModel instead
480480
if hasattr(SerializedDagModel, "fileloc"):
481-
# Airflow 2.x: fileloc is on SerializedDagModel
482-
# Use tuple IN clause to get only the latest version of each DAG
483-
session_query = self.session.query(
484-
SerializedDagModel.dag_id,
485-
json_data_column,
486-
SerializedDagModel.fileloc,
487-
).join(
488-
latest_dag_subquery,
489-
and_(
490-
SerializedDagModel.dag_id == latest_dag_subquery.c.dag_id,
491-
timestamp_column == latest_dag_subquery.c.max_timestamp,
492-
),
481+
# Airflow 2.x: fileloc is on SerializedDagModel.
482+
# Always LEFT OUTER JOIN DagModel so we can select is_paused in the
483+
# main query and avoid an extra DB round-trip per DAG (N+1).
484+
session_query = (
485+
self.session.query(
486+
SerializedDagModel.dag_id,
487+
json_data_column,
488+
SerializedDagModel.fileloc,
489+
DagModel.is_paused,
490+
)
491+
.join(
492+
latest_dag_subquery,
493+
and_(
494+
SerializedDagModel.dag_id == latest_dag_subquery.c.dag_id,
495+
timestamp_column == latest_dag_subquery.c.max_timestamp,
496+
),
497+
)
498+
.outerjoin(
499+
DagModel,
500+
SerializedDagModel.dag_id == DagModel.dag_id,
501+
)
493502
)
494503
else:
495-
# Airflow 3.x: fileloc is only on DagModel, we need to join
504+
# Airflow 3.x: fileloc is only on DagModel, already joined.
505+
# Add is_paused to the column list — no extra join needed.
496506
session_query = (
497507
self.session.query(
498508
SerializedDagModel.dag_id,
499509
json_data_column,
500510
DagModel.fileloc,
511+
DagModel.is_paused,
501512
)
502513
.join(
503514
latest_dag_subquery,
@@ -513,19 +524,9 @@ def get_pipelines_list(self) -> Iterable[AirflowDagDetails]:
513524
)
514525

515526
if not self.source_config.includeUnDeployedPipelines:
516-
# If we haven't already joined with DagModel (Airflow 2.x case)
517-
if hasattr(SerializedDagModel, "fileloc"):
518-
session_query = session_query.select_from(
519-
join(
520-
SerializedDagModel,
521-
DagModel,
522-
SerializedDagModel.dag_id == DagModel.dag_id,
523-
)
524-
)
525-
# Add the is_paused filter
526-
session_query = session_query.filter(
527-
DagModel.is_paused == False # pylint: disable=singleton-comparison
528-
)
527+
# DagModel is already joined in both paths above, so we can filter
528+
# directly without an extra select_from().
529+
session_query = session_query.filter(DagModel.is_paused.is_(False))
529530
limit = 100 # Number of records per batch
530531
offset = 0 # Start
531532

@@ -540,32 +541,19 @@ def get_pipelines_list(self) -> Iterable[AirflowDagDetails]:
540541
break
541542
for serialized_dag in results:
542543
try:
543-
# Query only the is_paused column from DagModel
544-
try:
545-
is_paused_result = (
546-
self.session.query(DagModel.is_paused)
547-
.filter(DagModel.dag_id == serialized_dag[0])
548-
.scalar()
549-
)
550-
pipeline_state = (
551-
PipelineState.Active.value
552-
if not is_paused_result
553-
else PipelineState.Inactive.value
554-
)
555-
except Exception as exc:
556-
logger.debug(traceback.format_exc())
557-
logger.warning(
558-
f"Could not query DagModel.is_paused for {serialized_dag[0]}. "
559-
f"Using default pipeline state - {exc}"
560-
)
561-
# If we can't query is_paused, assume the pipeline is active
562-
pipeline_state = PipelineState.Active.value
544+
# Unpack by name so future column list changes are explicit.
545+
dag_id, payload, fileloc, is_paused = serialized_dag
546+
pipeline_state = (
547+
PipelineState.Inactive.value
548+
if is_paused
549+
else PipelineState.Active.value
550+
)
563551

564-
data = serialized_dag[1]["dag"]
552+
data = payload["dag"]
565553
dag = AirflowDagDetails(
566-
dag_id=serialized_dag[0],
567-
fileloc=serialized_dag[2],
568-
data=AirflowDag.model_validate(serialized_dag[1]),
554+
dag_id=dag_id,
555+
fileloc=fileloc,
556+
data=AirflowDag.model_validate(payload),
569557
max_active_runs=data.get("max_active_runs", None),
570558
description=data.get("_description", None),
571559
start_date=data.get("start_date", None),

ingestion/tests/unit/topology/pipeline/test_airflow.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,51 @@ def test_get_schedule_interval_with_custom_timetable(self):
398398
self.assertIn("Custom Timetable", result)
399399
self.assertIn("CustomTimetable", result)
400400

401+
def test_get_pipelines_list_derives_state_from_row(self):
402+
"""
403+
Verify that get_pipelines_list derives pipeline_state from the is_paused
404+
column selected in the main query, without issuing a separate per-DAG lookup.
405+
406+
Rows: (dag_id, payload, fileloc, is_paused)
407+
- False -> Active
408+
- True -> Inactive
409+
- None -> Active (LEFT OUTER JOIN miss for undeployed DAGs)
410+
"""
411+
from unittest.mock import MagicMock
412+
413+
from metadata.generated.schema.entity.data.pipeline import PipelineState
414+
415+
active_row = ("dag_active", SERIALIZED_DAG, "/dags/active.py", False)
416+
inactive_row = ("dag_inactive", SERIALIZED_DAG, "/dags/inactive.py", True)
417+
null_row = ("dag_null", SERIALIZED_DAG, "/dags/null.py", None)
418+
419+
# Build a mock that chains through any SQLAlchemy query method and returns
420+
# our fake rows on the first .all() call, then [] to stop pagination.
421+
mock_q = MagicMock()
422+
for method in ("join", "outerjoin", "filter", "order_by", "limit", "offset", "group_by"):
423+
getattr(mock_q, method).return_value = mock_q
424+
mock_q.subquery.return_value = MagicMock()
425+
mock_q.all.side_effect = [
426+
[active_row, inactive_row, null_row],
427+
[],
428+
]
429+
430+
mock_session = MagicMock()
431+
mock_session.query.return_value = mock_q
432+
433+
original_session = getattr(self.airflow, "_session", None)
434+
self.airflow._session = mock_session
435+
try:
436+
dags = list(self.airflow.get_pipelines_list())
437+
finally:
438+
self.airflow._session = original_session
439+
440+
self.assertEqual(3, len(dags))
441+
by_id = {d.dag_id: d for d in dags}
442+
self.assertEqual(PipelineState.Active.value, by_id["dag_active"].state)
443+
self.assertEqual(PipelineState.Inactive.value, by_id["dag_inactive"].state)
444+
self.assertEqual(PipelineState.Active.value, by_id["dag_null"].state)
445+
401446
def test_get_schedule_interval_with_import_error(self):
402447
"""
403448
Test handling of timetable classes that can't be imported

0 commit comments

Comments
 (0)