|
| 1 | +"""Custom OpenAI Responses model that works around Llama Stack streaming quirks. |
| 2 | +
|
| 3 | +Llama Stack's Responses API emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP |
| 4 | +tool calls *before* the corresponding ``ResponseOutputItemAddedEvent``. pydantic_ai's |
| 5 | +default handler creates an orphan ``ToolCallPartDelta`` for the unannounced item_id, |
| 6 | +which later causes an IndexError in ``part_end_event``. |
| 7 | +
|
| 8 | +Additionally, MCP tool calls arrive as ``McpCall`` items (not ``ResponseFunctionToolCall``), |
| 9 | +and pydantic_ai registers them with a ``-call`` vendor_part_id suffix. The buffered |
| 10 | +deltas must be replayed with the matching suffix so pydantic_ai can append the |
| 11 | +streamed ``tool_args`` content to the correct part. |
| 12 | +
|
| 13 | +This module provides ``LlamaStackResponsesModel`` which wraps the event stream to |
| 14 | +buffer those early delta events and replay them correctly once the item is announced. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations as _annotations |
| 18 | + |
| 19 | +from collections import defaultdict |
| 20 | +from collections.abc import AsyncIterator |
| 21 | +from contextlib import asynccontextmanager |
| 22 | +from typing import Any, cast |
| 23 | + |
| 24 | +from openai import AsyncStream |
| 25 | +from openai.types import responses |
| 26 | +from pydantic_ai import UnexpectedModelBehavior |
| 27 | +from pydantic_ai._run_context import RunContext |
| 28 | +from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime |
| 29 | +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse |
| 30 | +from pydantic_ai.models import ( |
| 31 | + ModelRequestParameters, |
| 32 | + ModelSettings, |
| 33 | + StreamedResponse, |
| 34 | + check_allow_model_requests, |
| 35 | +) |
| 36 | +from pydantic_ai.models.openai import ( |
| 37 | + OpenAIResponsesModel, |
| 38 | + OpenAIResponsesModelSettings, |
| 39 | + OpenAIResponsesStreamedResponse, |
| 40 | + _map_api_errors, |
| 41 | +) |
| 42 | + |
| 43 | +from log import get_logger |
| 44 | + |
| 45 | +logger = get_logger(__name__) |
| 46 | + |
| 47 | + |
| 48 | +class _FilteredResponseStream: |
| 49 | + """Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack. |
| 50 | +
|
| 51 | + Llama Stack emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP tool calls |
| 52 | + *before* the ``ResponseOutputItemAddedEvent`` that announces them. This wrapper |
| 53 | + buffers those early deltas and replays them once the announcement arrives. |
| 54 | +
|
| 55 | + For ``McpCall`` items specifically, pydantic_ai registers the part with a |
| 56 | + ``-call`` vendor_part_id suffix. Buffered deltas are therefore replayed as a |
| 57 | + single combined event with the suffixed ``item_id`` so they match the part, plus |
| 58 | + a closing ``}`` to complete the outer JSON object that pydantic_ai opens. |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None: |
| 62 | + """Wrap an existing stream with reordering logic. |
| 63 | +
|
| 64 | + Args: |
| 65 | + source: The raw OpenAI AsyncStream to reorder. |
| 66 | + """ |
| 67 | + self._source = source |
| 68 | + self._announced_item_ids: set[str] = set() |
| 69 | + self._buffered_deltas: dict[ |
| 70 | + str, list[responses.ResponseFunctionCallArgumentsDeltaEvent] |
| 71 | + ] = defaultdict(list) |
| 72 | + |
| 73 | + async def close(self) -> None: |
| 74 | + """Close the underlying stream.""" |
| 75 | + await self._source.close() |
| 76 | + |
| 77 | + def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]: |
| 78 | + """Return async iterator that reorders events.""" |
| 79 | + return self._filtered_iter() |
| 80 | + |
| 81 | + async def _filtered_iter( |
| 82 | + self, |
| 83 | + ) -> AsyncIterator[responses.ResponseStreamEvent]: |
| 84 | + """Yield events, buffering early argument deltas until their item is announced.""" |
| 85 | + async for event in self._source: |
| 86 | + if isinstance(event, responses.ResponseOutputItemAddedEvent): |
| 87 | + if isinstance(event.item, responses.ResponseFunctionToolCall) and event.item.id: |
| 88 | + item_id = event.item.id |
| 89 | + self._announced_item_ids.add(item_id) |
| 90 | + yield event |
| 91 | + for delta in self._replay_buffered_deltas(item_id): |
| 92 | + yield delta |
| 93 | + continue |
| 94 | + |
| 95 | + if isinstance(event.item, responses.response_output_item.McpCall): |
| 96 | + item_id = event.item.id |
| 97 | + self._announced_item_ids.add(item_id) |
| 98 | + yield event |
| 99 | + for delta in self._replay_mcp_buffered_deltas(item_id): |
| 100 | + yield delta |
| 101 | + continue |
| 102 | + |
| 103 | + elif isinstance(event, responses.ResponseFunctionCallArgumentsDeltaEvent): |
| 104 | + if event.item_id not in self._announced_item_ids: |
| 105 | + logger.debug( |
| 106 | + "Buffering early argument delta for unannounced item_id=%s", |
| 107 | + event.item_id, |
| 108 | + ) |
| 109 | + self._buffered_deltas[event.item_id].append(event) |
| 110 | + continue |
| 111 | + |
| 112 | + yield event |
| 113 | + |
| 114 | + def _replay_buffered_deltas( |
| 115 | + self, item_id: str |
| 116 | + ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: |
| 117 | + """Return buffered deltas for a ``ResponseFunctionToolCall`` announcement. |
| 118 | +
|
| 119 | + Args: |
| 120 | + item_id: The announced item ID. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + List of buffered delta events to yield, unchanged. |
| 124 | + """ |
| 125 | + buffered = self._buffered_deltas.pop(item_id, []) |
| 126 | + if buffered: |
| 127 | + logger.debug( |
| 128 | + "Replaying %d buffered argument deltas for item_id=%s", |
| 129 | + len(buffered), |
| 130 | + item_id, |
| 131 | + ) |
| 132 | + return buffered |
| 133 | + |
| 134 | + def _replay_mcp_buffered_deltas( |
| 135 | + self, item_id: str |
| 136 | + ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: |
| 137 | + """Return buffered deltas for an ``McpCall`` announcement. |
| 138 | +
|
| 139 | + pydantic_ai registers ``McpCall`` parts with ``vendor_part_id=f'{id}-call'`` |
| 140 | + and seeds the args string with everything up to ``"tool_args":``. The |
| 141 | + buffered deltas contain the actual ``tool_args`` content. We combine them |
| 142 | + into a single delta with the suffixed ``item_id`` and append a closing ``}`` |
| 143 | + to complete the outer JSON object that pydantic_ai opened. |
| 144 | +
|
| 145 | + Args: |
| 146 | + item_id: The announced McpCall item ID. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + List containing one synthetic delta event, or empty if nothing buffered. |
| 150 | + """ |
| 151 | + buffered = self._buffered_deltas.pop(item_id, []) |
| 152 | + if not buffered: |
| 153 | + return [] |
| 154 | + |
| 155 | + combined_args = "".join(d.delta for d in buffered) + "}" |
| 156 | + logger.debug( |
| 157 | + "Replaying %d buffered MCP argument deltas as single event " |
| 158 | + "for item_id=%s-call", |
| 159 | + len(buffered), |
| 160 | + item_id, |
| 161 | + ) |
| 162 | + return [ |
| 163 | + responses.ResponseFunctionCallArgumentsDeltaEvent( |
| 164 | + delta=combined_args, |
| 165 | + item_id=f"{item_id}-call", |
| 166 | + output_index=buffered[0].output_index, |
| 167 | + sequence_number=buffered[-1].sequence_number + 1, |
| 168 | + type="response.function_call_arguments.delta", |
| 169 | + ) |
| 170 | + ] |
| 171 | + |
| 172 | + |
| 173 | +class LlamaStackResponsesModel(OpenAIResponsesModel): |
| 174 | + """OpenAI Responses model with Llama Stack streaming compatibility fixes. |
| 175 | +
|
| 176 | + Overrides the streaming response processing to buffer and replay |
| 177 | + ``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits |
| 178 | + before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item. |
| 179 | + """ |
| 180 | + |
| 181 | + @asynccontextmanager |
| 182 | + async def request_stream( |
| 183 | + self, |
| 184 | + messages: list[ModelMessage], |
| 185 | + model_settings: ModelSettings | None, |
| 186 | + model_request_parameters: ModelRequestParameters, |
| 187 | + run_context: RunContext[Any] | None = None, |
| 188 | + ) -> AsyncIterator[StreamedResponse]: |
| 189 | + """Request a streaming response, filtering Llama Stack-specific event quirks. |
| 190 | +
|
| 191 | + Args: |
| 192 | + messages: Model messages for the request. |
| 193 | + model_settings: Model-specific settings. |
| 194 | + model_request_parameters: Request parameters for the model. |
| 195 | + run_context: Optional run context from the agent. |
| 196 | +
|
| 197 | + Yields: |
| 198 | + A StreamedResponse with the filtered event stream. |
| 199 | + """ |
| 200 | + check_allow_model_requests() |
| 201 | + model_settings, model_request_parameters = self.prepare_request( |
| 202 | + model_settings, |
| 203 | + model_request_parameters, |
| 204 | + ) |
| 205 | + model_settings_cast = cast( |
| 206 | + OpenAIResponsesModelSettings, model_settings or {} |
| 207 | + ) |
| 208 | + response = await self._responses_create( |
| 209 | + messages, True, model_settings_cast, model_request_parameters |
| 210 | + ) |
| 211 | + |
| 212 | + filtered_stream = _FilteredResponseStream(response) |
| 213 | + |
| 214 | + async with response: |
| 215 | + peekable: PeekableAsyncStream[ |
| 216 | + responses.ResponseStreamEvent, _FilteredResponseStream |
| 217 | + ] = PeekableAsyncStream(filtered_stream) |
| 218 | + |
| 219 | + with _map_api_errors(self.model_name): |
| 220 | + first_chunk = await peekable.peek() |
| 221 | + |
| 222 | + if isinstance(first_chunk, Unset): |
| 223 | + raise UnexpectedModelBehavior( |
| 224 | + "Streamed response ended without content or tool calls" |
| 225 | + ) |
| 226 | + |
| 227 | + assert isinstance(first_chunk, responses.ResponseCreatedEvent) |
| 228 | + |
| 229 | + yield OpenAIResponsesStreamedResponse( |
| 230 | + model_request_parameters=model_request_parameters, |
| 231 | + _model_name=first_chunk.response.model, |
| 232 | + _model_settings=model_settings_cast, |
| 233 | + _response=peekable, # type: ignore[arg-type] |
| 234 | + _provider_name=self._provider.name, |
| 235 | + _provider_url=self._provider.base_url, |
| 236 | + _provider_timestamp=number_to_datetime( |
| 237 | + first_chunk.response.created_at |
| 238 | + ) |
| 239 | + if first_chunk.response.created_at |
| 240 | + else None, |
| 241 | + ) |
0 commit comments