Skip to content

Commit 95fc11e

Browse files
Ignore redelivered message for already-running task (#64052)
Catch TaskAlreadyRunningError from the supervisor and raise Celery Ignore() to prevent the broker redelivery from being recorded as a task failure. related: #58441
1 parent 059e9a4 commit 95fc11e

4 files changed

Lines changed: 92 additions & 19 deletions

File tree

devel-common/src/tests_common/test_utils/version_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
3838
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
3939
AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
4040
AIRFLOW_V_3_1_7_PLUS = get_base_airflow_version_tuple() >= (3, 1, 7)
41+
AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9)
4142
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
4243

4344
if AIRFLOW_V_3_1_PLUS:

providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
from sqlalchemy import select
4343

4444
from airflow.executors.base_executor import BaseExecutor
45-
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
45+
from airflow.providers.celery.version_compat import (
46+
AIRFLOW_V_3_0_PLUS,
47+
AIRFLOW_V_3_1_9_PLUS,
48+
AIRFLOW_V_3_2_PLUS,
49+
)
4650
from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, conf, timeout
4751
from airflow.utils.log.logging_mixin import LoggingMixin
4852
from airflow.utils.net import get_hostname
@@ -189,6 +193,7 @@ def on_celery_worker_ready(*args, **kwargs):
189193
# and deserialization for us
190194
@app.task(name="execute_workload")
191195
def execute_workload(input: str) -> None:
196+
from celery.exceptions import Ignore
192197
from pydantic import TypeAdapter
193198

194199
from airflow.executors import workloads
@@ -208,22 +213,35 @@ def execute_workload(input: str) -> None:
208213
base_url = f"http://localhost:8080{base_url}"
209214
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
210215

211-
if isinstance(workload, workloads.ExecuteTask):
212-
supervise(
213-
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
214-
ti=workload.ti, # type: ignore[arg-type]
215-
dag_rel_path=workload.dag_rel_path,
216-
bundle_info=workload.bundle_info,
217-
token=workload.token,
218-
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
219-
log_path=workload.log_path,
220-
)
221-
elif isinstance(workload, workloads.ExecuteCallback):
222-
success, error_msg = execute_callback_workload(workload.callback, log)
223-
if not success:
224-
raise RuntimeError(error_msg or "Callback execution failed")
225-
else:
226-
raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}")
216+
try:
217+
if isinstance(workload, workloads.ExecuteTask):
218+
supervise(
219+
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
220+
ti=workload.ti, # type: ignore[arg-type]
221+
dag_rel_path=workload.dag_rel_path,
222+
bundle_info=workload.bundle_info,
223+
token=workload.token,
224+
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
225+
log_path=workload.log_path,
226+
)
227+
elif isinstance(workload, workloads.ExecuteCallback):
228+
success, error_msg = execute_callback_workload(workload.callback, log)
229+
if not success:
230+
raise RuntimeError(error_msg or "Callback execution failed")
231+
else:
232+
raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}")
233+
except Exception as e:
234+
if AIRFLOW_V_3_1_9_PLUS:
235+
from airflow.sdk.exceptions import TaskAlreadyRunningError
236+
237+
if isinstance(e, TaskAlreadyRunningError):
238+
log.info("[%s] Task already running elsewhere, ignoring redelivered message", celery_task_id)
239+
# Raise Ignore() so Celery does not record a FAILURE result for this duplicate
240+
# delivery. Without this, the broker redelivering the message (e.g. after a
241+
# visibility timeout) would cause Celery to mark the task as failed, even though
242+
# the original worker is still executing it successfully.
243+
raise Ignore()
244+
raise
227245

228246

229247
if not AIRFLOW_V_3_0_PLUS:

providers/celery/src/airflow/providers/celery/version_compat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
2727

2828

2929
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
30+
AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9)
3031
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
3132

32-
__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_2_PLUS"]
33+
__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_9_PLUS", "AIRFLOW_V_3_2_PLUS"]

providers/celery/tests/unit/celery/executors/test_celery_executor.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
from tests_common.test_utils.config import conf_vars
4646
from tests_common.test_utils.dag import sync_dag_to_db
4747
from tests_common.test_utils.taskinstance import create_task_instance
48-
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
48+
from tests_common.test_utils.version_compat import (
49+
AIRFLOW_V_3_0_PLUS,
50+
AIRFLOW_V_3_1_9_PLUS,
51+
AIRFLOW_V_3_1_PLUS,
52+
AIRFLOW_V_3_2_PLUS,
53+
)
4954

5055
if AIRFLOW_V_3_0_PLUS:
5156
from airflow.models.dag_version import DagVersion
@@ -761,3 +766,51 @@ def test_celery_tasks_registered_on_import():
761766
assert "execute_command" in registered_tasks, (
762767
"execute_command must be registered for Airflow 2.x compatibility."
763768
)
769+
770+
771+
@pytest.mark.skipif(not AIRFLOW_V_3_1_9_PLUS, reason="TaskAlreadyRunningError requires Airflow 3.1.9+")
772+
def test_execute_workload_ignores_already_running_task():
773+
"""Test that execute_workload raises Celery Ignore when task is already running."""
774+
import importlib
775+
776+
from celery.exceptions import Ignore
777+
778+
from airflow.sdk.exceptions import TaskAlreadyRunningError
779+
780+
importlib.reload(celery_executor_utils)
781+
execute_workload_unwrapped = celery_executor_utils.execute_workload.__wrapped__
782+
783+
mock_current_task = mock.MagicMock()
784+
mock_current_task.request.id = "test-celery-task-id"
785+
mock_app = mock.MagicMock()
786+
mock_app.current_task = mock_current_task
787+
788+
with (
789+
mock.patch("airflow.sdk.execution_time.supervisor.supervise") as mock_supervise,
790+
mock.patch.object(celery_executor_utils, "app", mock_app),
791+
):
792+
mock_supervise.side_effect = TaskAlreadyRunningError("Task already running")
793+
794+
workload_json = """
795+
{
796+
"type": "ExecuteTask",
797+
"token": "test-token",
798+
"dag_rel_path": "test_dag.py",
799+
"bundle_info": {"name": "test-bundle", "version": null},
800+
"log_path": "test.log",
801+
"ti": {
802+
"id": "019bdec0-d353-7b68-abe0-5ac20fa75ad0",
803+
"dag_version_id": "019bdead-fdcd-78ab-a9f2-aba3b80fded2",
804+
"task_id": "test_task",
805+
"dag_id": "test_dag",
806+
"run_id": "test_run",
807+
"try_number": 1,
808+
"map_index": -1,
809+
"pool_slots": 1,
810+
"queue": "default",
811+
"priority_weight": 1
812+
}
813+
}
814+
"""
815+
with pytest.raises(Ignore):
816+
execute_workload_unwrapped(workload_json)

0 commit comments

Comments
 (0)