Skip to content

Commit 06b6aa0

Browse files
fix(responses-durable): normalize agent_reference before persisting durable task input
On hosted, the platform injects `agent_reference` as an AgentReference model (a Mapping but not json.dumps-serializable). It leaked through _split_runtime_refs into the persisted durable-task input, so create_and_start -> _resolve_input_storage raised `TypeError: Object of type AgentReference is not JSON serializable` and the durable background start silently fell back to a non-durable asyncio.create_task — meaning NO durable task was created and crash recovery never happened on hosted. _split_runtime_refs now normalizes a model-typed agent_reference to a plain dict (consumers all accept AgentReference | dict and read it as a mapping; the dict also survives cross-process recovery). Absent agent_reference stays the {} sentinel. This was invisible to the conformance suite because local/conformance requests carry no agent_reference (-> {} sentinel -> serializable). Adds TestSplitRuntimeRefsSerializable asserting the persisted durable input is JSON-serializable when agent_reference is a model. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1fcfa5b commit 06b6aa0

2 files changed

Lines changed: 181 additions & 29 deletions

File tree

sdk/agentserver/azure-ai-agentserver-responses/azure/ai/agentserver/responses/hosting/_durable_orchestrator.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def _build_server_error_payload(
120120
)
121121

122122

123-
def _split_runtime_refs(ctx_params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
123+
def _split_runtime_refs(
124+
ctx_params: dict[str, Any],
125+
) -> tuple[dict[str, Any], dict[str, Any]]:
124126
"""Split ``ctx_params`` into refs (memory-only) and persisted params.
125127
126128
:param ctx_params: The orchestrator's combined params dict.
@@ -137,6 +139,29 @@ def _split_runtime_refs(ctx_params: dict[str, Any]) -> tuple[dict[str, Any], dic
137139
refs[k] = v
138140
else:
139141
persisted[k] = v
142+
# The hosted gateway injects ``agent_reference`` as an ``AgentReference``
143+
# model. That model is a Mapping but is NOT ``json.dumps``-serializable, so
144+
# if it leaks into the persisted durable-task input the underlying
145+
# ``create_and_start`` -> ``_resolve_input_storage`` size check raises
146+
# ``TypeError`` and the whole durable start silently falls back to a
147+
# non-durable ``asyncio.create_task`` (no crash recovery). Normalize it to a
148+
# plain dict here: the durable input must be JSON-serializable AND survive
149+
# cross-process recovery, and every consumer accepts ``AgentReference | dict``
150+
# (and reads it as a mapping). Absent agent_reference is the ``{}`` sentinel,
151+
# which is already serializable.
152+
agent_reference = persisted.get("agent_reference")
153+
if agent_reference is not None and not isinstance(agent_reference, dict):
154+
if hasattr(agent_reference, "as_dict"):
155+
persisted["agent_reference"] = agent_reference.as_dict()
156+
else:
157+
try:
158+
persisted["agent_reference"] = dict(agent_reference)
159+
except (TypeError, ValueError):
160+
persisted["agent_reference"] = {
161+
"type": getattr(agent_reference, "type", "agent_reference"),
162+
"name": getattr(agent_reference, "name", None),
163+
"version": getattr(agent_reference, "version", None),
164+
}
140165
return refs, persisted
141166

142167

@@ -160,7 +185,9 @@ def _reconstruct_parsed_from_params(params: dict[str, Any]) -> Any:
160185
"missing. Ensure the orchestrator stamps it at fresh-entry."
161186
)
162187
# Late import to avoid circular dependency on hosting/_request_parsing.
163-
from ..models._generated import CreateResponse # pylint: disable=import-outside-toplevel
188+
from ..models._generated import (
189+
CreateResponse,
190+
) # pylint: disable=import-outside-toplevel
164191

165192
if isinstance(payload, dict):
166193
return CreateResponse(payload)
@@ -196,8 +223,14 @@ def _reconstruct_from_params(
196223
:rtype: tuple[ResponseExecution, ResponseContext]
197224
"""
198225
# Late imports to avoid module-level circular dependencies.
199-
from .._response_context import IsolationContext, ResponseContext # pylint: disable=import-outside-toplevel
200-
from ..models.runtime import ResponseExecution, ResponseModeFlags # pylint: disable=import-outside-toplevel
226+
from .._response_context import (
227+
IsolationContext,
228+
ResponseContext,
229+
) # pylint: disable=import-outside-toplevel
230+
from ..models.runtime import (
231+
ResponseExecution,
232+
ResponseModeFlags,
233+
) # pylint: disable=import-outside-toplevel
201234

202235
parsed = _reconstruct_parsed_from_params(params)
203236

@@ -226,7 +259,9 @@ def _reconstruct_from_params(
226259
input_items=record.input_items,
227260
previous_response_id=record.previous_response_id,
228261
conversation_id=record.conversation_id,
229-
history_limit=int(params.get("history_limit", runtime_options.default_fetch_history_count)),
262+
history_limit=int(
263+
params.get("history_limit", runtime_options.default_fetch_history_count)
264+
),
230265
# Client headers / query params are not preserved across recovery
231266
# — they were specific to the original HTTP request and are not
232267
# meaningful for the recovered handler.
@@ -509,7 +544,9 @@ def _ref(key: str) -> Any:
509544
# next-lifetime recovery can dispatch correctly without needing to
510545
# reconstruct the routing decisions from input params.
511546
if _RESP_DISPOSITION not in responses_ns:
512-
responses_ns[_RESP_DISPOSITION] = params.get("disposition", DISPOSITION_REINVOKE)
547+
responses_ns[_RESP_DISPOSITION] = params.get(
548+
"disposition", DISPOSITION_REINVOKE
549+
)
513550
# Force-flush so the disposition is durable BEFORE the body
514551
# could be killed — without an explicit flush the recovered
515552
# task would default to ``re-invoke`` and skip the mark-failed
@@ -581,8 +618,12 @@ def _ref(key: str) -> Any:
581618
runtime_state=self._runtime_state,
582619
runtime_options=self._options,
583620
)
584-
assert record is not None, "_reconstruct_from_params guarantees non-None record"
585-
assert self._runtime_state is not None, "runtime_state always wired at orchestrator init"
621+
assert (
622+
record is not None
623+
), "_reconstruct_from_params guarantees non-None record"
624+
assert (
625+
self._runtime_state is not None
626+
), "runtime_state always wired at orchestrator init"
586627
await self._runtime_state.add(record)
587628

588629
# After the reconstruction block, context and record are both
@@ -646,7 +687,8 @@ def _ref(key: str) -> Any:
646687
return
647688
except Exception: # pylint: disable=broad-exception-caught
648689
logger.debug(
649-
"persisted_response pre-fetch failed for %s " "(recovery, transient — not dropping)",
690+
"persisted_response pre-fetch failed for %s "
691+
"(recovery, transient — not dropping)",
650692
context.response_id,
651693
exc_info=True,
652694
)
@@ -772,7 +814,11 @@ async def _bridge() -> None:
772814
# mid-handler with grace exhausted) silently loses the
773815
# response because the one-shot ephemeral record is deleted
774816
# on cancel.
775-
if ctx.shutdown.is_set() and record is not None and record.status in {"queued", "in_progress"}:
817+
if (
818+
ctx.shutdown.is_set()
819+
and record is not None
820+
and record.status in {"queued", "in_progress"}
821+
):
776822
logger.info(
777823
"Response %s handler returned during shutdown without "
778824
"terminal; calling ctx.exit_for_recovery() so task stays "
@@ -950,11 +996,16 @@ async def _persist_crash_failed(
950996
# happened after terminal persistence, and overwriting would corrupt
951997
# the result.
952998
try:
953-
existing = await self._provider.get_response(response_id, isolation=isolation)
999+
existing = await self._provider.get_response(
1000+
response_id, isolation=isolation
1001+
)
9541002
existing_status = getattr(existing, "status", None) or (
9551003
existing.get("status") if isinstance(existing, dict) else None
9561004
)
957-
if isinstance(existing_status, str) and existing_status in _TERMINAL_STATUSES:
1005+
if (
1006+
isinstance(existing_status, str)
1007+
and existing_status in _TERMINAL_STATUSES
1008+
):
9581009
logger.info(
9591010
"_persist_crash_failed: response %s already terminal "
9601011
"(status=%s) — skipping overwrite (race avoidance)",
@@ -977,7 +1028,9 @@ async def _persist_crash_failed(
9771028
)
9781029

9791030
try:
980-
await self._provider.update_response(ResponseObject(failed_response), isolation=isolation)
1031+
await self._provider.update_response(
1032+
ResponseObject(failed_response), isolation=isolation
1033+
)
9811034
except KeyError:
9821035
# Response was never persisted at response.created — try
9831036
# create instead so the failed terminal still lands.

sdk/agentserver/azure-ai-agentserver-responses/tests/unit/test_durable_orchestrator.py

Lines changed: 115 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from azure.ai.agentserver.responses.hosting._durable_orchestrator import (
1414
DurableResponseOrchestrator,
1515
_is_recovered_entry,
16+
_split_runtime_refs,
1617
)
1718

1819

@@ -162,7 +163,9 @@ async def test_calls_run_background_non_stream(self) -> None:
162163
ctx.entry_mode = "fresh"
163164
ctx.retry_attempt = 0
164165
ctx.is_steered_turn = False # Spec 016 FR-020: was_steered renamed
165-
ctx.pending_input_count = 0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
166+
ctx.pending_input_count = (
167+
0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
168+
)
166169
ctx.metadata = _FakeTaskMetadata()
167170
ctx._cancellation_signal = asyncio.Event()
168171
ctx.shutdown = asyncio.Event()
@@ -195,7 +198,9 @@ async def test_calls_run_background_non_stream(self) -> None:
195198
assert kwargs["model"] == "gpt-4o"
196199

197200
@pytest.mark.asyncio
198-
async def test_recovery_and_steering_fields_flattened_on_response_context(self) -> None:
201+
async def test_recovery_and_steering_fields_flattened_on_response_context(
202+
self,
203+
) -> None:
199204
"""(Spec 024 Phase 5 — Proposal #10/#13) Recovery + steering
200205
classifiers land directly on ``ResponseContext`` flat fields.
201206
The pre-Phase-5 ``DurabilityContext`` indirection is deleted —
@@ -210,7 +215,10 @@ async def test_recovery_and_steering_fields_flattened_on_response_context(self)
210215
options=MagicMock(steerable_conversations=False),
211216
)
212217

213-
from azure.ai.agentserver.responses._response_context import IsolationContext, ResponseContext
218+
from azure.ai.agentserver.responses._response_context import (
219+
IsolationContext,
220+
ResponseContext,
221+
)
214222
from azure.ai.agentserver.responses.models.runtime import ResponseModeFlags
215223

216224
real_context = ResponseContext(
@@ -250,9 +258,13 @@ async def test_recovery_and_steering_fields_flattened_on_response_context(self)
250258
assert real_context.pending_input_count == 2
251259
assert not hasattr(real_context, "durability")
252260
# The metadata facade was swapped in to back the task metadata.
253-
from azure.ai.agentserver.responses._durability_context import _DeveloperMetadataFacade
261+
from azure.ai.agentserver.responses._durability_context import (
262+
_DeveloperMetadataFacade,
263+
)
254264

255-
assert isinstance(real_context.conversation_chain_metadata, _DeveloperMetadataFacade)
265+
assert isinstance(
266+
real_context.conversation_chain_metadata, _DeveloperMetadataFacade
267+
)
256268

257269
@pytest.mark.asyncio
258270
async def test_steerable_returns_none_for_implicit_suspend(self) -> None:
@@ -270,7 +282,9 @@ async def test_steerable_returns_none_for_implicit_suspend(self) -> None:
270282
ctx.entry_mode = "fresh"
271283
ctx.retry_attempt = 0
272284
ctx.is_steered_turn = False # Spec 016 FR-020: was_steered renamed
273-
ctx.pending_input_count = 0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
285+
ctx.pending_input_count = (
286+
0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
287+
)
274288
ctx.metadata = _FakeTaskMetadata()
275289
ctx._cancellation_signal = asyncio.Event()
276290
ctx.shutdown = asyncio.Event()
@@ -310,7 +324,9 @@ async def test_non_steerable_returns_none_too(self) -> None:
310324
ctx.entry_mode = "fresh"
311325
ctx.retry_attempt = 0
312326
ctx.is_steered_turn = False # Spec 016 FR-020: was_steered renamed
313-
ctx.pending_input_count = 0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
327+
ctx.pending_input_count = (
328+
0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
329+
)
314330
ctx.metadata = _FakeTaskMetadata()
315331
ctx._cancellation_signal = asyncio.Event()
316332
ctx.shutdown = asyncio.Event()
@@ -350,7 +366,9 @@ async def test_cancel_bridge_propagates(self) -> None:
350366
ctx.entry_mode = "fresh"
351367
ctx.retry_attempt = 0
352368
ctx.is_steered_turn = False # Spec 016 FR-020: was_steered renamed
353-
ctx.pending_input_count = 0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
369+
ctx.pending_input_count = (
370+
0 # Spec 016 FR-019: pending_inputs Sequence renamed to live int count
371+
)
354372
ctx.metadata = _FakeTaskMetadata()
355373
ctx._cancellation_signal = asyncio.Event()
356374
ctx.shutdown = asyncio.Event()
@@ -441,8 +459,12 @@ def test_pick_primitive_matrix(
441459
)
442460

443461
# Both primitives must exist (precondition for the matrix).
444-
assert hasattr(orch, "_one_shot_task_fn"), f"{case_id}: orchestrator must register a one-shot primitive."
445-
assert hasattr(orch, "_multi_turn_task_fn"), f"{case_id}: orchestrator must register a multi-turn primitive."
462+
assert hasattr(
463+
orch, "_one_shot_task_fn"
464+
), f"{case_id}: orchestrator must register a one-shot primitive."
465+
assert hasattr(
466+
orch, "_multi_turn_task_fn"
467+
), f"{case_id}: orchestrator must register a multi-turn primitive."
446468

447469
ctx_params = {
448470
"response_id": "resp_test",
@@ -472,22 +494,31 @@ def test_orchestrator_registers_both_primitives_on_construction(self) -> None:
472494
deployment that mis-imports the core wheel fails fast at
473495
server startup instead of per-request.
474496
"""
475-
opts = MagicMock(steerable_conversations=False, max_pending=10, default_fetch_history_count=100)
497+
opts = MagicMock(
498+
steerable_conversations=False,
499+
max_pending=10,
500+
default_fetch_history_count=100,
501+
)
476502
orch = DurableResponseOrchestrator(
477503
create_fn=AsyncMock(),
478504
provider=MagicMock(),
479505
options=opts,
480506
)
481507

482508
# Both registrations are present.
483-
assert hasattr(orch, "_one_shot_task_fn"), "Construction must register the one-shot primitive."
484-
assert hasattr(orch, "_multi_turn_task_fn"), "Construction must register the multi-turn primitive."
509+
assert hasattr(
510+
orch, "_one_shot_task_fn"
511+
), "Construction must register the one-shot primitive."
512+
assert hasattr(
513+
orch, "_multi_turn_task_fn"
514+
), "Construction must register the multi-turn primitive."
485515

486516
# Names are distinct and well-formed.
487517
one_shot_name = orch._one_shot_task_fn._opts.name
488518
multi_turn_name = orch._multi_turn_task_fn._opts.name
489519
assert one_shot_name != multi_turn_name, (
490-
f"Primitives must have distinct registration names " f"(both got {one_shot_name!r})."
520+
f"Primitives must have distinct registration names "
521+
f"(both got {one_shot_name!r})."
491522
)
492523
assert (
493524
"one_shot" in one_shot_name or "oneshot" in one_shot_name
@@ -499,13 +530,18 @@ def test_orchestrator_registers_both_primitives_on_construction(self) -> None:
499530
# The multi-turn primitive's steerable flag MUST match the
500531
# deployment's steerable_conversations option (per SOT §6.6).
501532
assert orch._multi_turn_task_fn._opts.steerable is False, (
502-
"Multi-turn primitive's steerable flag must match " "options.steerable_conversations."
533+
"Multi-turn primitive's steerable flag must match "
534+
"options.steerable_conversations."
503535
)
504536

505537
def test_orchestrator_multi_turn_steerable_flag_propagated(self) -> None:
506538
"""With ``steerable_conversations=True``, the multi-turn primitive
507539
is registered with ``steerable=True``."""
508-
opts = MagicMock(steerable_conversations=True, max_pending=10, default_fetch_history_count=100)
540+
opts = MagicMock(
541+
steerable_conversations=True,
542+
max_pending=10,
543+
default_fetch_history_count=100,
544+
)
509545
orch = DurableResponseOrchestrator(
510546
create_fn=AsyncMock(),
511547
provider=MagicMock(),
@@ -514,3 +550,66 @@ def test_orchestrator_multi_turn_steerable_flag_propagated(self) -> None:
514550
assert (
515551
orch._multi_turn_task_fn._opts.steerable is True
516552
), "Steerable flag must propagate from options to multi-turn primitive."
553+
554+
555+
class TestSplitRuntimeRefsSerializable:
556+
"""The persisted durable-task input MUST be JSON-serializable.
557+
558+
Regression for the hosted bug where the gateway-injected
559+
``agent_reference`` (an ``AgentReference`` model — a Mapping but not
560+
``json.dumps``-serializable) leaked into the persisted params, making
561+
``create_and_start`` raise ``TypeError`` and silently degrade the durable
562+
background run to a non-durable ``asyncio.create_task`` (no crash recovery).
563+
"""
564+
565+
def test_persisted_params_json_serializable_with_agent_reference_model(
566+
self,
567+
) -> None:
568+
import json
569+
570+
from azure.ai.agentserver.responses.models import AgentReference
571+
572+
ctx_params = {
573+
"response_id": "caresp_abc",
574+
"agent_name": "durable-responses-agent-demo",
575+
"session_id": "sess_1",
576+
"agent_reference": AgentReference(
577+
name="durable-responses-agent-demo", version="29"
578+
),
579+
# a runtime-only object ref that must be stripped, never persisted
580+
"_record_ref": object(),
581+
}
582+
583+
refs, persisted = _split_runtime_refs(ctx_params)
584+
585+
# refs hold the non-serializable object reference; not persisted
586+
assert "_record_ref" in refs
587+
assert "_record_ref" not in persisted
588+
589+
# agent_reference survives in the persisted input (needed across
590+
# cross-process recovery) but normalized to a plain dict
591+
assert isinstance(persisted["agent_reference"], dict)
592+
assert (
593+
persisted["agent_reference"].get("name") == "durable-responses-agent-demo"
594+
)
595+
assert persisted["agent_reference"].get("version") == "29"
596+
597+
# the whole persisted input must JSON-serialize (this is what the
598+
# core durable-task size check does and what previously raised)
599+
json.dumps(persisted) # must not raise
600+
601+
def test_empty_agent_reference_sentinel_passthrough(self) -> None:
602+
import json
603+
604+
# absent agent_reference is the ``{}`` sentinel — already serializable
605+
_, persisted = _split_runtime_refs({"response_id": "r", "agent_reference": {}})
606+
assert persisted["agent_reference"] == {}
607+
json.dumps(persisted)
608+
609+
def test_dict_agent_reference_unchanged(self) -> None:
610+
import json
611+
612+
ar = {"type": "agent_reference", "name": "x", "version": "1"}
613+
_, persisted = _split_runtime_refs({"response_id": "r", "agent_reference": ar})
614+
assert persisted["agent_reference"] == ar
615+
json.dumps(persisted)

0 commit comments

Comments
 (0)