Skip to content

Commit c9e9675

Browse files
authored
Merge pull request #1812 from asimurka/streaming_payload_models_and_serializers
LCORE-2311: Streaming models and serializers
2 parents 52b4770 + 54942e5 commit c9e9675

5 files changed

Lines changed: 385 additions & 0 deletions

File tree

src/models/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from models.common.query import Attachment, SolrVectorSearchRequest
1818
from models.common.transcripts import Transcript, TranscriptMetadata
1919
from models.common.turn_summary import (
20+
MCPListToolsSummary,
2021
RAGChunk,
2122
RAGContext,
2223
ReferencedDocument,
2324
ToolCallSummary,
25+
ToolInfoSummary,
2426
ToolResultSummary,
2527
TurnSummary,
2628
)
@@ -48,4 +50,6 @@
4850
"Transcript",
4951
"TranscriptMetadata",
5052
"TurnSummary",
53+
"ToolInfoSummary",
54+
"MCPListToolsSummary",
5155
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Streaming payload models and event type exports."""
2+
3+
from models.common.agents.stream_payloads import (
4+
EndEventData,
5+
EndStreamPayload,
6+
ErrorEventData,
7+
ErrorStreamPayload,
8+
InterruptedEventData,
9+
InterruptedStreamPayload,
10+
StartEventData,
11+
StartStreamPayload,
12+
StreamEventPayload,
13+
StreamPayloadBase,
14+
TokenChunkData,
15+
TokenStreamPayload,
16+
ToolCallStreamPayload,
17+
ToolResultStreamPayload,
18+
TurnCompleteStreamPayload,
19+
)
20+
from models.common.agents.turn_accumulator import AgentTurnAccumulator
21+
22+
__all__ = [
23+
"StreamPayloadBase",
24+
"ErrorEventData",
25+
"StartEventData",
26+
"InterruptedEventData",
27+
"EndEventData",
28+
"ErrorStreamPayload",
29+
"StartStreamPayload",
30+
"InterruptedStreamPayload",
31+
"EndStreamPayload",
32+
"TokenChunkData",
33+
"TokenStreamPayload",
34+
"TurnCompleteStreamPayload",
35+
"ToolCallStreamPayload",
36+
"ToolResultStreamPayload",
37+
"StreamEventPayload",
38+
"AgentTurnAccumulator",
39+
]
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Typed JSON bodies for SSE streaming events."""
2+
3+
import json
4+
from typing import Annotated, Literal, Optional, Self, TypeAlias
5+
6+
from pydantic import BaseModel, ConfigDict, Field
7+
8+
from models.api.responses.error import AbstractErrorResponse
9+
from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary
10+
11+
12+
class StreamPayloadBase(BaseModel):
13+
"""Base for streaming SSE JSON payloads."""
14+
15+
model_config = ConfigDict(extra="forbid")
16+
17+
def serialize_json(self) -> str:
18+
"""Format this payload as an SSE ``data:`` line."""
19+
return f"data: {json.dumps(self.model_dump(mode='json'))}\n\n"
20+
21+
def serialize_text(self) -> str:
22+
"""Format this payload as plain text for text media type clients."""
23+
return ""
24+
25+
26+
class ErrorEventData(BaseModel):
27+
"""Payload for event: "error"."""
28+
29+
status_code: int
30+
response: str
31+
cause: str
32+
33+
34+
class StartEventData(BaseModel):
35+
"""Payload for event: "start"."""
36+
37+
conversation_id: str
38+
request_id: str
39+
40+
41+
class InterruptedEventData(BaseModel):
42+
"""Payload for event: "interrupted"."""
43+
44+
request_id: str
45+
46+
47+
class EndEventData(BaseModel):
48+
"""Nested data for event: "end"."""
49+
50+
referenced_documents: list[ReferencedDocument]
51+
truncated: Optional[bool]
52+
input_tokens: int
53+
output_tokens: int
54+
55+
56+
class ErrorStreamPayload(StreamPayloadBase):
57+
"""SSE error event body (event + typed data)."""
58+
59+
event: Literal["error"] = "error"
60+
data: ErrorEventData
61+
62+
@classmethod
63+
def create(cls, *, status_code: int, response: str, cause: str) -> Self:
64+
"""Create an error stream payload from HTTP error fields.
65+
66+
Args:
67+
status_code: HTTP status code for the error.
68+
response: Short summary of the error.
69+
cause: Detailed explanation of the error cause.
70+
71+
Returns:
72+
Error stream payload instance.
73+
"""
74+
return cls(
75+
data=ErrorEventData(status_code=status_code, response=response, cause=cause)
76+
)
77+
78+
@classmethod
79+
def from_error_response(cls, error_response: AbstractErrorResponse) -> Self:
80+
"""Create an error stream payload from a structured API error response.
81+
82+
Args:
83+
error_response: Structured error response model.
84+
85+
Returns:
86+
Error stream payload instance.
87+
"""
88+
return cls.create(
89+
status_code=error_response.status_code,
90+
response=error_response.detail.response,
91+
cause=error_response.detail.cause,
92+
)
93+
94+
def serialize_text(self) -> str:
95+
"""Serialize error stream payload to plain text."""
96+
return f"Status: {self.data.status_code} - {self.data.response} - {self.data.cause}"
97+
98+
99+
class StartStreamPayload(StreamPayloadBase):
100+
"""SSE stream start body."""
101+
102+
event: Literal["start"] = "start"
103+
data: StartEventData
104+
105+
@classmethod
106+
def create(cls, *, conversation_id: str, request_id: str) -> Self:
107+
"""Create a stream-start payload.
108+
109+
Args:
110+
conversation_id: Conversation identifier for the stream.
111+
request_id: Request identifier for the stream.
112+
113+
Returns:
114+
Start stream payload instance.
115+
"""
116+
return cls(
117+
data=StartEventData(conversation_id=conversation_id, request_id=request_id)
118+
)
119+
120+
121+
class InterruptedStreamPayload(StreamPayloadBase):
122+
"""SSE interrupted stream body."""
123+
124+
event: Literal["interrupted"] = "interrupted"
125+
data: InterruptedEventData
126+
127+
@classmethod
128+
def create(cls, *, request_id: str) -> Self:
129+
"""Create an interrupted-stream payload.
130+
131+
Args:
132+
request_id: Request identifier for the interrupted stream.
133+
134+
Returns:
135+
Interrupted stream payload instance.
136+
"""
137+
return cls(data=InterruptedEventData(request_id=request_id))
138+
139+
140+
class EndStreamPayload(StreamPayloadBase):
141+
"""SSE end-of-stream body (includes available_quotas beside data)."""
142+
143+
event: Literal["end"] = "end"
144+
data: EndEventData
145+
available_quotas: dict[str, int]
146+
147+
@classmethod
148+
def create(
149+
cls,
150+
*,
151+
referenced_documents: list[ReferencedDocument],
152+
input_tokens: int,
153+
output_tokens: int,
154+
available_quotas: dict[str, int],
155+
) -> Self:
156+
"""Create an end-of-stream payload.
157+
158+
Args:
159+
referenced_documents: Documents referenced during the turn.
160+
input_tokens: Input token count for the turn.
161+
output_tokens: Output token count for the turn.
162+
available_quotas: Remaining quota limits by quota name.
163+
164+
Returns:
165+
End stream payload instance.
166+
"""
167+
return cls(
168+
data=EndEventData(
169+
referenced_documents=referenced_documents,
170+
truncated=None,
171+
input_tokens=input_tokens,
172+
output_tokens=output_tokens,
173+
),
174+
available_quotas=available_quotas,
175+
)
176+
177+
def serialize_text(self) -> str:
178+
"""Serialize end stream payload to plain text."""
179+
ref_docs_string = "\n".join(
180+
f"{doc.doc_title}: {doc.doc_url}"
181+
for doc in self.data.referenced_documents
182+
if doc.doc_url and doc.doc_title
183+
)
184+
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""
185+
186+
187+
class TokenChunkData(BaseModel):
188+
"""Structured data for token and turn-complete stream lines."""
189+
190+
id: int
191+
token: str
192+
193+
194+
class TokenStreamPayload(StreamPayloadBase):
195+
"""SSE token delta (event: "token")."""
196+
197+
event: Literal["token"] = "token"
198+
data: TokenChunkData
199+
200+
@classmethod
201+
def create(cls, *, chunk_id: int, token: str) -> Self:
202+
"""Create a token stream payload.
203+
204+
Args:
205+
chunk_id: Monotonic chunk identifier for the token delta.
206+
token: Token text for the delta.
207+
208+
Returns:
209+
Token stream payload instance.
210+
"""
211+
return cls(data=TokenChunkData(id=chunk_id, token=token))
212+
213+
def serialize_text(self) -> str:
214+
"""Serialize token stream payload to plain text."""
215+
return self.data.token
216+
217+
218+
class TurnCompleteStreamPayload(StreamPayloadBase):
219+
"""SSE turn completion (same data shape as token)."""
220+
221+
event: Literal["turn_complete"] = "turn_complete"
222+
data: TokenChunkData
223+
224+
@classmethod
225+
def create(cls, *, chunk_id: int, token: str) -> Self:
226+
"""Create a turn-complete stream payload.
227+
228+
Args:
229+
chunk_id: Monotonic chunk identifier for the final text.
230+
token: Full assistant text for the completed turn.
231+
232+
Returns:
233+
Turn-complete stream payload instance.
234+
"""
235+
return cls(data=TokenChunkData(id=chunk_id, token=token))
236+
237+
238+
class ToolCallStreamPayload(StreamPayloadBase):
239+
"""SSE tool call summary."""
240+
241+
event: Literal["tool_call"] = "tool_call"
242+
data: ToolCallSummary
243+
244+
def serialize_text(self) -> str:
245+
"""Serialize tool call stream payload to plain text."""
246+
return f"[Tool Call: {self.data.name}]\n"
247+
248+
249+
class ToolResultStreamPayload(StreamPayloadBase):
250+
"""SSE tool result summary."""
251+
252+
event: Literal["tool_result"] = "tool_result"
253+
data: ToolResultSummary
254+
255+
def serialize_text(self) -> str:
256+
"""Serialize tool result stream payload to plain text."""
257+
return "[Tool Result]\n"
258+
259+
260+
StreamEventPayload: TypeAlias = Annotated[
261+
TokenStreamPayload
262+
| TurnCompleteStreamPayload
263+
| ToolCallStreamPayload
264+
| ToolResultStreamPayload
265+
| EndStreamPayload
266+
| ErrorStreamPayload
267+
| InterruptedStreamPayload
268+
| StartStreamPayload,
269+
Field(discriminator="event"),
270+
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Mutable per-turn state for agent response processing."""
2+
3+
from dataclasses import dataclass, field
4+
from typing import Final
5+
6+
from pydantic_ai import AgentRunResult
7+
8+
from models.common.turn_summary import TurnSummary
9+
10+
11+
@dataclass(slots=True)
12+
class AgentTurnAccumulator: # pylint: disable=too-many-instance-attributes
13+
"""Information accumulator for a single interaction turn.
14+
15+
Attributes:
16+
vector_store_ids: Vector store IDs used to resolve RAG source labels.
17+
rag_id_mapping: Maps vector store IDs to user-facing source names.
18+
turn_summary: Aggregated turn output (text, tools, RAG, token usage).
19+
run_result: Agent run result (streaming only).
20+
chunk_id: Monotonic SSE chunk index (streaming only).
21+
text_parts: Buffered text deltas before turn_complete (streaming only).
22+
tool_round: Current tool-call round for summary labeling.
23+
round_increment_pending: Whether to bump tool_round on the next step.
24+
emitted_tool_call_ids: Tool call IDs already sent or recorded.
25+
emitted_tool_result_ids: Tool result IDs already sent or recorded.
26+
seen_docs: Referenced-document keys already added (deduplication).
27+
"""
28+
29+
vector_store_ids: Final[list[str]]
30+
rag_id_mapping: Final[dict[str, str]]
31+
turn_summary: TurnSummary
32+
run_result: AgentRunResult[str] | None = None
33+
chunk_id: int = 0
34+
text_parts: list[str] = field(default_factory=list)
35+
tool_round: int = 1
36+
round_increment_pending: bool = False
37+
emitted_tool_call_ids: set[str] = field(default_factory=set)
38+
emitted_tool_result_ids: set[str] = field(default_factory=set)
39+
seen_docs: set[tuple[str, str]] = field(default_factory=set)
40+
41+
def increment_round_if_pending(self) -> None:
42+
"""Increment tool_round if round_increment_pending is True."""
43+
if not self.round_increment_pending:
44+
return
45+
self.tool_round += 1
46+
self.round_increment_pending = False

0 commit comments

Comments
 (0)