diff --git a/airflow-core/src/airflow/jobs/queues.py b/airflow-core/src/airflow/jobs/queues.py new file mode 100644 index 0000000000000..54f2a12e0d303 --- /dev/null +++ b/airflow-core/src/airflow/jobs/queues.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from asyncio import Lock as AsyncLock, Queue +from collections import OrderedDict, defaultdict, deque +from collections.abc import Iterable, Iterator +from threading import Lock +from typing import Generic, TypeVar + +K = TypeVar("K") +V = TypeVar("V") +KV = TypeVar("KV", bound=tuple) + + +class KeyedHeadQueue(Generic[K, KV]): + """ + A keyed queue that manages values per key in insertion order. + + Features: + - `popleft()` returns only the *first value* per key (in insertion order of keys). + - Once a key's first value is popped, that key will never yield in `popleft()` again. + - Remaining values for consumed keys are preserved. + - Iteration yields those leftover (key, value) pairs. + + Example: + q = FirstValueQueue() + q.append(("task1", "event1")) + q.append(("task1", "event2")) + q.append(("task2", "eventA")) + + q.popleft() # ('task1', 'event1') + q.popleft() # ('task2', 'eventA') + + list(q) # [('task1', 'event2')] + """ + + def __init__(self) -> None: + self.__map: OrderedDict[K, deque[KV]] = OrderedDict() # key -> deque of values + self.__popped_keys: set[K] = set() # keys whose first value has been consumed + self._lock = Lock() + + @property + def _map(self) -> OrderedDict[K, list[KV]]: + with self._lock: + return OrderedDict((key, list(value)) for key, value in self.__map.items()) + + @property + def _popped_keys(self) -> set[K]: + with self._lock: + return set(self.__popped_keys) + + def get(self, key: K, default_value: list[KV] | None = None) -> list[KV] | None: + return list(self._map.get(key, default_value or [])) + + def extend(self, elements: Iterable[KV]) -> None: + for element in elements: + self.append(element) + + def append(self, element: KV) -> None: + """Append a (key, value) pair unless key already consumed.""" + key = element[0] + with self._lock: + if key not in self.__map: + self.__map[key] = deque() + self.__map[key].append(element) + + def popleft(self) -> KV: + """ + Pop the *first inserted value* for the next key in order. + + Raises IndexError if all first values have been popped. + """ + with self._lock: + for key, values in self.__map.items(): + if key not in self.__popped_keys: + value = values.popleft() + self.__popped_keys.add(key) + if not values: + del self.__map[key] + return value + raise IndexError("pop from empty KeyedHeadQueue") + + def popall(self) -> tuple[K, list[KV]]: + """ + Pop all values for the first unconsumed key (in insertion order). + + Marks the key as consumed. + Raises IndexError if no keys remain. + """ + with self._lock: + for key in self.__map.keys(): + if key not in self.__popped_keys: + values = list(self.__map.pop(key, [])) + self.__popped_keys.add(key) + return key, values + + raise IndexError("pop from empty KeyedHeadQueue") + + def __contains__(self, key: K) -> bool: + return key in self._map + + def __iter__(self) -> Iterator[tuple[K, KV]]: + """Iterate over leftover (key, value) pairs in a snapshot, so concurrent appends during iteration are not visible.""" + for key, values in self._map.items(): + for value in values: + yield key, value + + def __len__(self) -> int: + """Count remaining values available.""" + with self._lock: + return sum(len(value) for value in self.__map.values()) + + def __bool__(self) -> bool: + """Count of keys that still have their first value available.""" + with self._lock: + if not sum(1 for key in self.__map if key not in self.__popped_keys) > 0: + self.__popped_keys.clear() + return False + return True + + def keys(self) -> list[K]: + """Keys still waiting for their first value to be popped.""" + with self._lock: + return [key for key in self.__map.keys() if key not in self.__popped_keys] + + +class PartitionedQueue(Generic[K, V], defaultdict[K, Queue[tuple[K, V]]]): + """ + Dict-like container where each key maps to an asyncio.Queue. + + Tracks sizes safely for concurrent access. + Provides put(item) and popleft(). + Uses a total counter to make __bool__ O(1). + Supports both async and threading locks. + """ + + def __init__(self, maxsize: int = 0) -> None: + super().__init__(lambda: Queue(maxsize=maxsize)) + self.maxsize = maxsize + self._async_locks: dict[K, AsyncLock] = defaultdict(AsyncLock) + self._locks: dict[K, Lock] = defaultdict(Lock) + self._sizes: dict[K, int] = defaultdict(int) # track sizes per key + self._total_size: int = 0 # total items across all queues + + def __bool__(self) -> bool: + return self._total_size > 0 + + async def put(self, item: tuple[K, V]) -> None: + key = item[0] + queue = self[key] + async with self._async_locks[key]: + await queue.put(item) + with self._locks[key]: + self._sizes[key] += 1 + self._total_size += 1 + + def popleft(self) -> tuple[K, V]: + """Pop an item from the first non-empty queue synchronously (non-blocking) using thread lock.""" + for key, queue in list(self.items()): + with self._locks[key]: + if self._sizes[key] > 0: + item = queue.get_nowait() # won't raise if size > 0 + self._sizes[key] -= 1 + self._total_size -= 1 + return item + raise StopIteration diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 44f965890423f..b9709dc78526a 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -51,6 +51,7 @@ from airflow.executors.workloads.task import TaskInstanceDTO from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat +from airflow.jobs.queues import KeyedHeadQueue, PartitionedQueue from airflow.models.dagbag import DBDagBag from airflow.models.trigger import Trigger from airflow.observability.metrics import stats_utils @@ -104,6 +105,7 @@ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI logger = logging.getLogger(__name__) +maxsize = conf.getint("triggerer", "max_number_of_events_per_trigger", fallback=1) tracer = trace.get_tracer(__name__) @@ -256,7 +258,7 @@ class TriggerStateChanges(BaseModel): Field(default=None), ] # Format of list[str] is the exc traceback format - failures: list[tuple[int, list[str] | None]] | None = None + failures: list[tuple[int, tuple[str, dict[str, Any]] | None, list[str] | None]] | None = None finished: list[int] | None = None class TriggerStateSync(BaseModel): @@ -413,10 +415,15 @@ class TriggerRunnerSupervisor(WatchedSubprocess): creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False) # Outbound queue of events - events: deque[tuple[int, TriggerEvent]] = attrs.field(factory=deque, init=False) + events: KeyedHeadQueue[int, tuple[int, TriggerEvent]] = attrs.field(factory=KeyedHeadQueue, init=False) # Outbound queue of failed triggers - failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False) + failed_triggers: KeyedHeadQueue[int, tuple[int, tuple[str, dict[str, Any]] | None, list[str] | None]] = ( + attrs.field(factory=KeyedHeadQueue, init=False) + ) + + # Outbound queue of finished triggers + finished_triggers: set = attrs.field(factory=set, init=False) def is_alive(self) -> bool: # Set by `_service_subprocess` in the loop @@ -464,6 +471,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r for id in msg.finished or (): self.running_triggers.discard(id) self.cancelling_triggers.discard(id) + self.finished_triggers.add(id) if factory := self.logger_cache.pop(id, None): factory.upload_to_remote() # Need to close the FD explicitly, as it is not closed when logger is removed. @@ -621,17 +629,35 @@ def load_triggers(self): def handle_events(self): """Dispatch outbound events to the Trigger model which pushes them to the relevant task instances.""" - while self.events: - # Get the event and its trigger ID - trigger_id, event = self.events.popleft() - # Tell the model to wake up its tasks - Trigger.submit_event(trigger_id=trigger_id, event=event) - # Emit stat event - Stats.incr("triggers.succeeded") + if self.events: + with create_session() as session: + while self.events: + trigger_id, event = self.events.popleft() + is_last_event = trigger_id not in self.events + remaining_events = len(self.events.get(trigger_id, [])) + log.info( + "Trigger %s has %s remaining events and %s running triggers: %s", + trigger_id, + remaining_events, + len(self.running_triggers), + len(self.running_triggers), + ) + + # Tell the model to wake up its tasks + if Trigger.submit_event( + trigger_id=trigger_id, event=event, is_last_event=is_last_event, session=session + ): + # This is temporary logging to ease debugging, will be omitted in Airflow code base + log.info("Event %s handled for trigger %s", event.payload, trigger_id) + # Emit stat event + Stats.incr("triggers.succeeded") + else: + self.events.append((trigger_id, event)) def clean_unused(self): """Clean out unused or finished triggers.""" - Trigger.clean_unused() + Trigger.clean_unused(self.finished_triggers.copy()) + self.finished_triggers.clear() def handle_failed_triggers(self): """ @@ -639,12 +665,27 @@ def handle_failed_triggers(self): Task Instances that depend on them need failing. """ - while self.failed_triggers: - # Tell the model to fail this trigger's deps - trigger_id, saved_exc = self.failed_triggers.popleft() - Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc) - # Emit stat event - Stats.incr("triggers.failed") + if self.failed_triggers: + log.info("handle_failed_triggers: %d", len(self.failed_triggers)) + with create_session() as session: + while self.failed_triggers: + trigger_id, trigger, saved_exc = self.failed_triggers.popleft() + + # Tell the model to fail this trigger's deps + if trigger_id not in self.events and Trigger.submit_failure( + trigger_id=trigger_id, trigger=trigger, exc=saved_exc, session=session + ): + log.warning("Trigger %s has failed: %s", trigger_id, saved_exc) + # Emit stat event + Stats.incr("triggers.failed") + else: + log.warning( + "Trigger %s has failed but is still processing %d remaining events, so we waiting a bit...", + trigger_id, + len(self.events.get(trigger_id)), + ) + self.failed_triggers.append((trigger_id, trigger, saved_exc)) + session.flush() def emit_metrics(self): DualStatsManager.gauge( @@ -858,6 +899,7 @@ class TriggerDetails(TypedDict): is_watcher: bool name: str events: int + trigger: tuple[str, dict[str, Any]] | None @attrs.define(kw_only=True) @@ -930,10 +972,10 @@ class TriggerRunner: to_cancel: deque[int] # Outbound queue of events - events: deque[tuple[int, TriggerEvent]] + events: PartitionedQueue[int, TriggerEvent] # Outbound queue of failed triggers - failed_triggers: deque[tuple[int, BaseException | None]] + failed_triggers: KeyedHeadQueue[int, tuple[int, tuple[str, dict[str, Any]] | None, BaseException | None]] # Should-we-stop flag stop: bool = False @@ -950,8 +992,8 @@ def __init__(self): self.trigger_cache = {} self.to_create = deque() self.to_cancel = deque() - self.events = deque() - self.failed_triggers = deque() + self.events = PartitionedQueue(maxsize=maxsize) + self.failed_triggers = KeyedHeadQueue() self.job_id = None self._stop_event = None @@ -1061,7 +1103,7 @@ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance: except BaseException as e: # Either the trigger code or the path to it is bad. Fail the trigger. self.log.error("Trigger failed to load code", error=e, classpath=workload.classpath) - self.failed_triggers.append((trigger_id, e)) + self.failed_triggers.append((trigger_id, None, e)) continue # Loading the trigger class could have been expensive. Lets give other things a chance to run! @@ -1093,7 +1135,7 @@ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance: trigger_instance = trigger_class(**deserialised_kwargs) except TypeError as err: self.log.error("Trigger failed to inflate", error=err) - self.failed_triggers.append((trigger_id, err)) + self.failed_triggers.append((trigger_id, None, err)) continue trigger_instance.trigger_id = trigger_id trigger_instance.triggerer_job_id = self.job_id @@ -1118,8 +1160,13 @@ async def cancel_triggers(self): while self.to_cancel: trigger_id = self.to_cancel.popleft() if trigger_id in self.triggers: - # We only delete if it did not exit already - self.triggers[trigger_id]["task"].cancel() + # We only cancel if it did not exit already + if trigger_id not in self.failed_triggers: + await self.log.ainfo("No need to cancel trigger %s yet...", trigger_id) + elif not self.triggers[trigger_id]["task"].done(): + await self.log.ainfo("Cancelling trigger %s", trigger_id) + self.triggers[trigger_id]["task"].cancel() + pass await asyncio.sleep(0) async def cleanup_finished_triggers(self) -> list[int]: @@ -1130,13 +1177,19 @@ async def cleanup_finished_triggers(self) -> list[int]: """ finished_ids: list[int] = [] for trigger_id, details in list(self.triggers.items()): - if details["task"].done(): + await self.log.ainfo( + "trigger_id %s is %s.", trigger_id, "done" if details["task"].done() else "not done" + ) + if details["task"].done() and trigger_id not in self.events: finished_ids.append(trigger_id) # Check to see if it exited for good reasons saved_exc = None try: result = details["task"].result() - except (asyncio.CancelledError, SystemExit, KeyboardInterrupt): + except (asyncio.CancelledError, SystemExit, KeyboardInterrupt) as e: + await self.log.aexception( + "Trigger %s exited with cancelled error %s", details["name"], e, trigger_id=trigger_id + ) # These are "expected" exceptions and we stop processing here # If we don't, then the system requesting a trigger be removed - # which turns into CancelledError - results in a failure. @@ -1144,14 +1197,15 @@ async def cleanup_finished_triggers(self) -> list[int]: continue except BaseException as e: # This is potentially bad, so log it. - self.log.exception( + await self.log.aexception( "Trigger %s exited with error %s", details["name"], e, trigger_id=trigger_id ) saved_exc = e + self.failed_triggers.append((trigger_id, details.get("trigger"), saved_exc)) else: # See if they foolishly returned a TriggerEvent if isinstance(result, TriggerEvent): - self.log.error( + await self.log.aerror( "Trigger returned a TriggerEvent rather than yielding it", trigger=details["name"], trigger_id=trigger_id, @@ -1159,13 +1213,13 @@ async def cleanup_finished_triggers(self) -> list[int]: # See if this exited without sending an event, in which case # any task instances depending on it need to be failed if details["events"] == 0: - self.log.error( + await self.log.aerror( "Trigger exited without sending an event. Dependent tasks will be failed.", name=details["name"], trigger_id=trigger_id, ) # TODO: better formatting of the exception? - self.failed_triggers.append((trigger_id, saved_exc)) + self.failed_triggers.append((trigger_id, details.get("trigger"), saved_exc)) del self.triggers[trigger_id] await asyncio.sleep(0) return finished_ids @@ -1173,16 +1227,16 @@ async def cleanup_finished_triggers(self) -> list[int]: def process_trigger_events(self, finished_ids: list[int]) -> messages.TriggerStateChanges: # Copy out of our dequeues in threadsafe manner to sync state with parent events_to_send: list[tuple[int, DiscrimatedTriggerEvent]] = [] - failures_to_send: list[tuple[int, list[str] | None]] = [] + failures_to_send: list[tuple[int, tuple[str, dict[str, Any]] | None, list[str] | None]] = [] while self.events: trigger_id, trigger_event = self.events.popleft() events_to_send.append((trigger_id, trigger_event)) while self.failed_triggers: - trigger_id, exc = self.failed_triggers.popleft() + trigger_id, trigger, exc = self.failed_triggers.popleft() tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None - failures_to_send.append((trigger_id, tb)) + failures_to_send.append((trigger_id, trigger, tb)) return messages.TriggerStateChanges( events=events_to_send if events_to_send else None, @@ -1204,7 +1258,7 @@ def sanitize_trigger_events(self, msg: messages.TriggerStateChanges) -> messages trigger_id, trigger_event, ) - self.failed_triggers.append((trigger_id, e)) + self.failed_triggers.append((trigger_id, None, e)) else: events_to_send.append((trigger_id, trigger_event)) @@ -1285,7 +1339,7 @@ async def run_trigger( bind_log_contextvars(trigger_id=trigger_id) name = self.triggers[trigger_id]["name"] - self.log.info("trigger %s starting", name) + await self.log.ainfo("trigger %s starting", name) with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: try: if context is not None: @@ -1295,10 +1349,17 @@ async def run_trigger( await self.log.ainfo( "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event ) + await self.log.ainfo( + "%s size: %d / %d", + trigger_id, + self.events[trigger_id].qsize(), + self.events[trigger_id].maxsize, + ) self.triggers[trigger_id]["events"] += 1 - self.events.append((trigger_id, event)) + await self.events.put((trigger_id, event)) span.set_status(Status(StatusCode.OK)) except asyncio.CancelledError as e: + await self.log.aexception("trigger %s failed due to cancelled error", trigger_id) # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error # message about it if timeout := timeout_after: @@ -1310,6 +1371,10 @@ async def run_trigger( span.set_status(Status(StatusCode.OK), description=str(e)) raise except Exception as e: + await self.log.aexception("trigger %s failed", trigger_id) + # We serialize the trigger first before raising the exception, so that when the trigger is retryable, + # we can resume from the point where it failed when the scheduler recreates the trigger. + self.triggers[trigger_id]["trigger"] = trigger.serialize() span.set_status(Status(StatusCode.ERROR), description=str(e)) raise finally: diff --git a/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_next_trigger_id_to_task_instance_table.py b/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_next_trigger_id_to_task_instance_table.py new file mode 100644 index 0000000000000..262f1a9cbebec --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_next_trigger_id_to_task_instance_table.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add ``next_trigger_id`` column to ``task_instance`` table. + +Revision ID: 658517c60c7f +Revises: 1d6611b6ab7c +Create Date: 2025-12-26 12:07:05.849152 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "658517c60c7f" +down_revision = "1d6611b6ab7c" +branch_labels = None +depends_on = None +airflow_version = "3.2.0" + + +def upgrade(): + """Add ``next_trigger_id`` column to ``task_instance`` table.""" + op.add_column("task_instance", sa.Column("next_trigger_id", sa.Integer(), nullable=True)) + + +def downgrade(): + """Remove ``next_trigger_id`` column from ``task_instance`` table.""" + op.drop_column("task_instance", "next_trigger_id") diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index c93f0ed8e1e13..742800feff55f 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1257,6 +1257,15 @@ def recalculate(self) -> _UnfinishedStates: return schedulable_tis, callback + @classmethod + def get_dag_run(cls, dag_id: str, run_id: str, session: Session) -> DagRun | None: + return session.scalars( + select(DagRun).where( + DagRun.dag_id == dag_id, + DagRun.run_id == run_id, + ) + ).one_or_none() + @provide_session def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: tis = self.get_task_instances(session=session, state=State.task_states) @@ -1277,8 +1286,13 @@ def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterab tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) - unfinished_tis = [t for t in tis if t.state in State.unfinished] finished_tis = [t for t in tis if t.state in State.finished] + uncompleted_tis = [ + t for t in finished_tis if t.next_trigger_id + ] # TODO: this was added to make AIP-88 work + unfinished_tis = [t for t in tis if t.state in State.unfinished] + unfinished_tis.extend(uncompleted_tis) + if unfinished_tis: schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES] self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis)) @@ -1292,7 +1306,9 @@ def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterab # states, so we need to re-compute. if expansion_happened: changed_tis = True - new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] + new_unfinished_tis = [ + t for t in unfinished_tis if t.state in State.unfinished and not t.next_trigger_id + ] finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) unfinished_tis = new_unfinished_tis else: @@ -1477,6 +1493,12 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: return expanded_tis return () + def is_unmapped_task(ti: TI) -> bool: + from airflow.sdk.definitions.mappedoperator import MappedOperator + + # TODO: AIP-88 check why task is still MappedOperator even when not an unmapped task anymore + return isinstance(ti.task, MappedOperator) and ti.map_index == -1 + # Check dependencies. expansion_happened = False # Set of task ids for which was already done _revise_map_indexes_if_mapped @@ -1500,7 +1522,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if new_tis is not None: additional_tis.extend(new_tis) expansion_happened = True - if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: + if not new_tis and schedulable.state in SCHEDULEABLE_STATES: # It's enough to revise map index once per task id, # checking the map index for each mapped task significantly slows down scheduling if schedulable.task.task_id not in revised_map_index_task_ids: @@ -1514,7 +1536,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: # _revise_map_indexes_if_mapped might mark the current task as REMOVED # after calculating mapped task length, so we need to re-check # the task state to ensure it's still schedulable - if schedulable.state in SCHEDULEABLE_STATES: + if not is_unmapped_task(schedulable) and schedulable.state in SCHEDULEABLE_STATES: ready_tis.append(schedulable) # Check if any ti changed state diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 4c2137a5343cf..c4658a8cf9481 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -557,6 +557,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload): # The trigger to resume on if we are in state DEFERRED trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + next_trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Optional timeout utcdatetime for the trigger (past this, we'll fail) trigger_timeout: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) @@ -1615,7 +1616,21 @@ def defer_task(self, session: Session = NEW_SESSION) -> bool: assert self.start_date assert isinstance(self.task, Operator) - if start_trigger_args := self.start_trigger_args: + # Remaining task expansion still running from previous triggerer so reschedule + if self.context_carrier and "trigger" in self.context_carrier: + trigger_classpath, trigger_kwargs = self.context_carrier.pop("trigger", (None, None)) + + self.log.info( + "Creating trigger from context_carrier for task_id %s: %s", self.task_id, trigger_kwargs + ) + trigger_row = Trigger( + classpath=trigger_classpath, + kwargs=trigger_kwargs or {}, + ) + elif not (not self.next_trigger_id and (start_trigger_args := self.start_trigger_args)): + # self.log.warning("Couldn't create trigger from start_from_trigger for task_id %s thus could not be deferred!", self.task_id) + return False + else: trigger_kwargs = start_trigger_args.trigger_kwargs or {} timeout = start_trigger_args.timeout @@ -1630,32 +1645,31 @@ def defer_task(self, session: Session = NEW_SESSION) -> bool: kwargs=trigger_kwargs, ) - # First, make the trigger entry - session.add(trigger_row) - session.flush() - - # Then, update ourselves so it matches the deferral request - # Keep an eye on the logic in `check_and_change_state_before_execution()` - # depending on self.next_method semantics - self.state = TaskInstanceState.DEFERRED - self.trigger_id = trigger_row.id - self.next_method = start_trigger_args.next_method - self.next_kwargs = start_trigger_args.next_kwargs or {} - - # If an execution_timeout is set, set the timeout to the minimum of - # it and the trigger timeout - if execution_timeout := self.task.execution_timeout: - if self.trigger_timeout: - self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) - else: - self.trigger_timeout = self.start_date + execution_timeout - self.start_date = timezone.utcnow() - if self.state != TaskInstanceState.UP_FOR_RESCHEDULE: - self.try_number += 1 - if self.test_mode: - _add_log(event=self.state, task_instance=self, session=session) - return True - return False + # First, make the trigger entry + session.add(trigger_row) + session.flush() + + # Then, update ourselves so it matches the deferral request + # Keep an eye on the logic in `check_and_change_state_before_execution()` + # depending on self.next_method semantics + self.state = TaskInstanceState.DEFERRED + self.trigger_id = trigger_row.id + self.next_method = start_trigger_args.next_method + self.next_kwargs = start_trigger_args.next_kwargs or {} + + # If an execution_timeout is set, set the timeout to the minimum of + # it and the trigger timeout + if execution_timeout := self.task.execution_timeout: + if self.trigger_timeout: + self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) + else: + self.trigger_timeout = self.start_date + execution_timeout + self.start_date = timezone.utcnow() + if self.state != TaskInstanceState.UP_FOR_RESCHEDULE: + self.try_number += 1 + if self.test_mode: + _add_log(event=self.state, task_instance=self, session=session) + return True @classmethod def fetch_handle_failure_context( @@ -1694,7 +1708,11 @@ def fetch_handle_failure_context( if not test_mode: session.add(Log(TaskInstanceState.FAILED.value, ti)) - ti.clear_next_method_args() + # Only clear next method args if first invocation on triggerer failed + if ( + not ti.next_trigger_id + ): # TODO: this check is very important, otherwise failed triggers will clear the XCom's + ti.clear_next_method_args() # Set state correctly and figure out how to log it and decide whether # to email diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 60486b8ce864b..cfd03497de91b 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any from opentelemetry import trace -from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select +from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, or_, select from sqlalchemy.orm import Mapped, mapped_column from airflow._shared.observability.traces import new_task_run_carrier @@ -125,6 +125,16 @@ def variant(self) -> TaskMapVariant: return TaskMapVariant.LIST return TaskMapVariant.DICT + @classmethod + def get_task_map_length(cls, dag_id: str, task_id: str, run_id: str, session: Session) -> int | None: + return session.scalar( + select(TaskMap.length).where( + TaskMap.dag_id == dag_id, + TaskMap.task_id == task_id, + TaskMap.run_id == run_id, + ) + ) + @classmethod def expand_mapped_task( cls, @@ -155,7 +165,13 @@ def expand_mapped_task( ) try: - total_length: int | None = get_mapped_ti_count(task, run_id, session=session) + total_length: int | None = TaskMap.get_task_map_length( + dag_id=task.dag_id, task_id=task.task_id, run_id=run_id, session=session + ) + if not total_length: + total_length = get_mapped_ti_count(task, run_id, session=session) + else: + task = next((op for op in task.get_direct_relatives(upstream=False) if op.is_mapped), task) except NotFullyPopulated as e: if not task.dag or not task.dag.partial: task.log.error( @@ -167,16 +183,19 @@ def expand_mapped_task( total_length = None state: str | None = None - unmapped_ti: TaskInstance | None = session.scalars( - select(TaskInstance).where( - TaskInstance.dag_id == task.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == run_id, - TaskInstance.map_index == -1, - or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), - ) - ).one_or_none() - + unmapped_ti: TaskInstance | None = ( + session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + ).one_or_none() + if task and task.is_mapped + else None + ) all_expanded_tis: list[TaskInstance] = [] if unmapped_ti: @@ -226,15 +245,8 @@ def expand_mapped_task( indexes_to_map: Iterable[int] = () else: # Only create "missing" ones. - current_max_mapping = ( - session.scalar( - select(func.max(TaskInstance.map_index)).where( - TaskInstance.dag_id == task.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == run_id, - ) - ) - or 0 + current_max_mapping = TaskInstance.get_current_max_mapping( + dag_id=task.dag_id, task_id=task.task_id, run_id=run_id, session=session ) indexes_to_map = range(current_max_mapping + 1, total_length) diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index da78eede343dd..1625d92f09184 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -24,7 +24,7 @@ from traceback import format_exception from typing import TYPE_CHECKING, Any -from sqlalchemy import Integer, String, Text, delete, func, or_, select, update +from sqlalchemy import Integer, String, Text, delete, exists, func, or_, select, update from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Mapped, Session, mapped_column, relationship, selectinload from sqlalchemy.sql.functions import coalesce @@ -40,7 +40,7 @@ from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, with_row_locks -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: from sqlalchemy import Row @@ -218,13 +218,26 @@ def fetch_trigger_ids_with_non_task_associations(cls, session: Session = NEW_SES @classmethod @provide_session - def clean_unused(cls, session: Session = NEW_SESSION) -> None: + def clean_unused(cls, finished_triggers: set | None = None, session: Session = NEW_SESSION) -> None: """ Delete all triggers that have no tasks dependent on them and are not associated to an asset. Triggers have a one-to-many relationship to task instances, so we need to clean those up first. Afterward we can drop the triggers not referenced by anyone. """ + # TODO: AIP-88 should be moved into dedicated method called clean_finished + if finished_triggers: + session.execute( + update(TaskInstance) + .where( + or_( + TaskInstance.trigger_id.in_(finished_triggers), + TaskInstance.next_trigger_id.in_(finished_triggers), + ) + ) + .values(trigger_id=None, next_trigger_id=None) + ) + # Update all task instances with trigger IDs that are not DEFERRED to remove them for attempt in run_with_db_retries(): with attempt: @@ -237,12 +250,15 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: ) # Get all triggers that have no task instances, assets, or callbacks depending on them and delete them - ids = ( - select(cls.id) - .where(~cls.assets.any(), ~cls.callback.has()) - .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) - .group_by(cls.id) - .having(func.count(TaskInstance.trigger_id) == 0) + ids = select(Trigger.id).where( + # no TIs referencing trigger_id that are not failed + ~exists().where(TaskInstance.trigger_id == Trigger.id), + # no TIs referencing next_trigger_id that are not failed + ~exists().where(TaskInstance.next_trigger_id == Trigger.id), + # no assets + ~cls.assets.any(), + # no callback + ~cls.callback.has(), ) if get_dialect_name(session) == "mysql": # MySQL doesn't support DELETE with JOIN, so we need to do it in two steps @@ -257,38 +273,79 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: @classmethod @provide_session - def submit_event(cls, trigger_id, event: TriggerEvent, session: Session = NEW_SESSION) -> None: + def submit_event( + cls, + trigger_id, + event: TriggerEvent, + is_last_event: bool = True, + session: Session = NEW_SESSION, + ) -> bool: """ Fire an event. Resume all tasks that were in deferred state. Send an event to all assets associated to the trigger. """ - # Resume deferred tasks - for task_instance in session.scalars( + task_instances = session.scalars( select(TaskInstance).where( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + or_( + TaskInstance.trigger_id == trigger_id, + TaskInstance.next_trigger_id == trigger_id, + # We need to do this as once we run the next_method, trigger_id is removed from TaskInstance + ), + TaskInstance.state.in_([TaskInstanceState.DEFERRED, TaskInstanceState.SUCCESS]), + # TODO: SUCCESS might become COMPLETED + ) + ).all() + + log.info("task_instances: %d", len(task_instances)) + + if task_instances: + log.info("Handle event for trigger %s", trigger_id) + + # Resume deferred tasks + for task_instance in task_instances: + handle_event_submit( + event, + trigger_id=trigger_id, + task_instance=task_instance, + is_last_event=is_last_event, + session=session, + ) + else: + log.debug( + "No more task instances found for trigger %s! Stop processing events for trigger %s", + trigger_id, + trigger_id, ) - ): - handle_event_submit(event, task_instance=task_instance, session=session) # Send an event to assets trigger = session.scalars(select(cls).where(cls.id == trigger_id)).one_or_none() - if trigger is None: - # Already deleted for some reason - return - for asset in trigger.assets: - AssetManager.register_asset_change( - asset=asset.to_serialized(), - extra={"from_trigger": True, "payload": event.payload}, - session=session, - ) - if trigger.callback: - trigger.callback.handle_event(event, session) + + log.info( + "We should register asset changes for trigger_id %s for event %s", + trigger_id, + event, + ) + + if not trigger: + # TODO: check why Trigger disappears after first handled event, we need to FIX this + log.warning("Trigger %s was not found.", trigger_id) + else: + for asset in trigger.assets: + AssetManager.register_asset_change( + asset=asset.to_public(), + extra={"from_trigger": True, "payload": event.payload}, + session=session, + ) + if trigger.callback: + trigger.callback.handle_event(event, session) + + return True if task_instances else False @classmethod @provide_session - def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None: + def submit_failure(cls, trigger_id, trigger: dict, exc=None, session: Session = NEW_SESSION) -> bool: """ When a trigger has failed unexpectedly, mark everything that depended on it as failed. @@ -300,6 +357,41 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> the runtime code understands as immediate-fail, and pack the error into next_kwargs. """ + if trigger: + unfinished_tis = session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where( + TaskInstance.next_trigger_id == trigger_id, + ~TaskInstance.state.in_(State.finished_dr_states), + ) + .execution_options(populate_existing=True) + ) + + log.debug("unfinished_tis: %d", unfinished_tis) + + if unfinished_tis == 0: + task_instances = list( + session.scalars( + select(TaskInstance).where( + TaskInstance.next_trigger_id == trigger_id, + # TaskInstance.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]), + ) + ) + ) + + log.debug("task_instances: %d", len(task_instances)) + + for task_instance in task_instances: + task_instance.next_trigger_id = None + task_instance.context_carrier = { + **(task_instance.context_carrier or {}), + **{"trigger": trigger}, + } + task_instance.set_state(TaskInstanceState.UP_FOR_RETRY) + return True + return False + for task_instance in session.scalars( select(TaskInstance).where( TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED @@ -321,6 +413,8 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> task_instance.state = TaskInstanceState.SCHEDULED task_instance.scheduled_dttm = timezone.utcnow() + return False + @classmethod @provide_session def ids_for_triggerer( @@ -452,7 +546,14 @@ def get_sorted_triggers( @singledispatch -def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, session: Session) -> None: +def handle_event_submit( + event: TriggerEvent, + *, + trigger_id: int, + task_instance: TaskInstance, + is_last_event: bool = True, + session: Session, +) -> None: """ Handle the submit event for a given task instance. @@ -492,6 +593,14 @@ def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, ses # Set the state of the task instance to scheduled task_instance.state = TaskInstanceState.SCHEDULED task_instance.scheduled_dttm = timezone.utcnow() + + if is_last_event: + task_instance.next_trigger_id = None + else: + log.info("trigger %s is not last event to be processed...", trigger_id) + task_instance.try_number = 0 + task_instance.next_trigger_id = trigger_id + session.flush() diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 9bc0608611b5a..10f7ed2cc4cc0 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -115,7 +115,7 @@ class MappedClassProtocol(Protocol): "3.0.3": "fe199e1abd77", "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", - "3.2.0": "1d6611b6ab7c", + "3.2.0": "658517c60c7f", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 503a3f4834c36..57456da7585cd 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -353,7 +353,13 @@ class TestTriggerRunner: def test_run_inline_trigger_canceled(self, session) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { - 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0} + 1: { + "task": MagicMock(spec=asyncio.Task), + "is_watcher": False, + "name": "mock_name", + "events": 0, + "trigger": None, + } } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.timeout_after = None @@ -367,7 +373,13 @@ def test_run_inline_trigger_canceled(self, session) -> None: def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { - 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0} + 1: { + "task": MagicMock(spec=asyncio.Task), + "is_watcher": False, + "name": "mock_name", + "events": 0, + "trigger": None, + } } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.run.side_effect = asyncio.CancelledError() diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 3fa09106ade7f..a3e5b52d941a5 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2476,6 +2476,7 @@ def test_refresh_from_db(self, create_task_instance): "trigger_id": None, "next_kwargs": None, "next_method": None, + "next_trigger_id": None, "updated_at": None, "task_display_name": "Test Refresh from DB Task", "dag_version_id": mock.ANY,