Skip to content

Commit 346dcc3

Browse files
fix: call task.uncancel() after catching CancelledError in shield loops (Python 3.11+) (#1523)
* fix: call task.uncancel() after catching CancelledError in shield loops On Python 3.11+, asyncio.Task tracks a cancellation counter via Task.cancelling()/Task.uncancel(). When CancelledError is caught inside a while-True/asyncio.shield loop without calling uncancel(), the counter stays elevated and Python re-throws CancelledError at every subsequent await. This causes: 1. Duplicate commands (e.g. RequestCancelExternalWorkflow) sent to the Temporal server 2. Spurious ERROR-level 'exception in shielded future' log lines from temporalio.worker._workflow_instance The fix adds task.uncancel() (guarded by hasattr for Python <=3.10 compatibility) after each CancelledError catch in all 6 affected shield loops: - run_activity() in _outbound_schedule_activity - run_child() in _outbound_start_child_workflow - start-wait loop in _outbound_start_child_workflow - operation_handle_fn() in _outbound_start_nexus_operation - start-wait loop in _outbound_start_nexus_operation - _signal_external_workflow Fixes #1504 * test: add coverage for task.uncancel() in shield loops Add three integration tests to verify the fix for the elevated cancellation counter issue on Python 3.11+ (cpython#93453): - test_workflow_uncancel_shield_activity: Verifies that cancelling a shielded activity does not produce duplicate cancel commands or spurious 'exception in shielded future' error logs. - test_workflow_uncancel_shield_child_workflow: Verifies that cancelling a shielded child workflow task produces exactly one RequestCancelExternalWorkflowExecution event in history (not duplicates from the elevated cancellation counter). - test_workflow_uncancel_shield_signal_external: Verifies that signalling an external workflow completes without spurious error logs from the shield loop. Addresses review comment on PR #1523. * test: add coverage for nexus operation shield loops * style: apply ruff formatting fixes * fix: use sys.version_info guard for task.uncancel() to satisfy type checkers Replace hasattr(t, 'uncancel') with sys.version_info >= (3, 11) guard, which is the established pattern in this file for version-specific APIs. This satisfies pyright/mypy on Python 3.10 where asyncio.Task lacks uncancel(). --------- Co-authored-by: tconley1428 <tconley1428@gmail.com>
1 parent 7ea54e6 commit 346dcc3

3 files changed

Lines changed: 328 additions & 47 deletions

File tree

temporalio/worker/_workflow_instance.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,6 +1924,13 @@ async def run_activity() -> Any:
19241924
raise
19251925
# Send a cancel request to the activity
19261926
handle._apply_cancel_command(self._add_command())
1927+
# Clear the cancellation counter on Python 3.11+ so the
1928+
# next await does not immediately re-raise CancelledError
1929+
if (
1930+
sys.version_info >= (3, 11)
1931+
and (t := asyncio.current_task()) is not None
1932+
):
1933+
t.uncancel() # type: ignore[union-attr]
19271934

19281935
# Create the handle and set as pending
19291936
handle = _ActivityHandle(self, input, run_activity())
@@ -2008,6 +2015,13 @@ async def run_child() -> Any:
20082015
return await asyncio.shield(handle._result_fut)
20092016
except asyncio.CancelledError:
20102017
apply_child_cancel_error()
2018+
# Clear the cancellation counter on Python 3.11+ so the
2019+
# next await does not immediately re-raise CancelledError
2020+
if (
2021+
sys.version_info >= (3, 11)
2022+
and (t := asyncio.current_task()) is not None
2023+
):
2024+
t.uncancel() # type: ignore[union-attr]
20112025

20122026
# Create the handle and set as pending
20132027
handle = _ChildWorkflowHandle(
@@ -2025,6 +2039,13 @@ async def run_child() -> Any:
20252039
return handle
20262040
except asyncio.CancelledError:
20272041
apply_child_cancel_error()
2042+
# Clear the cancellation counter on Python 3.11+ so the
2043+
# next await does not immediately re-raise CancelledError
2044+
if (
2045+
sys.version_info >= (3, 11)
2046+
and (t := asyncio.current_task()) is not None
2047+
):
2048+
t.uncancel() # type: ignore[union-attr]
20282049
if self._cancel_requested:
20292050
raise
20302051

@@ -2053,6 +2074,13 @@ async def operation_handle_fn() -> OutputT:
20532074
except asyncio.CancelledError:
20542075
cancel_command = self._add_command()
20552076
handle._apply_cancel_command(cancel_command)
2077+
# Clear the cancellation counter on Python 3.11+ so the
2078+
# next await does not immediately re-raise CancelledError
2079+
if (
2080+
sys.version_info >= (3, 11)
2081+
and (t := asyncio.current_task()) is not None
2082+
):
2083+
t.uncancel() # type: ignore[union-attr]
20562084

20572085
handle = _NexusOperationHandle(
20582086
self, self._next_seq("nexus_operation"), input, operation_handle_fn()
@@ -2067,6 +2095,13 @@ async def operation_handle_fn() -> OutputT:
20672095
except asyncio.CancelledError:
20682096
cancel_command = self._add_command()
20692097
handle._apply_cancel_command(cancel_command)
2098+
# Clear the cancellation counter on Python 3.11+ so the
2099+
# next await does not immediately re-raise CancelledError
2100+
if (
2101+
sys.version_info >= (3, 11)
2102+
and (t := asyncio.current_task()) is not None
2103+
):
2104+
t.uncancel() # type: ignore[union-attr]
20702105
if self._cancel_requested:
20712106
raise
20722107

@@ -2599,6 +2634,13 @@ async def _signal_external_workflow(
25992634
except asyncio.CancelledError:
26002635
cancel_command = self._add_command()
26012636
cancel_command.cancel_signal_workflow.seq = seq
2637+
# Clear the cancellation counter on Python 3.11+ so the
2638+
# next await does not immediately re-raise CancelledError
2639+
if (
2640+
sys.version_info >= (3, 11)
2641+
and (t := asyncio.current_task()) is not None
2642+
):
2643+
t.uncancel() # type: ignore[union-attr]
26022644

26032645
def _stack_trace(self) -> str:
26042646
stacks = []

tests/nexus/test_workflow_caller_cancellation_types.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import uuid
34
from dataclasses import dataclass, field
45
from datetime import datetime, timezone
@@ -9,6 +10,7 @@
910
import pytest
1011

1112
import temporalio.nexus._operation_handlers
13+
import temporalio.worker._workflow_instance
1214
from temporalio import exceptions, nexus, workflow
1315
from temporalio.api.enums.v1 import EventType
1416
from temporalio.client import (
@@ -20,7 +22,7 @@
2022
from temporalio.common import WorkflowIDConflictPolicy
2123
from temporalio.testing import WorkflowEnvironment
2224
from temporalio.worker import Worker
23-
from tests.helpers import assert_eventually
25+
from tests.helpers import LogCapturer, assert_eventually
2426
from tests.helpers.nexus import make_nexus_endpoint_name
2527

2628

@@ -268,54 +270,72 @@ async def test_cancellation_type(
268270

269271
client = env.client
270272

271-
async with Worker(
272-
client,
273-
task_queue=str(uuid.uuid4()),
274-
workflows=[CallerWorkflow, HandlerWorkflow],
275-
nexus_service_handlers=[ServiceHandler()],
276-
) as worker:
277-
await env.create_nexus_endpoint(
278-
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
279-
)
273+
log_capturer = LogCapturer()
274+
with log_capturer.logs_captured(
275+
temporalio.worker._workflow_instance.logger, level=logging.WARNING
276+
):
277+
async with Worker(
278+
client,
279+
task_queue=str(uuid.uuid4()),
280+
workflows=[CallerWorkflow, HandlerWorkflow],
281+
nexus_service_handlers=[ServiceHandler()],
282+
) as worker:
283+
await env.create_nexus_endpoint(
284+
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
285+
)
280286

281-
# Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op
282-
# token
283-
with_start_workflow = WithStartWorkflowOperation(
284-
CallerWorkflow.run,
285-
Input(
286-
endpoint=make_nexus_endpoint_name(worker.task_queue),
287-
cancellation_type=cancellation_type,
288-
),
289-
id=test_context.caller_workflow_id,
290-
task_queue=worker.task_queue,
291-
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
292-
)
287+
# Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op
288+
# token
289+
with_start_workflow = WithStartWorkflowOperation(
290+
CallerWorkflow.run,
291+
Input(
292+
endpoint=make_nexus_endpoint_name(worker.task_queue),
293+
cancellation_type=cancellation_type,
294+
),
295+
id=test_context.caller_workflow_id,
296+
task_queue=worker.task_queue,
297+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
298+
)
293299

294-
operation_token = await client.execute_update_with_start_workflow(
295-
CallerWorkflow.get_operation_token,
296-
start_workflow_operation=with_start_workflow,
297-
)
298-
handler_wf = (
299-
nexus.WorkflowHandle[None]
300-
.from_token(operation_token)
301-
._to_client_workflow_handle(client)
302-
)
303-
caller_wf = await with_start_workflow.workflow_handle()
304-
305-
if cancellation_type == workflow.NexusOperationCancellationType.ABANDON:
306-
await check_behavior_for_abandon(caller_wf, handler_wf)
307-
elif cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL:
308-
await check_behavior_for_try_cancel(caller_wf, handler_wf)
309-
elif (
310-
cancellation_type == workflow.NexusOperationCancellationType.WAIT_REQUESTED
311-
):
312-
await check_behavior_for_wait_cancellation_requested(caller_wf, handler_wf)
313-
elif (
314-
cancellation_type == workflow.NexusOperationCancellationType.WAIT_COMPLETED
315-
):
316-
await check_behavior_for_wait_cancellation_completed(caller_wf, handler_wf)
317-
else:
318-
pytest.fail(f"Invalid cancellation type: {cancellation_type}")
300+
operation_token = await client.execute_update_with_start_workflow(
301+
CallerWorkflow.get_operation_token,
302+
start_workflow_operation=with_start_workflow,
303+
)
304+
handler_wf = (
305+
nexus.WorkflowHandle[None]
306+
.from_token(operation_token)
307+
._to_client_workflow_handle(client)
308+
)
309+
caller_wf = await with_start_workflow.workflow_handle()
310+
311+
if cancellation_type == workflow.NexusOperationCancellationType.ABANDON:
312+
await check_behavior_for_abandon(caller_wf, handler_wf)
313+
elif (
314+
cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL
315+
):
316+
await check_behavior_for_try_cancel(caller_wf, handler_wf)
317+
elif (
318+
cancellation_type
319+
== workflow.NexusOperationCancellationType.WAIT_REQUESTED
320+
):
321+
await check_behavior_for_wait_cancellation_requested(
322+
caller_wf, handler_wf
323+
)
324+
elif (
325+
cancellation_type
326+
== workflow.NexusOperationCancellationType.WAIT_COMPLETED
327+
):
328+
await check_behavior_for_wait_cancellation_completed(
329+
caller_wf, handler_wf
330+
)
331+
else:
332+
pytest.fail(f"Invalid cancellation type: {cancellation_type}")
333+
334+
# Verify no spurious "exception in shielded future" error logs
335+
shielded_err = log_capturer.find_log("exception in shielded future")
336+
assert shielded_err is None, (
337+
f"Unexpected 'exception in shielded future' log: {shielded_err}"
338+
)
319339

320340

321341
async def check_behavior_for_abandon(

0 commit comments

Comments
 (0)