Skip to content

Commit 8c5f7ec

Browse files
Respect run payload codec for workflow commands
1 parent 23168ea commit 8c5f7ec

6 files changed

Lines changed: 72 additions & 22 deletions

File tree

src/durable_workflow/worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
log = logging.getLogger("durable_workflow.worker")
2424

2525

26+
def _command_payload_codec(codec: object) -> str:
27+
return codec if isinstance(codec, str) and codec in serializer.SUPPORTED_CODECS else serializer.AVRO_CODEC
28+
29+
2630
def _activity_name(fn: Callable[..., Any]) -> str:
2731
return getattr(fn, "__activity_name__", fn.__name__)
2832

@@ -172,6 +176,7 @@ async def _run_workflow_task(self, task: dict[str, Any]) -> list[dict[str, Any]]
172176

173177
start_input: list[Any] = []
174178
codec = task.get("payload_codec")
179+
command_codec = _command_payload_codec(codec)
175180
raw_args = task.get("arguments")
176181
try:
177182
decoded = serializer.decode_envelope(raw_args, codec=codec)
@@ -262,7 +267,10 @@ async def _run_workflow_task(self, task: dict[str, Any]) -> list[dict[str, Any]]
262267
log.warning("failed to report replay failure: %s", fe)
263268
return None
264269

265-
commands = [c.to_server_command(self.task_queue) for c in outcome.commands]
270+
commands = [
271+
c.to_server_command(self.task_queue, payload_codec=command_codec)
272+
for c in outcome.commands
273+
]
266274
log.info(
267275
"completing workflow task %s with %d command(s): %s",
268276
task_id,

src/durable_workflow/workflow.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ class ScheduleActivity:
3636
arguments: list[Any]
3737
queue: str | None = None
3838

39-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
39+
def to_server_command(
40+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
41+
) -> dict[str, Any]:
4042
return {
4143
"type": "schedule_activity",
4244
"activity_type": self.activity_type,
43-
"arguments": serializer.envelope(self.arguments),
45+
"arguments": serializer.envelope(self.arguments, codec=payload_codec),
4446
"queue": self.queue or task_queue,
4547
}
4648

@@ -49,7 +51,9 @@ def to_server_command(self, task_queue: str) -> dict[str, Any]:
4951
class StartTimer:
5052
delay_seconds: int
5153

52-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
54+
def to_server_command(
55+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
56+
) -> dict[str, Any]:
5357
return {
5458
"type": "start_timer",
5559
"delay_seconds": self.delay_seconds,
@@ -60,10 +64,12 @@ def to_server_command(self, task_queue: str) -> dict[str, Any]:
6064
class CompleteWorkflow:
6165
result: Any
6266

63-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
67+
def to_server_command(
68+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
69+
) -> dict[str, Any]:
6470
return {
6571
"type": "complete_workflow",
66-
"result": serializer.envelope(self.result),
72+
"result": serializer.envelope(self.result, codec=payload_codec),
6773
}
6874

6975

@@ -73,7 +79,9 @@ class FailWorkflow:
7379
exception_type: str | None = None
7480
non_retryable: bool = False
7581

76-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
82+
def to_server_command(
83+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
84+
) -> dict[str, Any]:
7785
cmd: dict[str, Any] = {
7886
"type": "fail_workflow",
7987
"message": self.message,
@@ -91,11 +99,13 @@ class ContinueAsNew:
9199
arguments: list[Any] = field(default_factory=list)
92100
task_queue: str | None = None
93101

94-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
102+
def to_server_command(
103+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
104+
) -> dict[str, Any]:
95105
cmd: dict[str, Any] = {"type": "continue_as_new"}
96106
if self.workflow_type is not None:
97107
cmd["workflow_type"] = self.workflow_type
98-
cmd["arguments"] = serializer.envelope(self.arguments)
108+
cmd["arguments"] = serializer.envelope(self.arguments, codec=payload_codec)
99109
cmd["queue"] = self.task_queue or task_queue
100110
return cmd
101111

@@ -104,10 +114,12 @@ def to_server_command(self, task_queue: str) -> dict[str, Any]:
104114
class RecordSideEffect:
105115
result: Any
106116

107-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
117+
def to_server_command(
118+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
119+
) -> dict[str, Any]:
108120
return {
109121
"type": "record_side_effect",
110-
"result": serializer.envelope(self.result),
122+
"result": serializer.encode(self.result, codec=payload_codec),
111123
}
112124

113125

@@ -118,11 +130,13 @@ class StartChildWorkflow:
118130
task_queue: str | None = None
119131
parent_close_policy: str | None = None
120132

121-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
133+
def to_server_command(
134+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
135+
) -> dict[str, Any]:
122136
cmd: dict[str, Any] = {
123137
"type": "start_child_workflow",
124138
"workflow_type": self.workflow_type,
125-
"arguments": serializer.envelope(self.arguments),
139+
"arguments": serializer.envelope(self.arguments, codec=payload_codec),
126140
}
127141
if self.task_queue is not None:
128142
cmd["queue"] = self.task_queue
@@ -140,7 +154,9 @@ class RecordVersionMarker:
140154
min_supported: int
141155
max_supported: int
142156

143-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
157+
def to_server_command(
158+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
159+
) -> dict[str, Any]:
144160
return {
145161
"type": "record_version_marker",
146162
"change_id": self.change_id,
@@ -154,7 +170,9 @@ def to_server_command(self, task_queue: str) -> dict[str, Any]:
154170
class UpsertSearchAttributes:
155171
attributes: dict[str, Any]
156172

157-
def to_server_command(self, task_queue: str) -> dict[str, Any]:
173+
def to_server_command(
174+
self, task_queue: str, *, payload_codec: str = serializer.AVRO_CODEC
175+
) -> dict[str, Any]:
158176
return {
159177
"type": "upsert_search_attributes",
160178
"attributes": self.attributes,

tests/integration/test_polyglot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def test_python_workflow_calls_php_activity(server_url: str, server_token:
8989
outcome = replay(PolyglotPythonWorkflow, history, start_input, run_id=wf_task.get("run_id", ""))
9090
assert len(outcome.commands) == 1
9191
cmd = outcome.commands[0]
92-
server_cmd = cmd.to_server_command(task_queue)
92+
server_cmd = cmd.to_server_command(task_queue, payload_codec=codec)
9393
assert server_cmd["type"] == "schedule_activity"
9494
assert server_cmd["activity_type"] == "tests.polyglot.php-activity"
9595

@@ -166,7 +166,8 @@ async def test_python_workflow_calls_php_activity(server_url: str, server_token:
166166
outcome2 = replay(PolyglotPythonWorkflow, history2, start_input2, run_id=wf_task2.get("run_id", ""))
167167
assert len(outcome2.commands) == 1
168168
cmd2 = outcome2.commands[0]
169-
server_cmd2 = cmd2.to_server_command(task_queue)
169+
codec2 = wf_task2.get("payload_codec")
170+
server_cmd2 = cmd2.to_server_command(task_queue, payload_codec=codec2)
170171

171172
# Debug: inspect command before asserting
172173
print("\n=== Replay outcome ===")

tests/integration/test_smoke.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def test_greeter_workflow_end_to_end(server_url: str, server_token: str) -
8787
outcome = replay(SmokeGreeterWorkflow, history, start_input, run_id=wf_task.get("run_id", ""))
8888
assert len(outcome.commands) == 1
8989
cmd = outcome.commands[0]
90-
server_cmd = cmd.to_server_command(task_queue)
90+
server_cmd = cmd.to_server_command(task_queue, payload_codec=codec)
9191
assert server_cmd["type"] == "schedule_activity"
9292

9393
# 5. Complete workflow task with ScheduleActivity command
@@ -137,7 +137,8 @@ async def test_greeter_workflow_end_to_end(server_url: str, server_token: str) -
137137
outcome2 = replay(SmokeGreeterWorkflow, history2, start_input2, run_id=wf_task2.get("run_id", ""))
138138
assert len(outcome2.commands) == 1
139139
cmd2 = outcome2.commands[0]
140-
server_cmd2 = cmd2.to_server_command(task_queue)
140+
codec2 = wf_task2.get("payload_codec")
141+
server_cmd2 = cmd2.to_server_command(task_queue, payload_codec=codec2)
141142
assert server_cmd2["type"] == "complete_workflow"
142143

143144
# 11. Complete the final workflow task

tests/test_replay.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ def test_server_command_shape(self) -> None:
102102
assert server_cmd["type"] == "schedule_activity"
103103
assert server_cmd["activity_type"] == "greet"
104104
assert server_cmd["queue"] == "default-queue"
105+
assert server_cmd["arguments"]["codec"] == "avro"
106+
107+
def test_server_command_uses_payload_codec(self) -> None:
108+
outcome = replay(OneActivity, [], ["world"])
109+
cmd = outcome.commands[0]
110+
server_cmd = cmd.to_server_command("default-queue", payload_codec="json")
111+
assert server_cmd["arguments"]["codec"] == "json"
112+
assert serializer.decode(server_cmd["arguments"]["blob"], codec="json") == ["world"]
105113

106114

107115
class TestTwoActivities:
@@ -188,6 +196,12 @@ def test_server_command(self) -> None:
188196
assert server_cmd["result"]["codec"] == "avro"
189197
assert serializer.decode(server_cmd["result"]["blob"], codec="avro") == {"key": "val"}
190198

199+
def test_server_command_uses_payload_codec(self) -> None:
200+
cmd = CompleteWorkflow(result={"key": "val"})
201+
server_cmd = cmd.to_server_command("q", payload_codec="json")
202+
assert server_cmd["result"]["codec"] == "json"
203+
assert serializer.decode(server_cmd["result"]["blob"], codec="json") == {"key": "val"}
204+
191205

192206
@workflow.defn(name="continue-as-new-wf")
193207
class ContinueAsNewWorkflow:
@@ -289,8 +303,12 @@ def test_server_command_shape(self) -> None:
289303
cmd = RecordSideEffect(result={"key": "val"})
290304
sc = cmd.to_server_command("q")
291305
assert sc["type"] == "record_side_effect"
292-
assert sc["result"]["codec"] == "avro"
293-
assert serializer.decode(sc["result"]["blob"], codec="avro") == {"key": "val"}
306+
assert serializer.decode(sc["result"], codec="avro") == {"key": "val"}
307+
308+
def test_server_command_uses_payload_codec(self) -> None:
309+
cmd = RecordSideEffect(result={"key": "val"})
310+
sc = cmd.to_server_command("q", payload_codec="json")
311+
assert serializer.decode(sc["result"], codec="json") == {"key": "val"}
294312

295313

296314
class TestWorkflowContext:

tests/test_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from durable_workflow import activity, workflow
9+
from durable_workflow import activity, serializer, workflow
1010
from durable_workflow.client import (
1111
CONTROL_PLANE_REQUEST_CONTRACT_SCHEMA,
1212
CONTROL_PLANE_REQUEST_CONTRACT_VERSION,
@@ -165,6 +165,8 @@ async def test_schedule_activity_on_first_replay(self, mock_client: AsyncMock) -
165165
assert len(commands) == 1
166166
assert commands[0]["type"] == "schedule_activity"
167167
assert commands[0]["activity_type"] == "test-act"
168+
assert commands[0]["arguments"]["codec"] == "json"
169+
assert serializer.decode(commands[0]["arguments"]["blob"], codec="json") == ["hello"]
168170

169171
@pytest.mark.asyncio
170172
async def test_complete_on_resolved_activity(self, mock_client: AsyncMock) -> None:
@@ -183,6 +185,8 @@ async def test_complete_on_resolved_activity(self, mock_client: AsyncMock) -> No
183185
mock_client.complete_workflow_task.assert_called_once()
184186
commands = mock_client.complete_workflow_task.call_args.kwargs["commands"]
185187
assert commands[0]["type"] == "complete_workflow"
188+
assert commands[0]["result"]["codec"] == "json"
189+
assert serializer.decode(commands[0]["result"]["blob"], codec="json") == "done"
186190

187191
@pytest.mark.asyncio
188192
async def test_unknown_workflow_type_fails_task(self, mock_client: AsyncMock) -> None:

0 commit comments

Comments
 (0)