Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 92 additions & 35 deletions ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=too-many-lines
"""
Airflow source to extract metadata from OM UI
"""
Expand Down Expand Up @@ -102,6 +103,13 @@ class AirflowTaskStatus(Enum):
AirflowTaskStatus.SKIPPED.value: StatusType.Skipped.value,
}

# Upper bound on run_ids sent in a single TaskInstance bulk query. Keeps peak
# memory bounded and stays well below common DB driver IN(...) parameter caps
# (SQLite 999, some MySQL configs 1000). yield_pipeline_status chunks the
# eligible DagRuns by this size and yields statuses per chunk, so a failure in
# one chunk does not wipe out the whole DAG's status ingestion.
_TASK_INSTANCE_RUN_ID_CHUNK_SIZE = 50


class OMTaskInstance(BaseModel):
"""
Expand Down Expand Up @@ -320,18 +328,22 @@ def get_pipeline_status(self, dag_id: str) -> List[DagRun]:
return []

def get_task_instances(
self, dag_id: str, run_id: str, serialized_tasks: List[AirflowTask]
) -> List[OMTaskInstance]:
self, dag_id: str, run_ids: List[str], serialized_tasks: List[AirflowTask]
) -> Dict[str, List[OMTaskInstance]]:
"""
We are building our own scoped TaskInstance
class to only focus on core properties required
by the metadata ingestion.

This makes the versioning more flexible on which Airflow
sources we support.
Fetch all TaskInstances for the given DAG and run IDs in a single query,
returning a dict keyed by run_id. This avoids an N+1 pattern where a
separate query was previously fired for each DagRun.
"""
task_instance_list = None
serialized_tasks_ids = {task.task_id for task in serialized_tasks}
result: Dict[str, List[OMTaskInstance]] = defaultdict(list)

# Short-circuit: avoid building and executing a query with an empty
# IN(...) list - unnecessary DB round-trip and rejected by some SQL
# dialects. Caller (yield_pipeline_status) already guards this, but
# defend at the boundary as well.
if not run_ids:
return result
Comment on lines +339 to +346
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_task_instances is annotated to return Dict[str, List[OMTaskInstance]], but it actually returns a defaultdict(list). Returning a defaultdict can introduce subtle side effects for callers (e.g., result[missing_key] will create keys instead of raising), and it’s inconsistent with the declared return type. Consider using DefaultDict[...] internally and converting to a plain dict on return (or change the return annotation if you intend to expose defaultdict).

Copilot uses AI. Check for mistakes.

try:
task_instance_list = (
Expand All @@ -344,50 +356,95 @@ class to only focus on core properties required
)
.filter(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.run_id.in_(run_ids),
# updating old runs flag deleted tasks as `removed`
Comment on lines 330 to 360
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_task_instances will still build and execute the SQLAlchemy query even when run_ids is empty, which is both an unnecessary DB round-trip and undermines the stated goal of avoiding empty IN (...) filters in some backends/drivers. Consider short-circuiting early (e.g., return {} when not run_ids) before constructing the query.

Copilot uses AI. Check for mistakes.
TaskInstance.state != AirflowTaskStatus.REMOVED.value,
)
.all()
)
for elem in task_instance_list:
# Be defensive per-row: a single malformed/missing value must
# not abort the whole batch. Log and continue so the rest of
# the DAG's task instances still get ingested.
try:
row = elem._asdict()
task_id = row.get("task_id")
run_id = row.get("run_id")
if not task_id or not run_id:
logger.debug(
f"Skipping TaskInstance row with missing "
f"task_id/run_id for dag_id={dag_id}: {row}"
)
continue
if task_id not in serialized_tasks_ids:
continue
result[run_id].append(
OMTaskInstance(
task_id=task_id,
state=row.get("state"),
start_date=row.get("start_date"),
end_date=row.get("end_date"),
)
)
except Exception as row_exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Skipping malformed TaskInstance row for "
f"dag_id={dag_id}: {row_exc}"
)
continue
except Exception as exc:
logger.debug(traceback.format_exc())
logger.warning(
f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}."
f"Tried to get TaskInstances for run_ids. The run_id column "
f"might not be available in older Airflow DB schemas - {exc}."
)

task_instance_dict = (
[elem._asdict() for elem in task_instance_list]
if task_instance_list
else []
)

return [
OMTaskInstance(
task_id=elem.get("task_id"),
state=elem.get("state"),
start_date=elem.get("start_date"),
end_date=elem.get("end_date"),
)
for elem in task_instance_dict
if elem.get("task_id") in serialized_tasks_ids
]
return result

def yield_pipeline_status(
self, pipeline_details: AirflowDagDetails
) -> Iterable[Either[OMetaPipelineStatus]]:
try:
dag_run_list = self.get_pipeline_status(pipeline_details.dag_id)

for dag_run in dag_run_list:
if (
dag_run.run_id and self.context.get().task_names
): # Airflow dags can have old task which are turned off/commented out in code
tasks = self.get_task_instances(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
# Filter eligible DagRuns once. task_names is empty when the DAG
# has no tasks in the current context, in which case we skip the
# DB round trip entirely.
task_names = self.context.get().task_names
eligible_runs = [
dag_run for dag_run in dag_run_list if dag_run.run_id and task_names
]
Comment thread
RajdeepKushwaha5 marked this conversation as resolved.

# Chunk run_ids so we never send an unbounded IN(...) list to the
# DB and so we can stream per-run statuses without buffering every
# TaskInstance for the whole DAG in memory at once. A failure in
# one chunk is logged and the remaining chunks still emit.
for start in range(0, len(eligible_runs), _TASK_INSTANCE_RUN_ID_CHUNK_SIZE):
chunk = eligible_runs[start : start + _TASK_INSTANCE_RUN_ID_CHUNK_SIZE]
try:
tasks_by_run_id = self.get_task_instances(
dag_id=pipeline_details.dag_id,
run_ids=[dag_run.run_id for dag_run in chunk],
serialized_tasks=pipeline_details.tasks,
)
except Exception as chunk_exc:
# Preserve pre-PR safe-fallback behaviour: if the bulk
# TaskInstance fetch fails for this chunk, still yield a
# PipelineStatus per DagRun with an empty task list
# instead of silently dropping whole runs. This matches
# the prior per-run loop where a DB error produced empty
# tasks but runs were still emitted.
logger.debug(traceback.format_exc())
logger.warning(
f"Failed TaskInstance chunk for "
f"{pipeline_details.dag_id} "
f"(runs {start}-{start + len(chunk)}) - {chunk_exc}"
)
tasks_by_run_id = {}

for dag_run in chunk:
tasks = tasks_by_run_id.get(dag_run.run_id, [])

Comment thread
RajdeepKushwaha5 marked this conversation as resolved.
task_statuses = [
TaskStatus(
Expand All @@ -401,7 +458,7 @@ def yield_pipeline_status(
), # Might be None for running tasks
) # Log link might not be present in all Airflow versions
for task in tasks
if task.task_id in self.context.get().task_names
if task.task_id in task_names
]

# DagRun objects are built with logical_date (SDK is Airflow 3.x)
Expand Down
4 changes: 2 additions & 2 deletions ingestion/tests/unit/airflow/test_airflow_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_uses_start_date_when_logical_date_is_none(self, mock_init):
source.context.get.return_value = mock_context

source.get_pipeline_status = MagicMock(return_value=[dag_run])
source.get_task_instances = MagicMock(return_value=[])
source.get_task_instances = MagicMock(return_value={})
source.metadata = MagicMock()

mock_pipeline_details = MagicMock()
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_skips_run_when_both_dates_are_none(self, mock_init):
source.context.get.return_value = mock_context

source.get_pipeline_status = MagicMock(return_value=[dag_run])
source.get_task_instances = MagicMock(return_value=[])
source.get_task_instances = MagicMock(return_value={})
source.metadata = MagicMock()

mock_pipeline_details = MagicMock()
Expand Down
Loading
Loading