@@ -320,18 +320,15 @@ def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
320320 return []
321321
322322 def get_task_instances (
323- self , dag_id : str , run_id : str , serialized_tasks : List [AirflowTask ]
324- ) -> List [OMTaskInstance ]:
323+ self , dag_id : str , run_ids : List [ str ] , serialized_tasks : List [AirflowTask ]
324+ ) -> Dict [ str , List [OMTaskInstance ] ]:
325325 """
326- We are building our own scoped TaskInstance
327- class to only focus on core properties required
328- by the metadata ingestion.
329-
330- This makes the versioning more flexible on which Airflow
331- sources we support.
326+ Fetch all TaskInstances for the given DAG and run IDs in a single query,
327+ returning a dict keyed by run_id. This avoids an N+1 pattern where a
328+ separate query was previously fired for each DagRun.
332329 """
333- task_instance_list = None
334330 serialized_tasks_ids = {task .task_id for task in serialized_tasks }
331+ result : Dict [str , List [OMTaskInstance ]] = defaultdict (list )
335332
336333 try :
337334 task_instance_list = (
@@ -344,50 +341,59 @@ class to only focus on core properties required
344341 )
345342 .filter (
346343 TaskInstance .dag_id == dag_id ,
347- TaskInstance .run_id == run_id ,
344+ TaskInstance .run_id . in_ ( run_ids ) ,
348345 # updating old runs flag deleted tasks as `removed`
349346 TaskInstance .state != AirflowTaskStatus .REMOVED .value ,
350347 )
351348 .all ()
352349 )
350+ for elem in task_instance_list :
351+ row = elem ._asdict ()
352+ if row .get ("task_id" ) in serialized_tasks_ids :
353+ result [row ["run_id" ]].append (
354+ OMTaskInstance (
355+ task_id = row .get ("task_id" ),
356+ state = row .get ("state" ),
357+ start_date = row .get ("start_date" ),
358+ end_date = row .get ("end_date" ),
359+ )
360+ )
353361 except Exception as exc :
354362 logger .debug (traceback .format_exc ())
355363 logger .warning (
356- f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - { exc } ."
364+ f"Tried to get TaskInstances for run_ids. The run_id column might not be available in older Airflow DB schemas - { exc } ."
357365 )
358366
359- task_instance_dict = (
360- [elem ._asdict () for elem in task_instance_list ]
361- if task_instance_list
362- else []
363- )
364-
365- return [
366- OMTaskInstance (
367- task_id = elem .get ("task_id" ),
368- state = elem .get ("state" ),
369- start_date = elem .get ("start_date" ),
370- end_date = elem .get ("end_date" ),
371- )
372- for elem in task_instance_dict
373- if elem .get ("task_id" ) in serialized_tasks_ids
374- ]
367+ return result
375368
376369 def yield_pipeline_status (
377370 self , pipeline_details : AirflowDagDetails
378371 ) -> Iterable [Either [OMetaPipelineStatus ]]:
379372 try :
380373 dag_run_list = self .get_pipeline_status (pipeline_details .dag_id )
381374
375+ # Collect all run_ids up front so we can fetch all TaskInstances in
376+ # one query instead of one per DagRun (N+1 avoidance).
377+ run_ids = [
378+ dag_run .run_id
379+ for dag_run in dag_run_list
380+ if dag_run .run_id and self .context .get ().task_names
381+ ]
382+ tasks_by_run_id = (
383+ self .get_task_instances (
384+ dag_id = pipeline_details .dag_id ,
385+ run_ids = run_ids ,
386+ serialized_tasks = pipeline_details .tasks ,
387+ )
388+ if run_ids
389+ else {}
390+ )
391+
382392 for dag_run in dag_run_list :
383393 if (
384394 dag_run .run_id and self .context .get ().task_names
385395 ): # Airflow dags can have old task which are turned off/commented out in code
386- tasks = self .get_task_instances (
387- dag_id = dag_run .dag_id ,
388- run_id = dag_run .run_id ,
389- serialized_tasks = pipeline_details .tasks ,
390- )
396+ tasks = tasks_by_run_id .get (dag_run .run_id , [])
391397
392398 task_statuses = [
393399 TaskStatus (
0 commit comments