Skip to content

Commit fb62d30

Browse files
committed
fix mcp
1 parent e2c3886 commit fb62d30

4 files changed

Lines changed: 247 additions & 5 deletions

File tree

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Pydantic AI provider for Llama Stack."""
22

3+
from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel
34
from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider
45

5-
__all__ = ["LlamaStackProvider"]
6+
__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"]
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Custom OpenAI Responses model that works around Llama Stack streaming quirks.
2+
3+
Llama Stack's Responses API emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP
4+
tool calls *before* the corresponding ``ResponseOutputItemAddedEvent``. pydantic_ai's
5+
default handler creates an orphan ``ToolCallPartDelta`` for the unannounced item_id,
6+
which later causes an IndexError in ``part_end_event``.
7+
8+
Additionally, MCP tool calls arrive as ``McpCall`` items (not ``ResponseFunctionToolCall``),
9+
and pydantic_ai registers them with a ``-call`` vendor_part_id suffix. The buffered
10+
deltas must be replayed with the matching suffix so pydantic_ai can append the
11+
streamed ``tool_args`` content to the correct part.
12+
13+
This module provides ``LlamaStackResponsesModel`` which wraps the event stream to
14+
buffer those early delta events and replay them correctly once the item is announced.
15+
"""
16+
17+
from __future__ import annotations as _annotations
18+
19+
from collections import defaultdict
20+
from collections.abc import AsyncIterator
21+
from contextlib import asynccontextmanager
22+
from typing import Any, cast
23+
24+
from openai import AsyncStream
25+
from openai.types import responses
26+
from pydantic_ai import UnexpectedModelBehavior
27+
from pydantic_ai._run_context import RunContext
28+
from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime
29+
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse
30+
from pydantic_ai.models import (
31+
ModelRequestParameters,
32+
ModelSettings,
33+
StreamedResponse,
34+
check_allow_model_requests,
35+
)
36+
from pydantic_ai.models.openai import (
37+
OpenAIResponsesModel,
38+
OpenAIResponsesModelSettings,
39+
OpenAIResponsesStreamedResponse,
40+
_map_api_errors,
41+
)
42+
43+
from log import get_logger
44+
45+
logger = get_logger(__name__)
46+
47+
48+
class _FilteredResponseStream:
49+
"""Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack.
50+
51+
Llama Stack emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP tool calls
52+
*before* the ``ResponseOutputItemAddedEvent`` that announces them. This wrapper
53+
buffers those early deltas and replays them once the announcement arrives.
54+
55+
For ``McpCall`` items specifically, pydantic_ai registers the part with a
56+
``-call`` vendor_part_id suffix. Buffered deltas are therefore replayed as a
57+
single combined event with the suffixed ``item_id`` so they match the part, plus
58+
a closing ``}`` to complete the outer JSON object that pydantic_ai opens.
59+
"""
60+
61+
def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None:
62+
"""Wrap an existing stream with reordering logic.
63+
64+
Args:
65+
source: The raw OpenAI AsyncStream to reorder.
66+
"""
67+
self._source = source
68+
self._announced_item_ids: set[str] = set()
69+
self._buffered_deltas: dict[
70+
str, list[responses.ResponseFunctionCallArgumentsDeltaEvent]
71+
] = defaultdict(list)
72+
73+
async def close(self) -> None:
74+
"""Close the underlying stream."""
75+
await self._source.close()
76+
77+
def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]:
78+
"""Return async iterator that reorders events."""
79+
return self._filtered_iter()
80+
81+
async def _filtered_iter(
82+
self,
83+
) -> AsyncIterator[responses.ResponseStreamEvent]:
84+
"""Yield events, buffering early argument deltas until their item is announced."""
85+
async for event in self._source:
86+
if isinstance(event, responses.ResponseOutputItemAddedEvent):
87+
if isinstance(event.item, responses.ResponseFunctionToolCall) and event.item.id:
88+
item_id = event.item.id
89+
self._announced_item_ids.add(item_id)
90+
yield event
91+
for delta in self._replay_buffered_deltas(item_id):
92+
yield delta
93+
continue
94+
95+
if isinstance(event.item, responses.response_output_item.McpCall):
96+
item_id = event.item.id
97+
self._announced_item_ids.add(item_id)
98+
yield event
99+
for delta in self._replay_mcp_buffered_deltas(item_id):
100+
yield delta
101+
continue
102+
103+
elif isinstance(event, responses.ResponseFunctionCallArgumentsDeltaEvent):
104+
if event.item_id not in self._announced_item_ids:
105+
logger.debug(
106+
"Buffering early argument delta for unannounced item_id=%s",
107+
event.item_id,
108+
)
109+
self._buffered_deltas[event.item_id].append(event)
110+
continue
111+
112+
yield event
113+
114+
def _replay_buffered_deltas(
115+
self, item_id: str
116+
) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]:
117+
"""Return buffered deltas for a ``ResponseFunctionToolCall`` announcement.
118+
119+
Args:
120+
item_id: The announced item ID.
121+
122+
Returns:
123+
List of buffered delta events to yield, unchanged.
124+
"""
125+
buffered = self._buffered_deltas.pop(item_id, [])
126+
if buffered:
127+
logger.debug(
128+
"Replaying %d buffered argument deltas for item_id=%s",
129+
len(buffered),
130+
item_id,
131+
)
132+
return buffered
133+
134+
def _replay_mcp_buffered_deltas(
135+
self, item_id: str
136+
) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]:
137+
"""Return buffered deltas for an ``McpCall`` announcement.
138+
139+
pydantic_ai registers ``McpCall`` parts with ``vendor_part_id=f'{id}-call'``
140+
and seeds the args string with everything up to ``"tool_args":``. The
141+
buffered deltas contain the actual ``tool_args`` content. We combine them
142+
into a single delta with the suffixed ``item_id`` and append a closing ``}``
143+
to complete the outer JSON object that pydantic_ai opened.
144+
145+
Args:
146+
item_id: The announced McpCall item ID.
147+
148+
Returns:
149+
List containing one synthetic delta event, or empty if nothing buffered.
150+
"""
151+
buffered = self._buffered_deltas.pop(item_id, [])
152+
if not buffered:
153+
return []
154+
155+
combined_args = "".join(d.delta for d in buffered) + "}"
156+
logger.debug(
157+
"Replaying %d buffered MCP argument deltas as single event "
158+
"for item_id=%s-call",
159+
len(buffered),
160+
item_id,
161+
)
162+
return [
163+
responses.ResponseFunctionCallArgumentsDeltaEvent(
164+
delta=combined_args,
165+
item_id=f"{item_id}-call",
166+
output_index=buffered[0].output_index,
167+
sequence_number=buffered[-1].sequence_number + 1,
168+
type="response.function_call_arguments.delta",
169+
)
170+
]
171+
172+
173+
class LlamaStackResponsesModel(OpenAIResponsesModel):
174+
"""OpenAI Responses model with Llama Stack streaming compatibility fixes.
175+
176+
Overrides the streaming response processing to buffer and replay
177+
``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits
178+
before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item.
179+
"""
180+
181+
@asynccontextmanager
182+
async def request_stream(
183+
self,
184+
messages: list[ModelMessage],
185+
model_settings: ModelSettings | None,
186+
model_request_parameters: ModelRequestParameters,
187+
run_context: RunContext[Any] | None = None,
188+
) -> AsyncIterator[StreamedResponse]:
189+
"""Request a streaming response, filtering Llama Stack-specific event quirks.
190+
191+
Args:
192+
messages: Model messages for the request.
193+
model_settings: Model-specific settings.
194+
model_request_parameters: Request parameters for the model.
195+
run_context: Optional run context from the agent.
196+
197+
Yields:
198+
A StreamedResponse with the filtered event stream.
199+
"""
200+
check_allow_model_requests()
201+
model_settings, model_request_parameters = self.prepare_request(
202+
model_settings,
203+
model_request_parameters,
204+
)
205+
model_settings_cast = cast(
206+
OpenAIResponsesModelSettings, model_settings or {}
207+
)
208+
response = await self._responses_create(
209+
messages, True, model_settings_cast, model_request_parameters
210+
)
211+
212+
filtered_stream = _FilteredResponseStream(response)
213+
214+
async with response:
215+
peekable: PeekableAsyncStream[
216+
responses.ResponseStreamEvent, _FilteredResponseStream
217+
] = PeekableAsyncStream(filtered_stream)
218+
219+
with _map_api_errors(self.model_name):
220+
first_chunk = await peekable.peek()
221+
222+
if isinstance(first_chunk, Unset):
223+
raise UnexpectedModelBehavior(
224+
"Streamed response ended without content or tool calls"
225+
)
226+
227+
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
228+
229+
yield OpenAIResponsesStreamedResponse(
230+
model_request_parameters=model_request_parameters,
231+
_model_name=first_chunk.response.model,
232+
_model_settings=model_settings_cast,
233+
_response=peekable, # type: ignore[arg-type]
234+
_provider_name=self._provider.name,
235+
_provider_url=self._provider.base_url,
236+
_provider_timestamp=number_to_datetime(
237+
first_chunk.response.created_at
238+
)
239+
if first_chunk.response.created_at
240+
else None,
241+
)

src/utils/agents/tool_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def summarize_mcp_tool_result(
479479
Tool result summary in LCS turn-summary format.
480480
"""
481481
content = cast(dict[str, Any], part.content)
482-
if "tools" in content or "error" in content:
482+
if "tools" in content:
483483
return summarize_mcp_list_tools_result(part, tool_round)
484484
return summarize_mcp_call_result(part, tool_round)
485485

src/utils/pydantic_ai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
88
from llama_stack_client import AsyncLlamaStackClient
99
from pydantic_ai import Agent
10-
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
10+
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
1111

1212
from models.common.responses.responses_api_params import ResponsesApiParams
13-
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider
13+
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider, LlamaStackResponsesModel
1414

1515
_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
1616
{
@@ -92,7 +92,7 @@ def build_agent(
9292
provider = _llama_stack_provider_from_client(client)
9393
settings = _model_settings_from_responses_params(responses_params)
9494

95-
model = OpenAIResponsesModel(
95+
model = LlamaStackResponsesModel(
9696
responses_params.model,
9797
provider=provider,
9898
settings=settings,

0 commit comments

Comments
 (0)