Skip to content

Commit e1025f2

Browse files
Fixes #27150: Bulk-fetch TaskInstances per DAG to eliminate N+1 in yield_pipeline_status
1 parent c379214 commit e1025f2

2 files changed

Lines changed: 108 additions & 32 deletions

File tree

ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -320,18 +320,15 @@ def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
320320
return []
321321

322322
def get_task_instances(
323-
self, dag_id: str, run_id: str, serialized_tasks: List[AirflowTask]
324-
) -> List[OMTaskInstance]:
323+
self, dag_id: str, run_ids: List[str], serialized_tasks: List[AirflowTask]
324+
) -> Dict[str, List[OMTaskInstance]]:
325325
"""
326-
We are building our own scoped TaskInstance
327-
class to only focus on core properties required
328-
by the metadata ingestion.
329-
330-
This makes the versioning more flexible on which Airflow
331-
sources we support.
326+
Fetch all TaskInstances for the given DAG and run IDs in a single query,
327+
returning a dict keyed by run_id. This avoids an N+1 pattern where a
328+
separate query was previously fired for each DagRun.
332329
"""
333-
task_instance_list = None
334330
serialized_tasks_ids = {task.task_id for task in serialized_tasks}
331+
result: Dict[str, List[OMTaskInstance]] = defaultdict(list)
335332

336333
try:
337334
task_instance_list = (
@@ -344,50 +341,59 @@ class to only focus on core properties required
344341
)
345342
.filter(
346343
TaskInstance.dag_id == dag_id,
347-
TaskInstance.run_id == run_id,
344+
TaskInstance.run_id.in_(run_ids),
348345
# updating old runs flag deleted tasks as `removed`
349346
TaskInstance.state != AirflowTaskStatus.REMOVED.value,
350347
)
351348
.all()
352349
)
350+
for elem in task_instance_list:
351+
row = elem._asdict()
352+
if row.get("task_id") in serialized_tasks_ids:
353+
result[row["run_id"]].append(
354+
OMTaskInstance(
355+
task_id=row.get("task_id"),
356+
state=row.get("state"),
357+
start_date=row.get("start_date"),
358+
end_date=row.get("end_date"),
359+
)
360+
)
353361
except Exception as exc:
354362
logger.debug(traceback.format_exc())
355363
logger.warning(
356-
f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}."
364+
f"Tried to get TaskInstances for run_ids. The run_id column might not be available in older Airflow DB schemas - {exc}."
357365
)
358366

359-
task_instance_dict = (
360-
[elem._asdict() for elem in task_instance_list]
361-
if task_instance_list
362-
else []
363-
)
364-
365-
return [
366-
OMTaskInstance(
367-
task_id=elem.get("task_id"),
368-
state=elem.get("state"),
369-
start_date=elem.get("start_date"),
370-
end_date=elem.get("end_date"),
371-
)
372-
for elem in task_instance_dict
373-
if elem.get("task_id") in serialized_tasks_ids
374-
]
367+
return result
375368

376369
def yield_pipeline_status(
377370
self, pipeline_details: AirflowDagDetails
378371
) -> Iterable[Either[OMetaPipelineStatus]]:
379372
try:
380373
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)
381374

375+
# Collect all run_ids up front so we can fetch all TaskInstances in
376+
# one query instead of one per DagRun (N+1 avoidance).
377+
run_ids = [
378+
dag_run.run_id
379+
for dag_run in dag_run_list
380+
if dag_run.run_id and self.context.get().task_names
381+
]
382+
tasks_by_run_id = (
383+
self.get_task_instances(
384+
dag_id=pipeline_details.dag_id,
385+
run_ids=run_ids,
386+
serialized_tasks=pipeline_details.tasks,
387+
)
388+
if run_ids
389+
else {}
390+
)
391+
382392
for dag_run in dag_run_list:
383393
if (
384394
dag_run.run_id and self.context.get().task_names
385395
): # Airflow dags can have old task which are turned off/commented out in code
386-
tasks = self.get_task_instances(
387-
dag_id=dag_run.dag_id,
388-
run_id=dag_run.run_id,
389-
serialized_tasks=pipeline_details.tasks,
390-
)
396+
tasks = tasks_by_run_id.get(dag_run.run_id, [])
391397

392398
task_statuses = [
393399
TaskStatus(

ingestion/tests/unit/topology/pipeline/test_airflow.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from metadata.ingestion.source.pipeline.airflow.models import (
3232
AirflowDag,
3333
AirflowDagDetails,
34+
AirflowTask,
3435
)
3536
from metadata.ingestion.source.pipeline.airflow.utils import get_schedule_interval
3637

@@ -998,3 +999,72 @@ def test_task_source_url_airflow_3x_format(self):
998999
assert f"/dags/{quote(dag_id)}/tasks/{quote(task_id)}" in task_url
9991000
assert "/taskinstance/list/" not in task_url
10001001
assert "_flt_3_dag_id=" not in task_url
1002+
1003+
def test_get_task_instances_bulk_query(self):
1004+
"""
1005+
Verify that get_task_instances fires a single DB query for all run_ids
1006+
(no N+1 per DagRun) and groups the returned rows by run_id.
1007+
Tasks not present in serialized_tasks are excluded from the result.
1008+
"""
1009+
from unittest.mock import MagicMock
1010+
1011+
serialized_tasks = [
1012+
AirflowTask(task_id="task_a"),
1013+
AirflowTask(task_id="task_b"),
1014+
]
1015+
1016+
row_run1 = MagicMock()
1017+
row_run1._asdict.return_value = {
1018+
"task_id": "task_a",
1019+
"state": "success",
1020+
"start_date": None,
1021+
"end_date": None,
1022+
"run_id": "run_1",
1023+
}
1024+
row_run2 = MagicMock()
1025+
row_run2._asdict.return_value = {
1026+
"task_id": "task_b",
1027+
"state": "failed",
1028+
"start_date": None,
1029+
"end_date": None,
1030+
"run_id": "run_2",
1031+
}
1032+
unknown_task_row = MagicMock()
1033+
unknown_task_row._asdict.return_value = {
1034+
"task_id": "task_unknown",
1035+
"state": "success",
1036+
"start_date": None,
1037+
"end_date": None,
1038+
"run_id": "run_1",
1039+
}
1040+
1041+
mock_query = MagicMock()
1042+
mock_query.filter.return_value = mock_query
1043+
mock_query.all.return_value = [row_run1, row_run2, unknown_task_row]
1044+
mock_session = MagicMock()
1045+
mock_session.query.return_value = mock_query
1046+
1047+
original_session = getattr(self.airflow, "_session", None)
1048+
self.airflow._session = mock_session
1049+
try:
1050+
result = self.airflow.get_task_instances(
1051+
"my_dag", ["run_1", "run_2"], serialized_tasks
1052+
)
1053+
finally:
1054+
self.airflow._session = original_session
1055+
1056+
# Single DB query — not one per run_id
1057+
mock_session.query.assert_called_once()
1058+
1059+
# Results grouped correctly by run_id
1060+
self.assertIn("run_1", result)
1061+
self.assertIn("run_2", result)
1062+
1063+
# task_unknown is not in serialized_tasks so it must be excluded
1064+
self.assertEqual(len(result["run_1"]), 1)
1065+
self.assertEqual(result["run_1"][0].task_id, "task_a")
1066+
self.assertEqual(result["run_1"][0].state, "success")
1067+
1068+
self.assertEqual(len(result["run_2"]), 1)
1069+
self.assertEqual(result["run_2"][0].task_id, "task_b")
1070+
self.assertEqual(result["run_2"][0].state, "failed")

0 commit comments

Comments
 (0)