Skip to content

Commit 10dde32

Browse files
Conformance: signals/queries finish signal stays waiting on sdk-python 0.4.42 (#86)
1 parent 2e709e9 commit 10dde32

3 files changed

Lines changed: 177 additions & 58 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88

99
### Fixed
10+
- Signal and update receivers recorded while a condition wait is open now
11+
replay at that specific wait, so later signal-driven waits are not satisfied
12+
or consumed too early when no activity or timer result separates them.
1013
- Signal and update receivers recorded after an activity result now replay after
1114
the workflow consumes that activity result, so receiver-mutated state is not
1215
overwritten by deterministic post-activity setup before a `wait_condition`.

src/durable_workflow/workflow.py

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,15 @@ class _ReplayState:
13881388
instance: Any
13891389

13901390

1391+
@dataclass
1392+
class _PendingReceiver:
1393+
result_index: int
1394+
kind: str
1395+
name: str
1396+
args: list[Any]
1397+
condition_wait_id: str | None = None
1398+
1399+
13911400
def _decode_history_result(
13921401
payload: dict[str, Any],
13931402
fallback_codec: str | None,
@@ -1943,10 +1952,14 @@ def _state(commands: list[Command]) -> _ReplayState:
19431952
return _ReplayState(outcome=ReplayOutcome(commands=commands), instance=instance)
19441953

19451954
resolved_results: list[Any] = []
1955+
# External receivers normally apply by resolved-result cursor. Receivers
1956+
# observed while a condition wait is open are pinned to that wait so
1957+
# sequential signal-driven waits do not collapse to the same cursor.
1958+
#
19461959
# (resolved_result_index_before_apply, receiver_kind, name, decoded_args) —
19471960
# external receivers apply before the generator consumes the resolved_result
19481961
# at the stored index, preserving history interleaving with activities.
1949-
pending_receivers: list[tuple[int, str, str, list[Any]]] = []
1962+
pending_receivers: list[_PendingReceiver] = []
19501963
# Ordered ``ConditionWaitOpened`` payloads, used by ``WaitCondition`` yields
19511964
# to match against their corresponding opened wait in history
19521965
# (Nth yield ↔ Nth opened).
@@ -1955,6 +1968,17 @@ def _state(commands: list[Command]) -> _ReplayState:
19551968
# in history, future server-recorded) or 'timed_out' (from a matching
19561969
# condition_timeout TimerFired event).
19571970
wait_resolutions: dict[str, str] = {}
1971+
open_condition_wait_ids: list[str] = []
1972+
1973+
def _current_condition_wait_id() -> str | None:
1974+
return open_condition_wait_ids[-1] if open_condition_wait_ids else None
1975+
1976+
def _close_condition_wait(wait_id: Any) -> None:
1977+
if not isinstance(wait_id, str) or not wait_id:
1978+
return
1979+
with contextlib.suppress(ValueError):
1980+
open_condition_wait_ids.remove(wait_id)
1981+
19581982
for ev in events:
19591983
etype = ev.get("event_type")
19601984
payload = ev.get("payload") or {}
@@ -1975,6 +1999,7 @@ def _state(commands: list[Command]) -> _ReplayState:
19751999
wait_id = payload.get("condition_wait_id")
19762000
if isinstance(wait_id, str) and wait_id:
19772001
wait_resolutions[wait_id] = "timed_out"
2002+
_close_condition_wait(wait_id)
19782003
continue
19792004
if timer_kind == "signal_timeout":
19802005
continue
@@ -1983,14 +2008,17 @@ def _state(commands: list[Command]) -> _ReplayState:
19832008
wait_id = payload.get("condition_wait_id")
19842009
if isinstance(wait_id, str) and wait_id:
19852010
wait_opened.append(dict(payload))
2011+
open_condition_wait_ids.append(wait_id)
19862012
elif etype == "ConditionWaitSatisfied":
19872013
wait_id = payload.get("condition_wait_id")
19882014
if isinstance(wait_id, str) and wait_id:
19892015
wait_resolutions[wait_id] = "satisfied"
2016+
_close_condition_wait(wait_id)
19902017
elif etype == "ConditionWaitTimedOut":
19912018
wait_id = payload.get("condition_wait_id")
19922019
if isinstance(wait_id, str) and wait_id:
19932020
wait_resolutions[wait_id] = "timed_out"
2021+
_close_condition_wait(wait_id)
19942022
elif etype in ("SideEffectRecorded", "ChildRunCompleted"):
19952023
resolved_results.append(
19962024
_decode_history_result(
@@ -2011,73 +2039,86 @@ def _state(commands: list[Command]) -> _ReplayState:
20112039
elif etype == "SignalReceived":
20122040
signal_name = payload.get("signal_name")
20132041
if isinstance(signal_name, str) and signal_name:
2014-
pending_receivers.append(
2015-
(
2016-
len(resolved_results),
2017-
"signal",
2018-
signal_name,
2019-
_decode_receiver_args(
2020-
ev,
2021-
receiver_kind="signal",
2022-
receiver_name=signal_name,
2023-
workflow_id=workflow_id,
2024-
run_id=run_id,
2025-
payload_codec=payload_codec,
2026-
external_storage=external_storage,
2027-
external_storage_cache=external_storage_cache,
2028-
),
2029-
)
2030-
)
2042+
pending_receivers.append(_PendingReceiver(
2043+
result_index=len(resolved_results),
2044+
kind="signal",
2045+
name=signal_name,
2046+
args=_decode_receiver_args(
2047+
ev,
2048+
receiver_kind="signal",
2049+
receiver_name=signal_name,
2050+
workflow_id=workflow_id,
2051+
run_id=run_id,
2052+
payload_codec=payload_codec,
2053+
external_storage=external_storage,
2054+
external_storage_cache=external_storage_cache,
2055+
),
2056+
condition_wait_id=_current_condition_wait_id(),
2057+
))
20312058
elif etype == "UpdateApplied":
20322059
update_name = payload.get("update_name")
20332060
if isinstance(update_name, str) and update_name:
2034-
pending_receivers.append(
2035-
(
2036-
len(resolved_results),
2037-
"update",
2038-
update_name,
2039-
_decode_receiver_args(
2040-
ev,
2041-
receiver_kind="update",
2042-
receiver_name=update_name,
2043-
workflow_id=workflow_id,
2044-
run_id=run_id,
2045-
payload_codec=payload_codec,
2046-
external_storage=external_storage,
2047-
external_storage_cache=external_storage_cache,
2048-
),
2049-
)
2050-
)
2061+
pending_receivers.append(_PendingReceiver(
2062+
result_index=len(resolved_results),
2063+
kind="update",
2064+
name=update_name,
2065+
args=_decode_receiver_args(
2066+
ev,
2067+
receiver_kind="update",
2068+
receiver_name=update_name,
2069+
workflow_id=workflow_id,
2070+
run_id=run_id,
2071+
payload_codec=payload_codec,
2072+
external_storage=external_storage,
2073+
external_storage_cache=external_storage_cache,
2074+
),
2075+
condition_wait_id=_current_condition_wait_id(),
2076+
))
20512077

20522078
signal_registry: dict[str, str] = getattr(workflow_cls, "__workflow_signals__", {}) or {}
20532079
update_registry: dict[str, str] = getattr(workflow_cls, "__workflow_updates__", {}) or {}
20542080

2081+
def _apply_receiver(receiver: _PendingReceiver) -> None:
2082+
if receiver.kind == "signal":
2083+
method_name = signal_registry.get(receiver.name)
2084+
if method_name is None:
2085+
return
2086+
else:
2087+
method_name = update_registry.get(receiver.name)
2088+
if method_name is None:
2089+
raise TypeError(f"unknown update {receiver.name!r} in workflow history")
2090+
handler = getattr(instance, method_name, None)
2091+
if handler is None:
2092+
if receiver.kind == "signal":
2093+
return
2094+
raise TypeError(f"update handler {receiver.name!r} is not available")
2095+
ctx.logger._set_replaying(True)
2096+
handler(*receiver.args)
2097+
2098+
def _receiver_due(receiver: _PendingReceiver, *, before_consuming_result: bool) -> bool:
2099+
if receiver.condition_wait_id is not None:
2100+
return False
2101+
return (
2102+
receiver.result_index < result_cursor
2103+
if before_consuming_result
2104+
else receiver.result_index <= result_cursor
2105+
)
2106+
20552107
def _apply_due_receivers(*, before_consuming_result: bool = False) -> None:
20562108
while pending_receivers:
2057-
receiver_index = pending_receivers[0][0]
2058-
due = (
2059-
receiver_index < result_cursor
2060-
if before_consuming_result
2061-
else receiver_index <= result_cursor
2062-
)
2063-
if not due:
2109+
receiver = pending_receivers[0]
2110+
if not _receiver_due(receiver, before_consuming_result=before_consuming_result):
20642111
break
2065-
_, kind, name, args = pending_receivers.pop(0)
2066-
if kind == "signal":
2067-
method_name = signal_registry.get(name)
2068-
if method_name is None:
2069-
continue
2070-
else:
2071-
method_name = update_registry.get(name)
2072-
if method_name is None:
2073-
raise TypeError(f"unknown update {name!r} in workflow history")
2074-
handler = getattr(instance, method_name, None)
2075-
if handler is None:
2076-
if kind == "signal":
2077-
continue
2078-
raise TypeError(f"update handler {name!r} is not available")
2079-
ctx.logger._set_replaying(True)
2080-
handler(*args)
2112+
_apply_receiver(pending_receivers.pop(0))
2113+
2114+
def _apply_condition_wait_receivers(condition_wait_id: str | None) -> None:
2115+
if condition_wait_id is None:
2116+
return
2117+
while pending_receivers:
2118+
receiver = pending_receivers[0]
2119+
if receiver.condition_wait_id != condition_wait_id:
2120+
break
2121+
_apply_receiver(pending_receivers.pop(0))
20812122

20822123
result_cursor = 0
20832124
gen = instance.run(ctx, *start_input)
@@ -2174,6 +2215,7 @@ def _terminal_state(value: Any, *, include_pending: bool) -> _ReplayState:
21742215
opened_id = opened.get("condition_wait_id")
21752216
if isinstance(opened_id, str):
21762217
resolution = wait_resolutions.get(opened_id)
2218+
_apply_condition_wait_receivers(opened_id)
21772219
opened_key = opened.get("condition_key")
21782220
if isinstance(opened_key, str) and opened_key != (cmd.condition_key or ""):
21792221
return _state([FailWorkflow(

tests/test_wait_condition.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,48 @@ def run(self, ctx: WorkflowContext): # type: ignore[no-untyped-def]
9393
}
9494

9595

96+
@workflow.defn(name="two-stage-signal-wait")
97+
class TwoStageSignalWait:
98+
def __init__(self) -> None:
99+
self.stage = "booting"
100+
self.name: str | None = None
101+
self.finished = False
102+
self.events: list[str] = []
103+
104+
@workflow.signal("advance")
105+
def advance(self, name: str) -> None:
106+
self.name = name
107+
self.events.append(f"signal:{name}")
108+
109+
@workflow.signal("finish")
110+
def finish(self) -> None:
111+
self.finished = True
112+
self.events.append("signal:finish")
113+
114+
@workflow.query("state")
115+
def state(self) -> dict[str, object]:
116+
return {
117+
"stage": self.stage,
118+
"name": self.name,
119+
"finished": self.finished,
120+
"events": list(self.events),
121+
}
122+
123+
def run(self, ctx: WorkflowContext): # type: ignore[no-untyped-def]
124+
self.stage = "waiting-for-advance"
125+
self.events.append("started")
126+
yield ctx.wait_condition(lambda: self.name is not None, key="advance")
127+
128+
self.stage = "waiting-for-finish"
129+
self.finished = False
130+
self.events.append(f"advanced:{self.name}")
131+
yield ctx.wait_condition(lambda: self.finished, key="finish")
132+
133+
self.stage = "completed"
134+
self.events.append("finish")
135+
return self.state()
136+
137+
96138
class TestCtxWaitCondition:
97139
def test_wait_condition_returns_dataclass_with_predicate_and_key(self) -> None:
98140
ctx = WorkflowContext(run_id="x")
@@ -247,6 +289,38 @@ def test_signal_after_activity_result_applies_after_result_is_consumed(self) ->
247289
"approved_by": "alice",
248290
}
249291

292+
def test_signals_scope_to_the_condition_wait_open_when_they_arrived(self) -> None:
293+
history = [
294+
{
295+
"event_type": "ConditionWaitOpened",
296+
"payload": {"condition_wait_id": "wait-advance", "condition_key": "advance"},
297+
},
298+
_signal_received_event("advance", ["Ada"]),
299+
{
300+
"event_type": "ConditionWaitSatisfied",
301+
"payload": {"condition_wait_id": "wait-advance", "condition_key": "advance"},
302+
},
303+
{
304+
"event_type": "ConditionWaitOpened",
305+
"payload": {"condition_wait_id": "wait-finish", "condition_key": "finish"},
306+
},
307+
_signal_received_event("finish", []),
308+
]
309+
310+
expected = {
311+
"stage": "completed",
312+
"name": "Ada",
313+
"finished": True,
314+
"events": ["started", "signal:Ada", "advanced:Ada", "signal:finish", "finish"],
315+
}
316+
317+
outcome = replay(TwoStageSignalWait, history, [])
318+
319+
assert len(outcome.commands) == 1
320+
assert isinstance(outcome.commands[0], CompleteWorkflow)
321+
assert outcome.commands[0].result == expected
322+
assert query_state(TwoStageSignalWait, history, [], "state") == expected
323+
250324
def test_open_with_no_resolution_and_predicate_false_re_emits_wait_condition(self) -> None:
251325
history = [
252326
{

0 commit comments

Comments
 (0)