Skip to content

Commit 19ec7c5

Browse files
TD-P001: Python workers do not prove unsupported payload codecs fail closed (#27)
1 parent a52f132 commit 19ec7c5

2 files changed

Lines changed: 101 additions & 6 deletions

File tree

src/durable_workflow/worker.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def _command_payload_codec(codec: object) -> str:
7070
return codec if isinstance(codec, str) and codec in serializer.SUPPORTED_CODECS else serializer.AVRO_CODEC
7171

7272

73+
def _validate_payload_codec(codec: object) -> str | None:
74+
if codec is None:
75+
return None
76+
if isinstance(codec, str) and codec in serializer.SUPPORTED_CODECS:
77+
return codec
78+
raise ValueError(
79+
f"Unsupported payload codec {codec!r}; this SDK supports {serializer.SUPPORTED_CODECS!r}."
80+
)
81+
82+
7383
def _activity_name(fn: Callable[..., Any]) -> str:
7484
return getattr(fn, "__activity_name__", fn.__name__)
7585

@@ -414,9 +424,9 @@ async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str,
414424

415425
start_input: list[Any] = []
416426
codec = task.get("payload_codec")
417-
command_codec = _command_payload_codec(codec)
418427
raw_args = task.get("arguments")
419428
try:
429+
codec = _validate_payload_codec(codec)
420430
decoded = serializer.decode_envelope(raw_args, codec=codec)
421431
if decoded is not None:
422432
start_input = decoded if isinstance(decoded, list) else [decoded]
@@ -456,6 +466,7 @@ async def _run_workflow_task_core(self, task: dict[str, Any]) -> list[dict[str,
456466
return None
457467

458468
run_id: str = task.get("run_id", "")
469+
command_codec = _command_payload_codec(codec)
459470

460471
cls = self.workflows.get(wf_type)
461472
if cls is None:
@@ -583,9 +594,9 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
583594
activity_type: str = task.get("activity_type", "")
584595
attempt_number: int = task.get("attempt_number", 1)
585596
raw_args = task.get("arguments")
586-
inbound_codec = task.get("payload_codec") or serializer.JSON_CODEC
587-
result_codec = inbound_codec if inbound_codec in serializer.SUPPORTED_CODECS else serializer.AVRO_CODEC
597+
inbound_codec = task.get("payload_codec")
588598
try:
599+
inbound_codec = _validate_payload_codec(inbound_codec) or serializer.JSON_CODEC
589600
args = serializer.decode_envelope(raw_args, codec=inbound_codec) or []
590601
except AvroNotInstalledError as e:
591602
log.exception("activity %s arguments Avro decode failed (avro dependency unavailable)", task_id)
@@ -625,6 +636,7 @@ async def _run_activity_task(self, task: dict[str, Any]) -> str:
625636
return "decode_error"
626637
if not isinstance(args, list):
627638
args = [args]
639+
result_codec = inbound_codec
628640

629641
fn = self.activities.get(activity_type)
630642
if fn is None:
@@ -783,6 +795,19 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
783795
wf_type: str = task.get("workflow_type", "")
784796
query_name: str = task.get("query_name", "")
785797
codec = task.get("payload_codec")
798+
try:
799+
codec = _validate_payload_codec(codec)
800+
except ValueError as e:
801+
await self._fail_query_task(
802+
query_task_id,
803+
attempt,
804+
f"cannot decode query payload with codec {codec!r}: {e}.",
805+
reason="query_payload_decode_failed",
806+
failure_type=type(e).__name__,
807+
stack_trace=traceback.format_exc(),
808+
)
809+
return "failed"
810+
786811
result_codec = _command_payload_codec(codec)
787812
history = task.get("history_events", [])
788813

tests/test_worker.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,7 @@ async def test_workflow_with_envelope_arguments(self, mock_client: AsyncMock) ->
878878

879879

880880
class TestCodecDecodeFailures:
881-
"""TD-P012 / #370 regression: codec decode failures at the task boundary
882-
must turn into a deterministic fail_{workflow,activity}_task call so the
883-
lease does not sit until timeout."""
881+
"""Codec decode failures at the task boundary must fail tasks deterministically."""
884882

885883
@pytest.mark.asyncio
886884
async def test_activity_json_decode_failure_fails_task(self, mock_client: AsyncMock) -> None:
@@ -919,6 +917,29 @@ async def test_activity_avro_decode_failure_fails_task(self, mock_client: AsyncM
919917
assert call_kwargs["non_retryable"] is True
920918
mock_client.complete_activity_task.assert_not_called()
921919

920+
@pytest.mark.asyncio
921+
async def test_activity_unsupported_payload_codec_fails_before_handler(
922+
self, mock_client: AsyncMock
923+
) -> None:
924+
worker = Worker(mock_client, task_queue="q1", workflows=[], activities=[echo_activity])
925+
task = {
926+
"task_id": "at-unsupported-codec",
927+
"activity_attempt_id": "aa-unsupported-codec",
928+
"activity_type": "test-act",
929+
"arguments": {"codec": "json", "blob": '["hello"]'},
930+
"payload_codec": "zstd",
931+
}
932+
933+
outcome = await worker._run_activity_task(task)
934+
935+
assert outcome == "decode_error"
936+
mock_client.fail_activity_task.assert_called_once()
937+
call_kwargs = mock_client.fail_activity_task.call_args.kwargs
938+
assert "Unsupported payload codec 'zstd'" in call_kwargs["message"]
939+
assert call_kwargs["failure_type"] == "ValueError"
940+
assert call_kwargs["non_retryable"] is True
941+
mock_client.complete_activity_task.assert_not_called()
942+
922943
@pytest.mark.asyncio
923944
async def test_activity_avro_missing_dependency_fails_task(
924945
self, mock_client: AsyncMock, monkeypatch: pytest.MonkeyPatch
@@ -968,6 +989,55 @@ async def test_workflow_json_decode_failure_fails_task(self, mock_client: AsyncM
968989
assert "json" in call_kwargs["message"]
969990
mock_client.complete_workflow_task.assert_not_called()
970991

992+
@pytest.mark.asyncio
993+
async def test_workflow_unsupported_payload_codec_fails_before_replay(
994+
self, mock_client: AsyncMock
995+
) -> None:
996+
worker = Worker(mock_client, task_queue="q1", workflows=[TestWorkflow], activities=[])
997+
task = {
998+
"task_id": "t-unsupported-codec",
999+
"workflow_type": "test-wf",
1000+
"workflow_task_attempt": 1,
1001+
"history_events": [],
1002+
"arguments": {"codec": "json", "blob": '["hello"]'},
1003+
"payload_codec": "zstd",
1004+
}
1005+
1006+
commands = await worker._run_workflow_task(task)
1007+
1008+
assert commands is None
1009+
mock_client.fail_workflow_task.assert_called_once()
1010+
call_kwargs = mock_client.fail_workflow_task.call_args.kwargs
1011+
assert "Unsupported payload codec 'zstd'" in call_kwargs["message"]
1012+
assert call_kwargs["failure_type"] == "ValueError"
1013+
mock_client.complete_workflow_task.assert_not_called()
1014+
1015+
@pytest.mark.asyncio
1016+
async def test_query_unsupported_payload_codec_fails_before_query_handler(
1017+
self, mock_client: AsyncMock
1018+
) -> None:
1019+
worker = Worker(mock_client, task_queue="q1", workflows=[QueryWorkflow], activities=[])
1020+
task = {
1021+
"query_task_id": "qt-unsupported-codec",
1022+
"query_task_attempt": 1,
1023+
"workflow_type": "query-wf",
1024+
"query_name": "status",
1025+
"history_events": [],
1026+
"workflow_arguments": serializer.envelope([], codec="json"),
1027+
"query_arguments": serializer.envelope([], codec="json"),
1028+
"payload_codec": "zstd",
1029+
}
1030+
1031+
outcome = await worker._run_query_task(task)
1032+
1033+
assert outcome == "failed"
1034+
mock_client.fail_query_task.assert_called_once()
1035+
call_kwargs = mock_client.fail_query_task.call_args.kwargs
1036+
assert call_kwargs["reason"] == "query_payload_decode_failed"
1037+
assert "Unsupported payload codec 'zstd'" in call_kwargs["message"]
1038+
assert call_kwargs["failure_type"] == "ValueError"
1039+
mock_client.complete_query_task.assert_not_called()
1040+
9711041
@pytest.mark.asyncio
9721042
async def test_workflow_avro_missing_dependency_fails_task(
9731043
self, mock_client: AsyncMock, monkeypatch: pytest.MonkeyPatch

0 commit comments

Comments
 (0)