Skip to content

Commit 83c495e

Browse files
committed
Streaming payload models and serializers
1 parent 397acf1 commit 83c495e

2 files changed

Lines changed: 191 additions & 0 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Streaming payload models and event type exports."""
2+
3+
from models.common.streaming.stream_payloads import (
4+
EndEventData,
5+
EndStreamPayload,
6+
ErrorEventData,
7+
ErrorStreamPayload,
8+
InterruptedEventData,
9+
InterruptedStreamPayload,
10+
StartEventData,
11+
StartStreamPayload,
12+
StreamEventPayload,
13+
StreamPayloadBase,
14+
TokenChunkData,
15+
TokenStreamPayload,
16+
ToolCallStreamPayload,
17+
ToolResultStreamPayload,
18+
TurnCompleteStreamPayload,
19+
)
20+
21+
__all__ = [
22+
"StreamPayloadBase",
23+
"ErrorEventData",
24+
"StartEventData",
25+
"InterruptedEventData",
26+
"EndEventData",
27+
"ErrorStreamPayload",
28+
"StartStreamPayload",
29+
"InterruptedStreamPayload",
30+
"EndStreamPayload",
31+
"TokenChunkData",
32+
"TokenStreamPayload",
33+
"TurnCompleteStreamPayload",
34+
"ToolCallStreamPayload",
35+
"ToolResultStreamPayload",
36+
"StreamEventPayload",
37+
]
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Typed JSON bodies for SSE streaming events."""
2+
3+
import json
4+
from typing import Annotated, Literal, Optional, TypeAlias
5+
6+
from pydantic import BaseModel, ConfigDict, Field
7+
8+
from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary
9+
10+
11+
class StreamPayloadBase(BaseModel):
12+
"""Base for streaming SSE JSON payloads."""
13+
14+
model_config = ConfigDict(extra="forbid")
15+
16+
def serialize_json(self) -> str:
17+
"""Format this payload as an SSE ``data:`` line."""
18+
return f"data: {json.dumps(self.model_dump(mode='json'))}\n\n"
19+
20+
def serialize_text(self) -> str:
21+
"""Format this payload as plain text for text media type clients."""
22+
return ""
23+
24+
25+
class ErrorEventData(BaseModel):
26+
"""Payload for event: "error"."""
27+
28+
status_code: int
29+
response: str
30+
cause: str
31+
32+
33+
class StartEventData(BaseModel):
34+
"""Payload for event: "start"."""
35+
36+
conversation_id: str
37+
request_id: str
38+
39+
40+
class InterruptedEventData(BaseModel):
41+
"""Payload for event: "interrupted"."""
42+
43+
request_id: str
44+
45+
46+
class EndEventData(BaseModel):
47+
"""Nested data for event: "end"."""
48+
49+
referenced_documents: list[ReferencedDocument]
50+
truncated: Optional[bool]
51+
input_tokens: int
52+
output_tokens: int
53+
54+
55+
class ErrorStreamPayload(StreamPayloadBase):
56+
"""SSE error event body (event + typed data)."""
57+
58+
event: Literal["error"] = "error"
59+
data: ErrorEventData
60+
61+
def serialize_text(self) -> str:
62+
"""Serialize error stream payload to plain text."""
63+
return f"Status: {self.data.status_code} - {self.data.response} - {self.data.cause}"
64+
65+
66+
class StartStreamPayload(StreamPayloadBase):
67+
"""SSE stream start body."""
68+
69+
event: Literal["start"] = "start"
70+
data: StartEventData
71+
72+
73+
class InterruptedStreamPayload(StreamPayloadBase):
74+
"""SSE interrupted stream body."""
75+
76+
event: Literal["interrupted"] = "interrupted"
77+
data: InterruptedEventData
78+
79+
80+
class EndStreamPayload(StreamPayloadBase):
81+
"""SSE end-of-stream body (includes available_quotas beside data)."""
82+
83+
event: Literal["end"] = "end"
84+
data: EndEventData
85+
available_quotas: dict[str, int]
86+
87+
def serialize_text(self) -> str:
88+
"""Serialize end stream payload to plain text."""
89+
ref_docs_string = "\n".join(
90+
f"{doc.doc_title}: {doc.doc_url}"
91+
for doc in self.data.referenced_documents
92+
if doc.doc_url and doc.doc_title
93+
)
94+
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""
95+
96+
97+
class TokenChunkData(BaseModel):
98+
"""Structured data for token and turn-complete stream lines."""
99+
100+
id: int
101+
token: str
102+
103+
104+
class TokenStreamPayload(StreamPayloadBase):
105+
"""SSE token delta (event: "token")."""
106+
107+
event: Literal["token"] = "token"
108+
data: TokenChunkData
109+
110+
def serialize_text(self) -> str:
111+
"""Serialize token stream payload to plain text."""
112+
return self.data.token
113+
114+
115+
class TurnCompleteStreamPayload(StreamPayloadBase):
116+
"""SSE turn completion (same data shape as token)."""
117+
118+
event: Literal["turn_complete"] = "turn_complete"
119+
data: TokenChunkData
120+
121+
122+
class ToolCallStreamPayload(StreamPayloadBase):
123+
"""SSE tool call summary."""
124+
125+
event: Literal["tool_call"] = "tool_call"
126+
data: ToolCallSummary
127+
128+
def serialize_text(self) -> str:
129+
"""Serialize tool call stream payload to plain text."""
130+
return f"[Tool Call: {self.data.name}]\n"
131+
132+
133+
class ToolResultStreamPayload(StreamPayloadBase):
134+
"""SSE tool result summary."""
135+
136+
event: Literal["tool_result"] = "tool_result"
137+
data: ToolResultSummary
138+
139+
def serialize_text(self) -> str:
140+
"""Serialize tool result stream payload to plain text."""
141+
return "[Tool Result]\n"
142+
143+
144+
StreamEventPayload: TypeAlias = Annotated[
145+
TokenStreamPayload
146+
| TurnCompleteStreamPayload
147+
| ToolCallStreamPayload
148+
| ToolResultStreamPayload
149+
| EndStreamPayload
150+
| ErrorStreamPayload
151+
| InterruptedStreamPayload
152+
| StartStreamPayload,
153+
Field(discriminator="event"),
154+
]

0 commit comments

Comments
 (0)