-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fixes #27150: Bulk-fetch TaskInstances per DAG to eliminate N+1 in yield_pipeline_status #27152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1023bb6
d09f9d1
32213c1
7a62fc4
834d3b3
73e99dd
7bac5f9
64cd089
4dfedb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # pylint: disable=too-many-lines | ||
| """ | ||
| Airflow source to extract metadata from OM UI | ||
| """ | ||
|
|
@@ -102,6 +103,13 @@ class AirflowTaskStatus(Enum): | |
| AirflowTaskStatus.SKIPPED.value: StatusType.Skipped.value, | ||
| } | ||
|
|
||
| # Upper bound on run_ids sent in a single TaskInstance bulk query. Keeps peak | ||
| # memory bounded and stays well below common DB driver IN(...) parameter caps | ||
| # (SQLite 999, some MySQL configs 1000). yield_pipeline_status chunks the | ||
| # eligible DagRuns by this size and yields statuses per chunk, so a failure in | ||
| # one chunk does not wipe out the whole DAG's status ingestion. | ||
| _TASK_INSTANCE_RUN_ID_CHUNK_SIZE = 50 | ||
|
|
||
|
|
||
| class OMTaskInstance(BaseModel): | ||
| """ | ||
|
|
@@ -320,18 +328,22 @@ def get_pipeline_status(self, dag_id: str) -> List[DagRun]: | |
| return [] | ||
|
|
||
| def get_task_instances( | ||
| self, dag_id: str, run_id: str, serialized_tasks: List[AirflowTask] | ||
| ) -> List[OMTaskInstance]: | ||
| self, dag_id: str, run_ids: List[str], serialized_tasks: List[AirflowTask] | ||
| ) -> Dict[str, List[OMTaskInstance]]: | ||
| """ | ||
| We are building our own scoped TaskInstance | ||
| class to only focus on core properties required | ||
| by the metadata ingestion. | ||
|
|
||
| This makes the versioning more flexible on which Airflow | ||
| sources we support. | ||
| Fetch all TaskInstances for the given DAG and run IDs in a single query, | ||
| returning a dict keyed by run_id. This avoids an N+1 pattern where a | ||
| separate query was previously fired for each DagRun. | ||
| """ | ||
| task_instance_list = None | ||
| serialized_tasks_ids = {task.task_id for task in serialized_tasks} | ||
| result: Dict[str, List[OMTaskInstance]] = defaultdict(list) | ||
|
|
||
| # Short-circuit: avoid building and executing a query with an empty | ||
| # IN(...) list - unnecessary DB round-trip and rejected by some SQL | ||
| # dialects. Caller (yield_pipeline_status) already guards this, but | ||
| # defend at the boundary as well. | ||
| if not run_ids: | ||
| return result | ||
|
|
||
| try: | ||
| task_instance_list = ( | ||
|
|
@@ -344,50 +356,95 @@ class to only focus on core properties required | |
| ) | ||
| .filter( | ||
| TaskInstance.dag_id == dag_id, | ||
| TaskInstance.run_id == run_id, | ||
| TaskInstance.run_id.in_(run_ids), | ||
| # updating old runs flag deleted tasks as `removed` | ||
|
Comment on lines
330
to
360
|
||
| TaskInstance.state != AirflowTaskStatus.REMOVED.value, | ||
| ) | ||
| .all() | ||
| ) | ||
| for elem in task_instance_list: | ||
| # Be defensive per-row: a single malformed/missing value must | ||
| # not abort the whole batch. Log and continue so the rest of | ||
| # the DAG's task instances still get ingested. | ||
| try: | ||
| row = elem._asdict() | ||
| task_id = row.get("task_id") | ||
| run_id = row.get("run_id") | ||
| if not task_id or not run_id: | ||
| logger.debug( | ||
| f"Skipping TaskInstance row with missing " | ||
| f"task_id/run_id for dag_id={dag_id}: {row}" | ||
| ) | ||
| continue | ||
| if task_id not in serialized_tasks_ids: | ||
| continue | ||
| result[run_id].append( | ||
| OMTaskInstance( | ||
| task_id=task_id, | ||
| state=row.get("state"), | ||
| start_date=row.get("start_date"), | ||
| end_date=row.get("end_date"), | ||
| ) | ||
| ) | ||
| except Exception as row_exc: | ||
| logger.debug(traceback.format_exc()) | ||
| logger.warning( | ||
| f"Skipping malformed TaskInstance row for " | ||
| f"dag_id={dag_id}: {row_exc}" | ||
| ) | ||
| continue | ||
| except Exception as exc: | ||
| logger.debug(traceback.format_exc()) | ||
| logger.warning( | ||
| f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}." | ||
| f"Tried to get TaskInstances for run_ids. The run_id column " | ||
| f"might not be available in older Airflow DB schemas - {exc}." | ||
| ) | ||
|
|
||
| task_instance_dict = ( | ||
| [elem._asdict() for elem in task_instance_list] | ||
| if task_instance_list | ||
| else [] | ||
| ) | ||
|
|
||
| return [ | ||
| OMTaskInstance( | ||
| task_id=elem.get("task_id"), | ||
| state=elem.get("state"), | ||
| start_date=elem.get("start_date"), | ||
| end_date=elem.get("end_date"), | ||
| ) | ||
| for elem in task_instance_dict | ||
| if elem.get("task_id") in serialized_tasks_ids | ||
| ] | ||
| return result | ||
|
|
||
| def yield_pipeline_status( | ||
| self, pipeline_details: AirflowDagDetails | ||
| ) -> Iterable[Either[OMetaPipelineStatus]]: | ||
| try: | ||
| dag_run_list = self.get_pipeline_status(pipeline_details.dag_id) | ||
|
|
||
| for dag_run in dag_run_list: | ||
| if ( | ||
| dag_run.run_id and self.context.get().task_names | ||
| ): # Airflow dags can have old task which are turned off/commented out in code | ||
| tasks = self.get_task_instances( | ||
| dag_id=dag_run.dag_id, | ||
| run_id=dag_run.run_id, | ||
| # Filter eligible DagRuns once. task_names is empty when the DAG | ||
| # has no tasks in the current context, in which case we skip the | ||
| # DB round trip entirely. | ||
| task_names = self.context.get().task_names | ||
| eligible_runs = [ | ||
| dag_run for dag_run in dag_run_list if dag_run.run_id and task_names | ||
| ] | ||
|
RajdeepKushwaha5 marked this conversation as resolved.
|
||
|
|
||
| # Chunk run_ids so we never send an unbounded IN(...) list to the | ||
| # DB and so we can stream per-run statuses without buffering every | ||
| # TaskInstance for the whole DAG in memory at once. A failure in | ||
| # one chunk is logged and the remaining chunks still emit. | ||
| for start in range(0, len(eligible_runs), _TASK_INSTANCE_RUN_ID_CHUNK_SIZE): | ||
| chunk = eligible_runs[start : start + _TASK_INSTANCE_RUN_ID_CHUNK_SIZE] | ||
| try: | ||
| tasks_by_run_id = self.get_task_instances( | ||
| dag_id=pipeline_details.dag_id, | ||
| run_ids=[dag_run.run_id for dag_run in chunk], | ||
| serialized_tasks=pipeline_details.tasks, | ||
| ) | ||
| except Exception as chunk_exc: | ||
| # Preserve pre-PR safe-fallback behaviour: if the bulk | ||
| # TaskInstance fetch fails for this chunk, still yield a | ||
| # PipelineStatus per DagRun with an empty task list | ||
| # instead of silently dropping whole runs. This matches | ||
| # the prior per-run loop where a DB error produced empty | ||
| # tasks but runs were still emitted. | ||
| logger.debug(traceback.format_exc()) | ||
| logger.warning( | ||
| f"Failed TaskInstance chunk for " | ||
| f"{pipeline_details.dag_id} " | ||
| f"(runs {start}-{start + len(chunk)}) - {chunk_exc}" | ||
| ) | ||
| tasks_by_run_id = {} | ||
|
|
||
| for dag_run in chunk: | ||
| tasks = tasks_by_run_id.get(dag_run.run_id, []) | ||
|
|
||
|
RajdeepKushwaha5 marked this conversation as resolved.
|
||
| task_statuses = [ | ||
| TaskStatus( | ||
|
|
@@ -401,7 +458,7 @@ def yield_pipeline_status( | |
| ), # Might be None for running tasks | ||
| ) # Log link might not be present in all Airflow versions | ||
| for task in tasks | ||
| if task.task_id in self.context.get().task_names | ||
| if task.task_id in task_names | ||
| ] | ||
|
|
||
| # DagRun objects are built with logical_date (SDK is Airflow 3.x) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_task_instancesis annotated to returnDict[str, List[OMTaskInstance]], but it actually returns adefaultdict(list). Returning adefaultdictcan introduce subtle side effects for callers (e.g.,result[missing_key]will create keys instead of raising), and it’s inconsistent with the declared return type. Consider usingDefaultDict[...]internally and converting to a plaindicton return (or change the return annotation if you intend to exposedefaultdict).