Skip to content

Commit e3f4f4d

Browse files
Conformance: CLI signal arguments are missing during SDK query replay (#81)
1 parent 982b944 commit e3f4f4d

2 files changed

Lines changed: 209 additions & 2 deletions

File tree

src/durable_workflow/worker.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import time
2626
import traceback
2727
import uuid
28-
from collections.abc import Awaitable, Callable, Iterable
28+
from collections.abc import Awaitable, Callable, Iterable, Mapping
2929
from types import FunctionType
3030
from typing import Any
3131

@@ -111,6 +111,136 @@ def _is_final_query_task_rejection(error: BaseException) -> bool:
111111
)
112112

113113

114+
def _signal_arguments_envelope_from_export(
115+
signal: Mapping[str, Any],
116+
*,
117+
default_codec: str | None,
118+
) -> dict[str, Any] | None:
119+
raw_arguments = signal.get("arguments")
120+
if raw_arguments is None:
121+
return None
122+
123+
codec = signal.get("payload_codec")
124+
if not isinstance(codec, str) or codec == "":
125+
codec = default_codec or serializer.AVRO_CODEC
126+
127+
if isinstance(raw_arguments, str):
128+
if raw_arguments == "":
129+
return None
130+
return {"codec": codec, "blob": raw_arguments}
131+
132+
if not isinstance(raw_arguments, Mapping):
133+
return None
134+
135+
envelope = dict(raw_arguments)
136+
if "blob" not in envelope and "external_storage" not in envelope:
137+
return None
138+
envelope.setdefault("codec", codec)
139+
return envelope
140+
141+
142+
def _query_history_with_export_signal_arguments(
143+
history: Any,
144+
history_export: Any,
145+
*,
146+
default_codec: str | None,
147+
) -> Any:
148+
# Query-task history can carry compact SignalReceived rows; the full
149+
# signal payload bytes are still present in the accompanying export.
150+
if not isinstance(history, list) or not isinstance(history_export, Mapping):
151+
return history
152+
153+
raw_signals = history_export.get("signals")
154+
if not isinstance(raw_signals, list):
155+
return history
156+
157+
export_payloads = history_export.get("payloads")
158+
export_codec = (
159+
export_payloads.get("codec")
160+
if isinstance(export_payloads, Mapping)
161+
else None
162+
)
163+
signal_default_codec = default_codec
164+
if signal_default_codec is None and isinstance(export_codec, str) and export_codec:
165+
signal_default_codec = export_codec
166+
167+
signals_by_id: dict[str, Mapping[str, Any]] = {}
168+
signals_by_command_id: dict[str, Mapping[str, Any]] = {}
169+
signals_by_name: dict[str, list[Mapping[str, Any]]] = {}
170+
for raw_signal in raw_signals:
171+
if not isinstance(raw_signal, Mapping):
172+
continue
173+
envelope = _signal_arguments_envelope_from_export(raw_signal, default_codec=signal_default_codec)
174+
if envelope is None:
175+
continue
176+
signal_id = raw_signal.get("id")
177+
if isinstance(signal_id, str) and signal_id:
178+
signals_by_id[signal_id] = raw_signal
179+
command_id = raw_signal.get("command_id")
180+
if isinstance(command_id, str) and command_id:
181+
signals_by_command_id[command_id] = raw_signal
182+
name = raw_signal.get("name")
183+
if isinstance(name, str) and name:
184+
signals_by_name.setdefault(name, []).append(raw_signal)
185+
186+
if not signals_by_id and not signals_by_command_id and not signals_by_name:
187+
return history
188+
189+
name_offsets: dict[str, int] = {}
190+
enriched: list[Any] = []
191+
changed = False
192+
for raw_event in history:
193+
if not isinstance(raw_event, Mapping):
194+
enriched.append(raw_event)
195+
continue
196+
event_type = raw_event.get("event_type") or raw_event.get("type")
197+
if event_type != "SignalReceived":
198+
enriched.append(raw_event)
199+
continue
200+
raw_payload = raw_event.get("payload")
201+
if not isinstance(raw_payload, Mapping):
202+
enriched.append(dict(raw_event))
203+
continue
204+
if any(raw_payload.get(key) is not None for key in ("value", "input", "arguments")):
205+
enriched.append(dict(raw_event))
206+
continue
207+
208+
signal: Mapping[str, Any] | None = None
209+
signal_id = raw_payload.get("signal_id")
210+
if isinstance(signal_id, str) and signal_id:
211+
signal = signals_by_id.get(signal_id)
212+
if signal is None:
213+
command_id = raw_payload.get("workflow_command_id") or raw_event.get("workflow_command_id")
214+
if isinstance(command_id, str) and command_id:
215+
signal = signals_by_command_id.get(command_id)
216+
if signal is None:
217+
signal_name = raw_payload.get("signal_name")
218+
if isinstance(signal_name, str) and signal_name:
219+
candidates = signals_by_name.get(signal_name, [])
220+
offset = name_offsets.get(signal_name, 0)
221+
if offset < len(candidates):
222+
signal = candidates[offset]
223+
name_offsets[signal_name] = offset + 1
224+
if signal is None:
225+
enriched.append(dict(raw_event))
226+
continue
227+
228+
envelope = _signal_arguments_envelope_from_export(signal, default_codec=signal_default_codec)
229+
if envelope is None:
230+
enriched.append(dict(raw_event))
231+
continue
232+
233+
payload = dict(raw_payload)
234+
payload["arguments"] = envelope
235+
payload.setdefault("payload_codec", envelope.get("codec"))
236+
event = dict(raw_event)
237+
event["payload"] = payload
238+
enriched.append(event)
239+
changed = True
240+
241+
return enriched if changed else history
242+
243+
114244
def _callable_fingerprint_payload(value: object) -> str:
115245
if isinstance(value, staticmethod | classmethod):
116246
value = value.__func__
@@ -948,7 +1078,11 @@ async def _run_query_task_core(self, task: dict[str, Any], *, client: Client | N
9481078
return "failed"
9491079

9501080
result_codec = _command_payload_codec(codec)
951-
history = task.get("history_events", [])
1081+
history = _query_history_with_export_signal_arguments(
1082+
task.get("history_events", []),
1083+
task.get("history_export"),
1084+
default_codec=codec,
1085+
)
9521086

9531087
cls = self.workflows.get(wf_type)
9541088
if cls is None:

tests/test_worker.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,23 @@ def run(self, ctx): # type: ignore[no-untyped-def]
7979
return self.status
8080

8181

82+
@workflow.defn(name="counter-query-wf")
83+
class CounterQueryWorkflow:
84+
def __init__(self) -> None:
85+
self.count = 0
86+
87+
@workflow.signal("increment")
88+
def increment(self, amount: int) -> None:
89+
self.count += amount
90+
91+
@workflow.query("current")
92+
def current(self) -> int:
93+
return self.count
94+
95+
def run(self, ctx): # type: ignore[no-untyped-def]
96+
yield ctx.wait_condition(lambda: False)
97+
98+
8299
@workflow.defn(name="async-query-wf")
83100
class AsyncQueryWorkflow:
84101
@workflow.query("current")
@@ -685,6 +702,62 @@ async def test_query_task_executes_registered_query(self, mock_client: AsyncMock
685702
)
686703
mock_client.fail_query_task.assert_not_called()
687704

705+
@pytest.mark.asyncio
706+
async def test_query_task_replays_signal_arguments_from_history_export(
707+
self, mock_client: AsyncMock
708+
) -> None:
709+
worker = Worker(mock_client, task_queue="q1", workflows=[CounterQueryWorkflow], activities=[])
710+
signal_arguments = serializer.encode([3], codec="json")
711+
task = {
712+
"query_task_id": "qt-signal-export",
713+
"query_task_attempt": 1,
714+
"workflow_type": "counter-query-wf",
715+
"workflow_id": "wf-counter",
716+
"run_id": "run-counter",
717+
"query_name": "current",
718+
"history_events": [
719+
{
720+
"event_type": "SignalReceived",
721+
"workflow_command_id": "cmd-increment",
722+
"payload": {
723+
"signal_id": "sig-increment",
724+
"workflow_command_id": "cmd-increment",
725+
"signal_name": "increment",
726+
},
727+
},
728+
],
729+
"history_export": {
730+
"payloads": {"codec": "json"},
731+
"signals": [
732+
{
733+
"id": "sig-increment",
734+
"command_id": "cmd-increment",
735+
"name": "increment",
736+
"payload_codec": "json",
737+
"arguments": signal_arguments,
738+
},
739+
],
740+
},
741+
"workflow_arguments": serializer.envelope([], codec="json"),
742+
"query_arguments": serializer.envelope([], codec="json"),
743+
"payload_codec": "json",
744+
}
745+
746+
outcome = await worker._run_query_task(task)
747+
748+
assert outcome == "completed"
749+
mock_client.complete_query_task.assert_awaited_once_with(
750+
query_task_id="qt-signal-export",
751+
lease_owner=worker.worker_id,
752+
query_task_attempt=1,
753+
result=3,
754+
codec="json",
755+
workflow_id="wf-counter",
756+
run_id="run-counter",
757+
query_name="current",
758+
)
759+
mock_client.fail_query_task.assert_not_called()
760+
688761
@pytest.mark.asyncio
689762
async def test_query_task_awaits_async_query_result(self, mock_client: AsyncMock) -> None:
690763
worker = Worker(mock_client, task_queue="q1", workflows=[AsyncQueryWorkflow], activities=[])

0 commit comments

Comments
 (0)