Skip to content

Commit 9e6bf9d

Browse files
Conformance: signals/queries finish signal stays waiting on sdk-python 0.4.42 (#87)
1 parent 4ca3d18 commit 9e6bf9d

4 files changed

Lines changed: 299 additions & 47 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88

99
### Fixed
10+
- Repeated condition-wait openings for the same logical wait now replay through
11+
every matching signal before deciding whether the wait is still pending, so
12+
long-running signal/query workflows do not get stuck on the first signal.
1013
- Signal and update receivers recorded while a condition wait is open now
1114
replay at that specific wait, so later signal-driven waits are not satisfied
1215
or consumed too early when no activity or timer result separates them.

src/durable_workflow/workflow.py

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,49 @@ def _apply_condition_wait_receivers(condition_wait_id: str | None) -> None:
21202120
break
21212121
_apply_receiver(pending_receivers.pop(0))
21222122

2123+
def _condition_wait_mismatch(opened: Mapping[str, Any], cmd: WaitCondition) -> FailWorkflow | None:
2124+
opened_key = opened.get("condition_key")
2125+
if isinstance(opened_key, str) and opened_key != (cmd.condition_key or ""):
2126+
return FailWorkflow(
2127+
message=(
2128+
"wait_condition key changed during replay: "
2129+
f"history has {opened_key!r}, workflow yielded "
2130+
f"{cmd.condition_key!r}"
2131+
),
2132+
exception_type="NonDeterministicWaitCondition",
2133+
)
2134+
2135+
opened_fingerprint = opened.get("condition_definition_fingerprint")
2136+
if (
2137+
isinstance(opened_fingerprint, str)
2138+
and cmd.condition_definition_fingerprint != opened_fingerprint
2139+
):
2140+
return FailWorkflow(
2141+
message=(
2142+
"wait_condition predicate fingerprint changed during replay: "
2143+
f"history has {opened_fingerprint!r}, workflow yielded "
2144+
f"{cmd.condition_definition_fingerprint!r}"
2145+
),
2146+
exception_type="NonDeterministicWaitCondition",
2147+
)
2148+
2149+
return None
2150+
2151+
def _same_logical_condition_wait(opened: Mapping[str, Any], cmd: WaitCondition) -> bool:
2152+
if _condition_wait_mismatch(opened, cmd) is not None:
2153+
return False
2154+
2155+
opened_key = opened.get("condition_key")
2156+
if isinstance(opened_key, str) and opened_key and opened_key == (cmd.condition_key or ""):
2157+
return True
2158+
2159+
opened_fingerprint = opened.get("condition_definition_fingerprint")
2160+
return (
2161+
isinstance(opened_fingerprint, str)
2162+
and bool(opened_fingerprint)
2163+
and opened_fingerprint == cmd.condition_definition_fingerprint
2164+
)
2165+
21232166
result_cursor = 0
21242167
gen = instance.run(ctx, *start_input)
21252168
if not hasattr(gen, "__next__"):
@@ -2208,56 +2251,52 @@ def _terminal_state(value: Any, *, include_pending: bool) -> _ReplayState:
22082251
next_value = _version_marker_result(cmd, cmd.version)
22092252
continue
22102253
if isinstance(cmd, WaitCondition):
2211-
resolution: str | None = None
2212-
opened: dict[str, Any] | None = None
2213-
if wait_yield_count < len(wait_opened):
2214-
opened = wait_opened[wait_yield_count]
2215-
opened_id = opened.get("condition_wait_id")
2216-
if isinstance(opened_id, str):
2217-
resolution = wait_resolutions.get(opened_id)
2218-
_apply_condition_wait_receivers(opened_id)
2219-
opened_key = opened.get("condition_key")
2220-
if isinstance(opened_key, str) and opened_key != (cmd.condition_key or ""):
2254+
while True:
2255+
resolution: str | None = None
2256+
opened: dict[str, Any] | None = None
2257+
if wait_yield_count < len(wait_opened):
2258+
opened = wait_opened[wait_yield_count]
2259+
mismatch = _condition_wait_mismatch(opened, cmd)
2260+
if mismatch is not None:
2261+
return _state([mismatch])
2262+
opened_id = opened.get("condition_wait_id")
2263+
if isinstance(opened_id, str):
2264+
resolution = wait_resolutions.get(opened_id)
2265+
_apply_condition_wait_receivers(opened_id)
2266+
if resolution == "timed_out":
2267+
next_value = False
2268+
wait_yield_count += 1
2269+
break
2270+
try:
2271+
satisfied = bool(cmd.predicate())
2272+
except Exception as exc:
22212273
return _state([FailWorkflow(
2222-
message=(
2223-
"wait_condition key changed during replay: "
2224-
f"history has {opened_key!r}, workflow yielded "
2225-
f"{cmd.condition_key!r}"
2226-
),
2227-
exception_type="NonDeterministicWaitCondition",
2274+
message=f"wait_condition predicate raised: {exc}",
2275+
exception_type=type(exc).__name__,
22282276
)])
2229-
opened_fingerprint = opened.get("condition_definition_fingerprint")
2277+
if satisfied or resolution == "satisfied":
2278+
next_value = True
2279+
wait_yield_count += 1
2280+
break
2281+
2282+
next_wait_index = wait_yield_count + 1
2283+
# A single logical wait can be re-opened in history after
2284+
# non-satisfying signals. Consume repeated physical opens
2285+
# with the same key/fingerprint before declaring the
2286+
# logical wait still pending.
22302287
if (
2231-
isinstance(opened_fingerprint, str)
2232-
and cmd.condition_definition_fingerprint != opened_fingerprint
2288+
opened is not None
2289+
and next_wait_index < len(wait_opened)
2290+
and _same_logical_condition_wait(wait_opened[next_wait_index], cmd)
22332291
):
2234-
return _state([FailWorkflow(
2235-
message=(
2236-
"wait_condition predicate fingerprint changed during replay: "
2237-
f"history has {opened_fingerprint!r}, workflow yielded "
2238-
f"{cmd.condition_definition_fingerprint!r}"
2239-
),
2240-
exception_type="NonDeterministicWaitCondition",
2241-
)])
2242-
if resolution == "timed_out":
2243-
next_value = False
2244-
wait_yield_count += 1
2245-
continue
2246-
try:
2247-
satisfied = bool(cmd.predicate())
2248-
except Exception as exc:
2249-
return _state([FailWorkflow(
2250-
message=f"wait_condition predicate raised: {exc}",
2251-
exception_type=type(exc).__name__,
2252-
)])
2253-
if satisfied or resolution == "satisfied":
2254-
next_value = True
2255-
wait_yield_count += 1
2256-
continue
2257-
ctx.logger._set_replaying(False)
2258-
pending.append(cmd)
2259-
wait_yield_count += 1
2260-
return _state(pending)
2292+
wait_yield_count = next_wait_index
2293+
continue
2294+
2295+
ctx.logger._set_replaying(False)
2296+
pending.append(cmd)
2297+
wait_yield_count = next_wait_index
2298+
return _state(pending)
2299+
continue
22612300
if isinstance(cmd, (ScheduleActivity, StartTimer, StartChildWorkflow)):
22622301
if result_cursor < len(resolved_results):
22632302
val = resolved_results[result_cursor]

tests/test_wait_condition.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,36 @@ def run(self, ctx: WorkflowContext): # type: ignore[no-untyped-def]
135135
return self.state()
136136

137137

138+
@workflow.defn(name="signal-counter-until-finished")
139+
class SignalCounterUntilFinished:
140+
def __init__(self) -> None:
141+
self.count = 0
142+
self.done = False
143+
self.events: list[dict[str, object]] = []
144+
145+
@workflow.signal("increment")
146+
def increment(self, amount: int) -> None:
147+
self.count += amount
148+
self.events.append({"signal": "increment", "amount": amount, "count": self.count})
149+
150+
@workflow.signal("finish")
151+
def finish(self) -> None:
152+
self.done = True
153+
self.events.append({"signal": "finish", "count": self.count})
154+
155+
@workflow.query("current")
156+
def current(self) -> dict[str, object]:
157+
return {
158+
"count": self.count,
159+
"done": self.done,
160+
"events": list(self.events),
161+
}
162+
163+
def run(self, ctx: WorkflowContext): # type: ignore[no-untyped-def]
164+
yield ctx.wait_condition(lambda: self.done, key="done")
165+
return self.current()
166+
167+
138168
class TestCtxWaitCondition:
139169
def test_wait_condition_returns_dataclass_with_predicate_and_key(self) -> None:
140170
ctx = WorkflowContext(run_id="x")
@@ -321,6 +351,99 @@ def test_signals_scope_to_the_condition_wait_open_when_they_arrived(self) -> Non
321351
assert outcome.commands[0].result == expected
322352
assert query_state(TwoStageSignalWait, history, [], "state") == expected
323353

354+
def test_repeated_physical_waits_for_one_logical_condition_replay_all_signals(self) -> None:
355+
history = [
356+
{
357+
"event_type": "ConditionWaitOpened",
358+
"payload": {"condition_wait_id": "wait-count-3", "condition_key": "done"},
359+
},
360+
_signal_received_event("increment", [3]),
361+
{
362+
"event_type": "ConditionWaitOpened",
363+
"payload": {"condition_wait_id": "wait-count-8", "condition_key": "done"},
364+
},
365+
_signal_received_event("increment", [5]),
366+
{
367+
"event_type": "ConditionWaitOpened",
368+
"payload": {"condition_wait_id": "wait-finish", "condition_key": "done"},
369+
},
370+
_signal_received_event("finish", []),
371+
]
372+
373+
expected = {
374+
"count": 8,
375+
"done": True,
376+
"events": [
377+
{"signal": "increment", "amount": 3, "count": 3},
378+
{"signal": "increment", "amount": 5, "count": 8},
379+
{"signal": "finish", "count": 8},
380+
],
381+
}
382+
383+
outcome = replay(SignalCounterUntilFinished, history, [])
384+
385+
assert len(outcome.commands) == 1
386+
assert isinstance(outcome.commands[0], CompleteWorkflow)
387+
assert outcome.commands[0].result == expected
388+
assert query_state(SignalCounterUntilFinished, history, [], "current") == expected
389+
390+
def test_repeated_physical_waits_keep_query_state_current_before_finish(self) -> None:
391+
history = [
392+
{
393+
"event_type": "ConditionWaitOpened",
394+
"payload": {"condition_wait_id": "wait-count-3", "condition_key": "done"},
395+
},
396+
_signal_received_event("increment", [3]),
397+
{
398+
"event_type": "ConditionWaitOpened",
399+
"payload": {"condition_wait_id": "wait-count-8", "condition_key": "done"},
400+
},
401+
_signal_received_event("increment", [5]),
402+
]
403+
404+
outcome = replay(SignalCounterUntilFinished, history, [])
405+
406+
assert len(outcome.commands) == 1
407+
assert isinstance(outcome.commands[0], WaitCondition)
408+
assert query_state(SignalCounterUntilFinished, history, [], "current") == {
409+
"count": 8,
410+
"done": False,
411+
"events": [
412+
{"signal": "increment", "amount": 3, "count": 3},
413+
{"signal": "increment", "amount": 5, "count": 8},
414+
],
415+
}
416+
417+
def test_repeated_wait_after_activity_can_be_satisfied_by_later_signal(self) -> None:
418+
history = [
419+
{
420+
"event_type": "ActivityCompleted",
421+
"payload": {"result": '"loaded"'},
422+
},
423+
{
424+
"event_type": "ConditionWaitOpened",
425+
"payload": {"condition_wait_id": "wait-before-query", "condition_key": "approval"},
426+
},
427+
{
428+
"event_type": "ConditionWaitOpened",
429+
"payload": {"condition_wait_id": "wait-after-query", "condition_key": "approval"},
430+
},
431+
_signal_received_event("approve", ["alice"]),
432+
]
433+
434+
outcome = replay(ActivityThenWaitForApproval, history, [])
435+
436+
assert len(outcome.commands) == 1
437+
assert isinstance(outcome.commands[0], CompleteWorkflow)
438+
assert outcome.commands[0].result == {
439+
"activity_result": "loaded",
440+
"approved_by": "alice",
441+
}
442+
assert query_state(ActivityThenWaitForApproval, history, [], "state") == {
443+
"activity_result": "loaded",
444+
"approved_by": "alice",
445+
}
446+
324447
def test_open_with_no_resolution_and_predicate_false_re_emits_wait_condition(self) -> None:
325448
history = [
326449
{

tests/test_worker.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def current(self) -> int:
9393
return self.count
9494

9595
def run(self, ctx): # type: ignore[no-untyped-def]
96-
yield ctx.wait_condition(lambda: False)
96+
yield ctx.wait_condition(lambda: False, key="done")
9797

9898

9999
@workflow.defn(name="async-query-wf")
@@ -758,6 +758,93 @@ async def test_query_task_replays_signal_arguments_from_history_export(
758758
)
759759
mock_client.fail_query_task.assert_not_called()
760760

761+
@pytest.mark.asyncio
762+
async def test_query_task_replays_repeated_condition_wait_signal_arguments(
763+
self, mock_client: AsyncMock
764+
) -> None:
765+
worker = Worker(mock_client, task_queue="q1", workflows=[CounterQueryWorkflow], activities=[])
766+
first_signal_arguments = serializer.encode([3], codec="json")
767+
second_signal_arguments = serializer.encode([5], codec="json")
768+
task = {
769+
"query_task_id": "qt-repeated-wait-signals",
770+
"query_task_attempt": 1,
771+
"workflow_type": "counter-query-wf",
772+
"workflow_id": "wf-counter",
773+
"run_id": "run-counter",
774+
"query_name": "current",
775+
"history_events": [
776+
{
777+
"event_type": "ConditionWaitOpened",
778+
"payload": {
779+
"condition_wait_id": "wait-count-3",
780+
"condition_key": "done",
781+
},
782+
},
783+
{
784+
"event_type": "SignalReceived",
785+
"workflow_command_id": "cmd-increment-3",
786+
"payload": {
787+
"signal_id": "sig-increment-3",
788+
"workflow_command_id": "cmd-increment-3",
789+
"signal_name": "increment",
790+
},
791+
},
792+
{
793+
"event_type": "ConditionWaitOpened",
794+
"payload": {
795+
"condition_wait_id": "wait-count-8",
796+
"condition_key": "done",
797+
},
798+
},
799+
{
800+
"event_type": "SignalReceived",
801+
"workflow_command_id": "cmd-increment-5",
802+
"payload": {
803+
"signal_id": "sig-increment-5",
804+
"workflow_command_id": "cmd-increment-5",
805+
"signal_name": "increment",
806+
},
807+
},
808+
],
809+
"history_export": {
810+
"payloads": {"codec": "json"},
811+
"signals": [
812+
{
813+
"id": "sig-increment-3",
814+
"command_id": "cmd-increment-3",
815+
"name": "increment",
816+
"payload_codec": "json",
817+
"arguments": first_signal_arguments,
818+
},
819+
{
820+
"id": "sig-increment-5",
821+
"command_id": "cmd-increment-5",
822+
"name": "increment",
823+
"payload_codec": "json",
824+
"arguments": second_signal_arguments,
825+
},
826+
],
827+
},
828+
"workflow_arguments": serializer.envelope([], codec="json"),
829+
"query_arguments": serializer.envelope([], codec="json"),
830+
"payload_codec": "json",
831+
}
832+
833+
outcome = await worker._run_query_task(task)
834+
835+
assert outcome == "completed"
836+
mock_client.complete_query_task.assert_awaited_once_with(
837+
query_task_id="qt-repeated-wait-signals",
838+
lease_owner=worker.worker_id,
839+
query_task_attempt=1,
840+
result=8,
841+
codec="json",
842+
workflow_id="wf-counter",
843+
run_id="run-counter",
844+
query_name="current",
845+
)
846+
mock_client.fail_query_task.assert_not_called()
847+
761848
@pytest.mark.asyncio
762849
async def test_query_task_awaits_async_query_result(self, mock_client: AsyncMock) -> None:
763850
worker = Worker(mock_client, task_queue="q1", workflows=[AsyncQueryWorkflow], activities=[])

0 commit comments

Comments
 (0)