diff --git a/ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py b/ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py index 7052687ee8b8..342e24c71970 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/airflow/metadata.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """ Airflow source to extract metadata from OM UI """ @@ -102,6 +103,13 @@ class AirflowTaskStatus(Enum): AirflowTaskStatus.SKIPPED.value: StatusType.Skipped.value, } +# Upper bound on run_ids sent in a single TaskInstance bulk query. Keeps peak +# memory bounded and stays well below common DB driver IN(...) parameter caps +# (SQLite 999, some MySQL configs 1000). yield_pipeline_status chunks the +# eligible DagRuns by this size and yields statuses per chunk, so a failure in +# one chunk does not wipe out the whole DAG's status ingestion. +_TASK_INSTANCE_RUN_ID_CHUNK_SIZE = 50 + class OMTaskInstance(BaseModel): """ @@ -320,18 +328,22 @@ def get_pipeline_status(self, dag_id: str) -> List[DagRun]: return [] def get_task_instances( - self, dag_id: str, run_id: str, serialized_tasks: List[AirflowTask] - ) -> List[OMTaskInstance]: + self, dag_id: str, run_ids: List[str], serialized_tasks: List[AirflowTask] + ) -> Dict[str, List[OMTaskInstance]]: """ - We are building our own scoped TaskInstance - class to only focus on core properties required - by the metadata ingestion. - - This makes the versioning more flexible on which Airflow - sources we support. + Fetch all TaskInstances for the given DAG and run IDs in a single query, + returning a dict keyed by run_id. This avoids an N+1 pattern where a + separate query was previously fired for each DagRun. """ - task_instance_list = None serialized_tasks_ids = {task.task_id for task in serialized_tasks} + result: Dict[str, List[OMTaskInstance]] = defaultdict(list) + + # Short-circuit: avoid building and executing a query with an empty + # IN(...) list - unnecessary DB round-trip and rejected by some SQL + # dialects. Caller (yield_pipeline_status) already guards this, but + # defend at the boundary as well. + if not run_ids: + return result try: task_instance_list = ( @@ -344,34 +356,51 @@ class to only focus on core properties required ) .filter( TaskInstance.dag_id == dag_id, - TaskInstance.run_id == run_id, + TaskInstance.run_id.in_(run_ids), # updating old runs flag deleted tasks as `removed` TaskInstance.state != AirflowTaskStatus.REMOVED.value, ) .all() ) + for elem in task_instance_list: + # Be defensive per-row: a single malformed/missing value must + # not abort the whole batch. Log and continue so the rest of + # the DAG's task instances still get ingested. + try: + row = elem._asdict() + task_id = row.get("task_id") + run_id = row.get("run_id") + if not task_id or not run_id: + logger.debug( + f"Skipping TaskInstance row with missing " + f"task_id/run_id for dag_id={dag_id}: {row}" + ) + continue + if task_id not in serialized_tasks_ids: + continue + result[run_id].append( + OMTaskInstance( + task_id=task_id, + state=row.get("state"), + start_date=row.get("start_date"), + end_date=row.get("end_date"), + ) + ) + except Exception as row_exc: + logger.debug(traceback.format_exc()) + logger.warning( + f"Skipping malformed TaskInstance row for " + f"dag_id={dag_id}: {row_exc}" + ) + continue except Exception as exc: logger.debug(traceback.format_exc()) logger.warning( - f"Tried to get TaskInstances with run_id. It might not be available in older Airflow versions - {exc}." + f"Tried to get TaskInstances for run_ids. The run_id column " + f"might not be available in older Airflow DB schemas - {exc}." ) - task_instance_dict = ( - [elem._asdict() for elem in task_instance_list] - if task_instance_list - else [] - ) - - return [ - OMTaskInstance( - task_id=elem.get("task_id"), - state=elem.get("state"), - start_date=elem.get("start_date"), - end_date=elem.get("end_date"), - ) - for elem in task_instance_dict - if elem.get("task_id") in serialized_tasks_ids - ] + return result def yield_pipeline_status( self, pipeline_details: AirflowDagDetails @@ -379,15 +408,43 @@ def yield_pipeline_status( try: dag_run_list = self.get_pipeline_status(pipeline_details.dag_id) - for dag_run in dag_run_list: - if ( - dag_run.run_id and self.context.get().task_names - ): # Airflow dags can have old task which are turned off/commented out in code - tasks = self.get_task_instances( - dag_id=dag_run.dag_id, - run_id=dag_run.run_id, + # Filter eligible DagRuns once. task_names is empty when the DAG + # has no tasks in the current context, in which case we skip the + # DB round trip entirely. + task_names = self.context.get().task_names + eligible_runs = [ + dag_run for dag_run in dag_run_list if dag_run.run_id and task_names + ] + + # Chunk run_ids so we never send an unbounded IN(...) list to the + # DB and so we can stream per-run statuses without buffering every + # TaskInstance for the whole DAG in memory at once. A failure in + # one chunk is logged and the remaining chunks still emit. + for start in range(0, len(eligible_runs), _TASK_INSTANCE_RUN_ID_CHUNK_SIZE): + chunk = eligible_runs[start : start + _TASK_INSTANCE_RUN_ID_CHUNK_SIZE] + try: + tasks_by_run_id = self.get_task_instances( + dag_id=pipeline_details.dag_id, + run_ids=[dag_run.run_id for dag_run in chunk], serialized_tasks=pipeline_details.tasks, ) + except Exception as chunk_exc: + # Preserve pre-PR safe-fallback behaviour: if the bulk + # TaskInstance fetch fails for this chunk, still yield a + # PipelineStatus per DagRun with an empty task list + # instead of silently dropping whole runs. This matches + # the prior per-run loop where a DB error produced empty + # tasks but runs were still emitted. + logger.debug(traceback.format_exc()) + logger.warning( + f"Failed TaskInstance chunk for " + f"{pipeline_details.dag_id} " + f"(runs {start}-{start + len(chunk)}) - {chunk_exc}" + ) + tasks_by_run_id = {} + + for dag_run in chunk: + tasks = tasks_by_run_id.get(dag_run.run_id, []) task_statuses = [ TaskStatus( @@ -401,7 +458,7 @@ def yield_pipeline_status( ), # Might be None for running tasks ) # Log link might not be present in all Airflow versions for task in tasks - if task.task_id in self.context.get().task_names + if task.task_id in task_names ] # DagRun objects are built with logical_date (SDK is Airflow 3.x) diff --git a/ingestion/tests/unit/airflow/test_airflow_metadata.py b/ingestion/tests/unit/airflow/test_airflow_metadata.py index 00c146b9ea7b..3b7268e0b13a 100644 --- a/ingestion/tests/unit/airflow/test_airflow_metadata.py +++ b/ingestion/tests/unit/airflow/test_airflow_metadata.py @@ -365,7 +365,7 @@ def test_uses_start_date_when_logical_date_is_none(self, mock_init): source.context.get.return_value = mock_context source.get_pipeline_status = MagicMock(return_value=[dag_run]) - source.get_task_instances = MagicMock(return_value=[]) + source.get_task_instances = MagicMock(return_value={}) source.metadata = MagicMock() mock_pipeline_details = MagicMock() @@ -397,7 +397,7 @@ def test_skips_run_when_both_dates_are_none(self, mock_init): source.context.get.return_value = mock_context source.get_pipeline_status = MagicMock(return_value=[dag_run]) - source.get_task_instances = MagicMock(return_value=[]) + source.get_task_instances = MagicMock(return_value={}) source.metadata = MagicMock() mock_pipeline_details = MagicMock() diff --git a/ingestion/tests/unit/topology/pipeline/test_airflow.py b/ingestion/tests/unit/topology/pipeline/test_airflow.py index 6c302aeb6d37..56520e4712c7 100644 --- a/ingestion/tests/unit/topology/pipeline/test_airflow.py +++ b/ingestion/tests/unit/topology/pipeline/test_airflow.py @@ -27,10 +27,14 @@ OpenMetadataWorkflowConfig, ) from metadata.ingestion.ometa.ometa_api import OpenMetadata -from metadata.ingestion.source.pipeline.airflow.metadata import AirflowSource +from metadata.ingestion.source.pipeline.airflow.metadata import ( + AirflowSource, + OMTaskInstance, +) from metadata.ingestion.source.pipeline.airflow.models import ( AirflowDag, AirflowDagDetails, + AirflowTask, ) from metadata.ingestion.source.pipeline.airflow.utils import get_schedule_interval @@ -998,3 +1002,572 @@ def test_task_source_url_airflow_3x_format(self): assert f"/dags/{quote(dag_id)}/tasks/{quote(task_id)}" in task_url assert "/taskinstance/list/" not in task_url assert "_flt_3_dag_id=" not in task_url + + def test_get_task_instances_bulk_query(self): + """ + Verify that get_task_instances fires a single DB query for all run_ids + (no N+1 per DagRun) and groups the returned rows by run_id. + Tasks not present in serialized_tasks are excluded from the result. + """ + from unittest.mock import MagicMock + + serialized_tasks = [ + AirflowTask(task_id="task_a"), + AirflowTask(task_id="task_b"), + ] + + row_run1 = MagicMock() + row_run1._asdict.return_value = { + "task_id": "task_a", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_1", + } + row_run2 = MagicMock() + row_run2._asdict.return_value = { + "task_id": "task_b", + "state": "failed", + "start_date": None, + "end_date": None, + "run_id": "run_2", + } + unknown_task_row = MagicMock() + unknown_task_row._asdict.return_value = { + "task_id": "task_unknown", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_1", + } + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [row_run1, row_run2, unknown_task_row] + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances( + "my_dag", ["run_1", "run_2"], serialized_tasks + ) + finally: + self.airflow._session = original_session + + # Single DB query — not one per run_id + mock_session.query.assert_called_once() + + # Results grouped correctly by run_id + self.assertIn("run_1", result) + self.assertIn("run_2", result) + + # task_unknown is not in serialized_tasks so it must be excluded + self.assertEqual(len(result["run_1"]), 1) + self.assertEqual(result["run_1"][0].task_id, "task_a") + self.assertEqual(result["run_1"][0].state, "success") + + self.assertEqual(len(result["run_2"]), 1) + self.assertEqual(result["run_2"][0].task_id, "task_b") + self.assertEqual(result["run_2"][0].state, "failed") + + def test_get_task_instances_no_regression_vs_old_per_run_loop(self): + """ + Behavioural-equivalence test against the previous per-run_id loop. + + Reconstructs a realistic mixed dataset (multiple DAG runs, multiple + tasks per run, some renamed/removed tasks, one run with no surviving + tasks) and asserts that the new bulk get_task_instances produces the + same per-run mapping a per-run_id loop over the old single-run filter + would have produced. This is the no-regression check the maintainer + asked for, performed without needing a live Airflow DB. + """ + from unittest.mock import MagicMock + + serialized_tasks = [ + AirflowTask(task_id="extract"), + AirflowTask(task_id="transform"), + AirflowTask(task_id="load"), + ] + + def make_row(task_id, run_id, state): + row = MagicMock() + row._asdict.return_value = { + "task_id": task_id, + "state": state, + "start_date": None, + "end_date": None, + "run_id": run_id, + } + return row + + all_rows = [ + make_row("extract", "scheduled__1", "success"), + make_row("transform", "scheduled__1", "success"), + make_row("load", "scheduled__1", "success"), + make_row("extract", "scheduled__2", "success"), + make_row("transform", "scheduled__2", "failed"), + make_row("legacy_step", "scheduled__2", "success"), + make_row("extract", "manual__3", "running"), + make_row("only_old_task", "scheduled__4", "success"), + ] + run_ids = ["scheduled__1", "scheduled__2", "manual__3", "scheduled__4"] + + def expected_per_run(): + grouped = {} + allowed = {t.task_id for t in serialized_tasks} + for run_id in run_ids: + grouped[run_id] = [ + OMTaskInstance( + task_id=r._asdict()["task_id"], + state=r._asdict()["state"], + start_date=None, + end_date=None, + ) + for r in all_rows + if r._asdict()["run_id"] == run_id + and r._asdict()["task_id"] in allowed + ] + return grouped + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = all_rows + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + actual = self.airflow.get_task_instances( + "etl_dag", run_ids, serialized_tasks + ) + finally: + self.airflow._session = original_session + + expected = expected_per_run() + + # Single bulk query, not one per run_id + mock_session.query.assert_called_once() + self.assertEqual( + set(actual.keys()), {"scheduled__1", "scheduled__2", "manual__3"} + ) + for run_id in actual: + self.assertEqual( + [(t.task_id, t.state) for t in actual[run_id]], + [(t.task_id, t.state) for t in expected[run_id]], + f"Bulk query result for {run_id} diverges from per-run loop output", + ) + # scheduled__4 had only a legacy task: equivalent to old loop returning [] + self.assertEqual(actual.get("scheduled__4", []), expected["scheduled__4"]) + + def test_get_task_instances_returns_empty_dict_on_db_exception(self): + """ + On any DB error (e.g. older Airflow schemas without run_id column) the + method must swallow the exception and return an empty dict so that + yield_pipeline_status keeps emitting per-run statuses with empty task + lists - matching the pre-change safe-fallback behaviour. + """ + from unittest.mock import MagicMock + + mock_session = MagicMock() + mock_session.query.side_effect = RuntimeError("simulated DB failure") + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances( + "any_dag", + ["run_a", "run_b"], + [AirflowTask(task_id="t1")], + ) + finally: + self.airflow._session = original_session + + self.assertEqual(result, {}) + + def test_get_task_instances_handles_empty_run_ids(self): + """ + If get_task_instances is ever called with no run_ids it must not throw + (some SQL dialects reject `IN ()`). yield_pipeline_status guards this + upstream, but the method itself should still degrade gracefully. + """ + from unittest.mock import MagicMock + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [] + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances("any_dag", [], []) + finally: + self.airflow._session = original_session + + self.assertEqual(result, {}) + + def test_get_task_instances_skips_rows_with_missing_fields(self): + """ + Negative-data test: if the DB returns rows with missing task_id or + run_id (e.g. NULLs from a partial/corrupt Airflow schema), the + method must log-and-continue - the rest of the batch must still be + ingested. It must NOT raise and abort the whole DAG. + """ + from unittest.mock import MagicMock + + serialized_tasks = [ + AirflowTask(task_id="task_a"), + AirflowTask(task_id="task_b"), + ] + + good_row = MagicMock() + good_row._asdict.return_value = { + "task_id": "task_a", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_1", + } + missing_task_id = MagicMock() + missing_task_id._asdict.return_value = { + "task_id": None, + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_1", + } + missing_run_id = MagicMock() + missing_run_id._asdict.return_value = { + "task_id": "task_b", + "state": "failed", + "start_date": None, + "end_date": None, + "run_id": None, + } + second_good_row = MagicMock() + second_good_row._asdict.return_value = { + "task_id": "task_b", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_2", + } + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [ + good_row, + missing_task_id, + missing_run_id, + second_good_row, + ] + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances( + "my_dag", ["run_1", "run_2"], serialized_tasks + ) + finally: + self.airflow._session = original_session + + # Bad rows skipped, good rows kept - no exception propagated + self.assertEqual(set(result.keys()), {"run_1", "run_2"}) + self.assertEqual([t.task_id for t in result["run_1"]], ["task_a"]) + self.assertEqual([t.task_id for t in result["run_2"]], ["task_b"]) + + def test_get_task_instances_continues_on_malformed_row(self): + """ + Negative-data test: if a single row raises while being processed + (e.g. ._asdict() explodes for one element), the method must log the + offending row and keep going for the remaining rows in the batch. + Preferred behaviour per maintainer review: log and move forward, + do NOT interrupt processing of the whole DAG. + """ + from unittest.mock import MagicMock + + serialized_tasks = [AirflowTask(task_id="task_a")] + + good_row_before = MagicMock() + good_row_before._asdict.return_value = { + "task_id": "task_a", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_1", + } + broken_row = MagicMock() + broken_row._asdict.side_effect = RuntimeError("corrupt row") + good_row_after = MagicMock() + good_row_after._asdict.return_value = { + "task_id": "task_a", + "state": "failed", + "start_date": None, + "end_date": None, + "run_id": "run_2", + } + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [good_row_before, broken_row, good_row_after] + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances( + "my_dag", ["run_1", "run_2"], serialized_tasks + ) + finally: + self.airflow._session = original_session + + # Both surrounding good rows must be present despite the bad one + self.assertEqual(set(result.keys()), {"run_1", "run_2"}) + self.assertEqual(result["run_1"][0].state, "success") + self.assertEqual(result["run_2"][0].state, "failed") + + def test_get_task_instances_stray_run_id_grouped_separately(self): + """ + Negative-data test: if the DB returns a TaskInstance whose run_id is + not in the requested run_ids list (e.g. stale cache / race with a + delete), it is grouped under its own key in the returned dict. + yield_pipeline_status then safely ignores it via + tasks_by_run_id.get(run_id, []) so no data for the requested runs is + lost and no exception propagates. + """ + from unittest.mock import MagicMock + + serialized_tasks = [AirflowTask(task_id="task_a")] + + requested_row = MagicMock() + requested_row._asdict.return_value = { + "task_id": "task_a", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_requested", + } + stray_row = MagicMock() + stray_row._asdict.return_value = { + "task_id": "task_a", + "state": "success", + "start_date": None, + "end_date": None, + "run_id": "run_stray", + } + + mock_query = MagicMock() + mock_query.filter.return_value = mock_query + mock_query.all.return_value = [requested_row, stray_row] + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + original_session = getattr(self.airflow, "_session", None) + self.airflow._session = mock_session + try: + result = self.airflow.get_task_instances( + "my_dag", ["run_requested"], serialized_tasks + ) + finally: + self.airflow._session = original_session + + # Requested run is populated + self.assertIn("run_requested", result) + self.assertEqual(len(result["run_requested"]), 1) + self.assertEqual(result["run_requested"][0].task_id, "task_a") + + # Stray run is grouped under its own key (not merged with a requested + # run, not dropped silently). yield_pipeline_status's + # tasks_by_run_id.get(run_id, []) lookup means it's safely ignored + # by the caller. + self.assertIn("run_stray", result) + self.assertEqual(len(result["run_stray"]), 1) + self.assertEqual(result["run_stray"][0].task_id, "task_a") + + def test_yield_pipeline_status_chunks_run_ids(self): + """ + Defense-in-depth: even though run_ids is already bounded by + numberOfStatus upstream, yield_pipeline_status must chunk the calls + to get_task_instances by _TASK_INSTANCE_RUN_ID_CHUNK_SIZE so that + we never send an unbounded IN(...) list to the DB and so that a + failed chunk does not wipe out the rest of the DAG's statuses. + + With 125 eligible runs and a chunk size of 50 we expect exactly + 3 calls (50 + 50 + 25) to get_task_instances and 125 yielded + pipeline statuses. + """ + from unittest.mock import MagicMock, patch + + from metadata.ingestion.source.pipeline.airflow import ( + metadata as airflow_module, + ) + + total_runs = 125 + chunk_size = 50 + expected_calls = 3 + + dag_runs = [] + for i in range(total_runs): + dag_run = MagicMock() + dag_run.dag_id = "my_dag" + dag_run.run_id = f"run_{i}" + dag_run.state = "success" + dag_run.logical_date = None + dag_run.start_date = None + dag_runs.append(dag_run) + + pipeline_details = MagicMock() + pipeline_details.dag_id = "my_dag" + pipeline_details.tasks = [AirflowTask(task_id="t1")] + + context_value = MagicMock() + context_value.task_names = ["t1"] + context_value.pipeline_service = "svc" + context_value.pipeline = "my_dag" + + bulk_call_log = [] + + def fake_get_task_instances(dag_id, run_ids, serialized_tasks): + bulk_call_log.append(list(run_ids)) + return {run_id: [] for run_id in run_ids} + + with patch.object( + airflow_module, "_TASK_INSTANCE_RUN_ID_CHUNK_SIZE", chunk_size + ), patch.object( + self.airflow, "get_pipeline_status", return_value=dag_runs + ), patch.object( + self.airflow, "get_task_instances", side_effect=fake_get_task_instances + ), patch.object( + self.airflow, + "context", + MagicMock(get=MagicMock(return_value=context_value)), + ), patch.object( + self.airflow, "metadata", MagicMock() + ), patch( + "metadata.ingestion.source.pipeline.airflow.metadata.fqn.build", + return_value="svc.my_dag", + ), patch( + "metadata.ingestion.source.pipeline.airflow.metadata.datetime_to_ts", + return_value=1, + ): + results = list(self.airflow.yield_pipeline_status(pipeline_details)) + + # Exactly ceil(total_runs / chunk_size) bulk queries + self.assertEqual(len(bulk_call_log), expected_calls) + + # Every chunk respects the configured bound + for chunk in bulk_call_log: + self.assertLessEqual(len(chunk), chunk_size) + + # Chunk sizes for 125 with chunk_size=50 are 50, 50, 25 + self.assertEqual([len(c) for c in bulk_call_log], [50, 50, 25]) + + # Every eligible run_id is covered exactly once, in order + flattened = [run_id for chunk in bulk_call_log for run_id in chunk] + self.assertEqual(flattened, [f"run_{i}" for i in range(total_runs)]) + + # One PipelineStatus is yielded per eligible DagRun + self.assertEqual(len(results), total_runs) + for either in results: + self.assertIsNone(either.left) + self.assertIsNotNone(either.right) + + def test_yield_pipeline_status_chunk_failure_does_not_block_other_chunks(self): + """ + If one chunk's get_task_instances call raises, yield_pipeline_status + must log the failure and keep processing the remaining chunks. To + preserve the pre-PR safe-fallback behaviour, the failed chunk's runs + still produce PipelineStatus objects with empty task lists (instead + of being silently dropped) - matching the prior per-run loop where a + DB error produced empty tasks but runs were still emitted. + """ + from unittest.mock import MagicMock, patch + + from metadata.ingestion.source.pipeline.airflow import ( + metadata as airflow_module, + ) + + total_runs = 30 + chunk_size = 10 # -> 3 chunks of 10 + + dag_runs = [] + for i in range(total_runs): + dag_run = MagicMock() + dag_run.dag_id = "my_dag" + dag_run.run_id = f"run_{i}" + dag_run.state = "success" + dag_run.logical_date = None + dag_run.start_date = None + dag_runs.append(dag_run) + + pipeline_details = MagicMock() + pipeline_details.dag_id = "my_dag" + pipeline_details.tasks = [AirflowTask(task_id="t1")] + + context_value = MagicMock() + context_value.task_names = ["t1"] + context_value.pipeline_service = "svc" + context_value.pipeline = "my_dag" + + call_counter = {"n": 0} + + def fake_get_task_instances(dag_id, run_ids, serialized_tasks): + call_counter["n"] += 1 + # Fail the middle chunk only + if call_counter["n"] == 2: + raise RuntimeError("simulated chunk failure") + return {run_id: [] for run_id in run_ids} + + with patch.object( + airflow_module, "_TASK_INSTANCE_RUN_ID_CHUNK_SIZE", chunk_size + ), patch.object( + self.airflow, "get_pipeline_status", return_value=dag_runs + ), patch.object( + self.airflow, "get_task_instances", side_effect=fake_get_task_instances + ), patch.object( + self.airflow, + "context", + MagicMock(get=MagicMock(return_value=context_value)), + ), patch.object( + self.airflow, "metadata", MagicMock() + ), patch( + "metadata.ingestion.source.pipeline.airflow.metadata.fqn.build", + return_value="svc.my_dag", + ), patch( + "metadata.ingestion.source.pipeline.airflow.metadata.datetime_to_ts", + return_value=1, + ): + results = list(self.airflow.yield_pipeline_status(pipeline_details)) + + # All 3 chunks were attempted even though the middle one raised + self.assertEqual(call_counter["n"], 3) + + # All 30 statuses are emitted: good chunks with whatever tasks they + # returned, failed chunk with empty task lists. None dropped. + self.assertEqual(len(results), total_runs) + for either in results: + self.assertIsNone(either.left) + self.assertIsNotNone(either.right) + + yielded_run_ids = { + either.right.pipeline_status.executionId for either in results + } + self.assertEqual(yielded_run_ids, {f"run_{i}" for i in range(total_runs)}) + + # Runs in the failed middle chunk have empty taskStatus lists + failed_chunk_runs = {f"run_{i}" for i in range(10, 20)} + failed_statuses = [ + e.right.pipeline_status + for e in results + if e.right.pipeline_status.executionId in failed_chunk_runs + ] + self.assertEqual(len(failed_statuses), 10) + for status in failed_statuses: + self.assertEqual(status.taskStatus, [])