Skip to content

Commit 4276a41

Browse files
Batch SDK payload codec visits
Issue: zorporation/durable-workflow#450 Loop-ID: build-01
1 parent a324796 commit 4276a41

5 files changed

Lines changed: 304 additions & 26 deletions

File tree

src/durable_workflow/_avro.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ def encode(value: Any) -> str:
6565
The generic wrapper accepts the same value shapes as ``json.dumps``; adapt
6666
domain objects to JSON-native data before encoding.
6767
"""
68+
return encode_many([value])[0]
69+
70+
71+
def encode_many(values: list[Any]) -> list[str]:
72+
"""Encode several Avro generic-wrapper payloads through one codec visit.
73+
74+
The Avro runtime does not provide a useful vectorized API for independent
75+
payload blobs, but batching here still avoids repeated import/schema/writer
76+
setup at high fan-out command boundaries.
77+
"""
78+
if not values:
79+
return []
80+
6881
try:
6982
import avro.io
7083
except ImportError as exc:
@@ -74,17 +87,8 @@ def encode(value: Any) -> str:
7487
) from exc
7588

7689
schema = _load_avro_schema()
77-
buf = io.BytesIO()
78-
encoder = avro.io.BinaryEncoder(buf)
7990
writer = avro.io.DatumWriter(schema)
80-
writer.write(
81-
{
82-
"json": json.dumps(value, separators=(",", ":"), ensure_ascii=False),
83-
"version": WRAPPER_VERSION,
84-
},
85-
encoder,
86-
)
87-
return base64.b64encode(_PREFIX_GENERIC_WRAPPER + buf.getvalue()).decode("ascii")
91+
return [_encode_with_writer(value, writer, avro.io.BinaryEncoder) for value in values]
8892

8993

9094
def decode(blob: str) -> Any:
@@ -94,6 +98,14 @@ def decode(blob: str) -> Any:
9498
schemas (prefix ``0x01``) raise :class:`ValueError` because the SDK
9599
has no schema registry.
96100
"""
101+
return decode_many([blob])[0]
102+
103+
104+
def decode_many(blobs: list[str]) -> list[Any]:
105+
"""Decode several Avro payload blobs through one codec visit."""
106+
if not blobs:
107+
return []
108+
97109
try:
98110
import avro.io
99111
except ImportError as exc:
@@ -102,6 +114,25 @@ def decode(blob: str) -> Any:
102114
"codec. Reinstall durable-workflow with its runtime dependencies."
103115
) from exc
104116

117+
schema = _load_avro_schema()
118+
reader = avro.io.DatumReader(schema)
119+
return [_decode_with_reader(blob, reader, avro.io.BinaryDecoder) for blob in blobs]
120+
121+
122+
def _encode_with_writer(value: Any, writer: Any, encoder_cls: Any) -> str:
123+
buf = io.BytesIO()
124+
encoder = encoder_cls(buf)
125+
writer.write(
126+
{
127+
"json": json.dumps(value, separators=(",", ":"), ensure_ascii=False),
128+
"version": WRAPPER_VERSION,
129+
},
130+
encoder,
131+
)
132+
return base64.b64encode(_PREFIX_GENERIC_WRAPPER + buf.getvalue()).decode("ascii")
133+
134+
135+
def _decode_with_reader(blob: str, reader: Any, decoder_cls: Any) -> Any:
105136
try:
106137
raw = base64.b64decode(blob, validate=True)
107138
except (ValueError, TypeError) as exc:
@@ -124,9 +155,7 @@ def decode(blob: str) -> Any:
124155
f"These bytes were not produced by a Durable Workflow Avro serializer."
125156
)
126157

127-
schema = _load_avro_schema()
128-
reader = avro.io.DatumReader(schema)
129-
decoder = avro.io.BinaryDecoder(io.BytesIO(raw[1:]))
158+
decoder = decoder_cls(io.BytesIO(raw[1:]))
130159
try:
131160
record = reader.read(decoder)
132161
except Exception as exc:

src/durable_workflow/client.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,16 +644,20 @@ async def describe_workflow(self, workflow_id: str) -> WorkflowExecution:
644644
``output`` envelopes when present.
645645
"""
646646
data = await self._request("GET", f"/workflows/{workflow_id}", context=workflow_id)
647-
input_val = None
648-
output_val = None
647+
input_val = data.get("input")
648+
output_val = data.get("output")
649+
envelope_jobs: list[tuple[str, Any]] = []
649650
if data.get("input_envelope"):
650-
input_val = serializer.decode_envelope(data["input_envelope"])
651-
elif data.get("input") is not None:
652-
input_val = data["input"]
651+
envelope_jobs.append(("input", data["input_envelope"]))
653652
if data.get("output_envelope"):
654-
output_val = serializer.decode_envelope(data["output_envelope"])
655-
elif data.get("output") is not None:
656-
output_val = data["output"]
653+
envelope_jobs.append(("output", data["output_envelope"]))
654+
if envelope_jobs:
655+
decoded = serializer.decode_envelopes([envelope for _, envelope in envelope_jobs])
656+
for (field, _), value in zip(envelope_jobs, decoded, strict=True):
657+
if field == "input":
658+
input_val = value
659+
else:
660+
output_val = value
657661
return WorkflowExecution(
658662
workflow_id=data.get("workflow_id", workflow_id),
659663
run_id=data.get("run_id"),

src/durable_workflow/serializer.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,26 @@ def encode_many(
228228
call sites.
229229
"""
230230
contexts = _warning_contexts_for_values(values, warning_context)
231-
return [
232-
encode(
233-
value,
231+
if codec == JSON_CODEC:
232+
blobs = [
233+
json.dumps(value, separators=(",", ":"), ensure_ascii=False)
234+
for value in values
235+
]
236+
elif codec == AVRO_CODEC:
237+
blobs = _avro.encode_many(list(values))
238+
else:
239+
raise ValueError(
240+
f"Unsupported payload codec {codec!r}; this SDK supports {SUPPORTED_CODECS!r}."
241+
)
242+
243+
for index, blob in enumerate(blobs):
244+
warn_if_payload_near_limit(
245+
blob,
234246
codec=codec,
235247
size_warning=size_warning,
236248
warning_context=contexts[index],
237249
)
238-
for index, value in enumerate(values)
239-
]
250+
return blobs
240251

241252

242253
def envelope(
@@ -377,6 +388,60 @@ def decode_envelope(value: Any, codec: str | None = None) -> Any:
377388
return decode(value, codec=codec)
378389

379390

391+
def decode_envelopes(values: Sequence[Any], codec: str | None = None) -> list[Any]:
392+
"""Decode several raw blobs or ``{codec, blob}`` envelopes in order."""
393+
jobs: list[tuple[str | None, str | None]] = []
394+
passthroughs: dict[int, Any] = {}
395+
for index, value in enumerate(values):
396+
if isinstance(value, dict) and "codec" in value and "blob" in value:
397+
jobs.append((value["blob"], value["codec"]))
398+
elif value is None or value == "":
399+
passthroughs[index] = None
400+
jobs.append((None, None))
401+
else:
402+
jobs.append((value, codec))
403+
404+
results: list[Any] = [None] * len(values)
405+
grouped: dict[str | None, list[tuple[int, str | None]]] = {}
406+
for index, (blob, item_codec) in enumerate(jobs):
407+
if index in passthroughs:
408+
continue
409+
grouped.setdefault(item_codec, []).append((index, blob))
410+
411+
for item_codec, group in grouped.items():
412+
decoded = decode_many([blob for _, blob in group], codec=item_codec)
413+
for (index, _), value in zip(group, decoded, strict=True):
414+
results[index] = value
415+
416+
return results
417+
418+
419+
def decode_many(blobs: Sequence[str | None], codec: str | None = None) -> list[Any]:
420+
"""Decode several payload blobs with one codec visit when possible."""
421+
if codec is None or codec == JSON_CODEC:
422+
return [decode(blob, codec=codec) for blob in blobs]
423+
424+
if codec == AVRO_CODEC:
425+
decoded: list[Any] = [None] * len(blobs)
426+
avro_jobs: list[tuple[int, str]] = []
427+
for index, blob in enumerate(blobs):
428+
if blob is None or blob == "":
429+
continue
430+
avro_jobs.append((index, blob))
431+
if not avro_jobs:
432+
return decoded
433+
avro_values = _avro.decode_many([blob for _, blob in avro_jobs])
434+
for (index, _), value in zip(avro_jobs, avro_values, strict=True):
435+
decoded[index] = value
436+
return decoded
437+
438+
raise ValueError(
439+
f"Cannot decode payload with codec {codec!r}: this SDK supports "
440+
f"{SUPPORTED_CODECS!r}. Ensure the workflow was started with a "
441+
f"compatible codec or an explicit {{'codec': '<codec>', 'blob': '...'}} envelope."
442+
)
443+
444+
380445
def decode(blob: str | None, codec: str | None = None) -> Any:
381446
"""Decode a payload blob into a Python value.
382447

tests/test_client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,34 @@ async def test_envelope_fields(self, client: Client) -> None:
217217
assert desc.output == {"greeting": "Hello, Ada!"}
218218
assert desc.payload_codec == "json"
219219

220+
@pytest.mark.asyncio
221+
async def test_envelope_fields_decode_in_one_batch(self, client: Client) -> None:
222+
resp = _mock_response(200, {
223+
"workflow_id": "wf-1",
224+
"run_id": "run-1",
225+
"workflow_type": "greeter",
226+
"status": "completed",
227+
"input_envelope": {"codec": "json", "blob": '["Ada"]'},
228+
"output_envelope": {"codec": "json", "blob": '{"greeting":"Hello, Ada!"}'},
229+
})
230+
231+
with (
232+
patch.object(client._http, "request", new_callable=AsyncMock, return_value=resp),
233+
patch.object(
234+
serializer,
235+
"decode_envelopes",
236+
return_value=[["Ada"], {"greeting": "Hello, Ada!"}],
237+
) as decode_envelopes,
238+
):
239+
desc = await client.describe_workflow("wf-1")
240+
241+
decode_envelopes.assert_called_once_with([
242+
{"codec": "json", "blob": '["Ada"]'},
243+
{"codec": "json", "blob": '{"greeting":"Hello, Ada!"}'},
244+
])
245+
assert desc.input == ["Ada"]
246+
assert desc.output == {"greeting": "Hello, Ada!"}
247+
220248
@pytest.mark.asyncio
221249
async def test_not_found(self, client: Client) -> None:
222250
resp = _mock_response(404, {"reason": "workflow_not_found", "message": "not found"})

0 commit comments

Comments
 (0)