Skip to content

Commit 4bfc377

Browse files
committed
Streaming payload models and serializers
1 parent 397acf1 commit 4bfc377

3 files changed

Lines changed: 321 additions & 0 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Streaming payload models and event type exports."""
2+
3+
from models.common.streaming.stream_payloads import (
4+
EndEventData,
5+
EndStreamPayload,
6+
ErrorEventData,
7+
ErrorStreamPayload,
8+
InterruptedEventData,
9+
InterruptedStreamPayload,
10+
TokenChunkData,
11+
TokenStreamPayload,
12+
ToolCallStreamPayload,
13+
ToolResultStreamPayload,
14+
TurnCompleteStreamPayload,
15+
StartEventData,
16+
StartStreamPayload,
17+
StreamEventPayload,
18+
StreamPayloadBase,
19+
)
20+
21+
__all__ = [
22+
"StreamPayloadBase",
23+
"ErrorEventData",
24+
"StartEventData",
25+
"InterruptedEventData",
26+
"EndEventData",
27+
"ErrorStreamPayload",
28+
"StartStreamPayload",
29+
"InterruptedStreamPayload",
30+
"EndStreamPayload",
31+
"TokenChunkData",
32+
"TokenStreamPayload",
33+
"TurnCompleteStreamPayload",
34+
"ToolCallStreamPayload",
35+
"ToolResultStreamPayload",
36+
"StreamEventPayload",
37+
]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Typed JSON bodies for SSE streaming events."""
2+
3+
from typing import Annotated, Literal, Optional, TypeAlias
4+
5+
from pydantic import BaseModel, ConfigDict, Field
6+
7+
from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary
8+
9+
10+
class StreamPayloadBase(BaseModel):
11+
"""Base for streaming SSE JSON payloads."""
12+
13+
model_config = ConfigDict(extra="forbid")
14+
15+
16+
class ErrorEventData(StreamPayloadBase):
17+
"""Payload for event: "error"."""
18+
19+
status_code: int
20+
response: str
21+
cause: str
22+
23+
24+
class StartEventData(StreamPayloadBase):
25+
"""Payload for event: "start"."""
26+
27+
conversation_id: str
28+
request_id: str
29+
30+
31+
class InterruptedEventData(StreamPayloadBase):
32+
"""Payload for event: "interrupted"."""
33+
34+
request_id: str
35+
36+
37+
class EndEventData(StreamPayloadBase):
38+
"""Nested data for event: "end"."""
39+
40+
referenced_documents: list[ReferencedDocument]
41+
truncated: Optional[bool]
42+
input_tokens: int
43+
output_tokens: int
44+
45+
46+
class ErrorStreamPayload(StreamPayloadBase):
47+
"""SSE error event body (event + typed data)."""
48+
49+
event: Literal["error"] = "error"
50+
data: ErrorEventData
51+
52+
53+
class StartStreamPayload(StreamPayloadBase):
54+
"""SSE stream start body."""
55+
56+
event: Literal["start"] = "start"
57+
data: StartEventData
58+
59+
60+
class InterruptedStreamPayload(StreamPayloadBase):
61+
"""SSE interrupted stream body."""
62+
63+
event: Literal["interrupted"] = "interrupted"
64+
data: InterruptedEventData
65+
66+
67+
class EndStreamPayload(StreamPayloadBase):
68+
"""SSE end-of-stream body (includes available_quotas beside data)."""
69+
70+
event: Literal["end"] = "end"
71+
data: EndEventData
72+
available_quotas: dict[str, int]
73+
74+
75+
class TokenChunkData(StreamPayloadBase):
76+
"""Structured data for token and turn-complete stream lines."""
77+
78+
id: int
79+
token: str
80+
81+
82+
class TokenStreamPayload(StreamPayloadBase):
83+
"""SSE token delta (event: "token")."""
84+
85+
event: Literal["token"] = "token"
86+
data: TokenChunkData
87+
88+
89+
class TurnCompleteStreamPayload(StreamPayloadBase):
90+
"""SSE turn completion (same data shape as token)."""
91+
92+
event: Literal["turn_complete"] = "turn_complete"
93+
data: TokenChunkData
94+
95+
96+
class ToolCallStreamPayload(StreamPayloadBase):
97+
"""SSE tool call summary."""
98+
99+
event: Literal["tool_call"] = "tool_call"
100+
data: ToolCallSummary
101+
102+
103+
class ToolResultStreamPayload(StreamPayloadBase):
104+
"""SSE tool result summary."""
105+
106+
event: Literal["tool_result"] = "tool_result"
107+
data: ToolResultSummary
108+
109+
110+
StreamEventPayload: TypeAlias = Annotated[
111+
TokenStreamPayload
112+
| TurnCompleteStreamPayload
113+
| ToolCallStreamPayload
114+
| ToolResultStreamPayload
115+
| EndStreamPayload
116+
| ErrorStreamPayload
117+
| InterruptedStreamPayload
118+
| StartStreamPayload,
119+
Field(discriminator="event"),
120+
]
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Shared streaming event formatting utilities."""
2+
3+
import json
4+
from functools import singledispatch
5+
from typing import Optional
6+
7+
from constants import MEDIA_TYPE_JSON
8+
from log import get_logger
9+
from models.api.responses.error import (
10+
AbstractErrorResponse,
11+
)
12+
from models.common.streaming import (
13+
EndEventData,
14+
EndStreamPayload,
15+
ErrorEventData,
16+
ErrorStreamPayload,
17+
InterruptedEventData,
18+
InterruptedStreamPayload,
19+
TokenStreamPayload,
20+
ToolCallStreamPayload,
21+
ToolResultStreamPayload,
22+
TurnCompleteStreamPayload,
23+
StartEventData,
24+
StartStreamPayload,
25+
StreamEventPayload,
26+
StreamPayloadBase,
27+
)
28+
from models.common.turn_summary import ReferencedDocument
29+
from utils.token_counter import TokenCounter
30+
31+
logger = get_logger(__name__)
32+
33+
34+
def format_stream_data(data: StreamPayloadBase) -> str:
35+
"""Format a Pydantic payload as an SSE ``data:`` line."""
36+
return f"data: {json.dumps(data.model_dump(mode='json'))}\n\n"
37+
38+
39+
def serialize_http_error_event(
40+
error: AbstractErrorResponse,
41+
media_type: Optional[str] = None,
42+
) -> str:
43+
"""Serialize an API error to an SSE or plain-text client response."""
44+
logger.error("Error while obtaining answer for user question")
45+
resolved_media_type = media_type or MEDIA_TYPE_JSON
46+
payload = ErrorStreamPayload(
47+
data=ErrorEventData(
48+
status_code=error.status_code,
49+
response=error.detail.response,
50+
cause=error.detail.cause,
51+
),
52+
)
53+
return serialize_event(payload, resolved_media_type)
54+
55+
56+
def serialize_start_event(
57+
conversation_id: str,
58+
request_id: str,
59+
media_type: str = MEDIA_TYPE_JSON,
60+
) -> str:
61+
"""Serialize the stream start payload to an SSE line."""
62+
payload = StartStreamPayload(
63+
data=StartEventData(
64+
conversation_id=conversation_id,
65+
request_id=request_id,
66+
),
67+
)
68+
return serialize_event(payload, media_type)
69+
70+
71+
def serialize_interrupted_event(
72+
request_id: str, media_type: str = MEDIA_TYPE_JSON
73+
) -> str:
74+
"""Serialize an interrupted-stream payload to an SSE line."""
75+
payload = InterruptedStreamPayload(
76+
data=InterruptedEventData(request_id=request_id),
77+
)
78+
return serialize_event(payload, media_type)
79+
80+
81+
def serialize_end_event(
82+
token_usage: TokenCounter,
83+
available_quotas: dict[str, int],
84+
referenced_documents: list[ReferencedDocument],
85+
media_type: Optional[str] = None,
86+
) -> str:
87+
"""Serialize the stream end payload for JSON SSE or plain-text clients."""
88+
resolved_media_type = media_type or MEDIA_TYPE_JSON
89+
payload = EndStreamPayload(
90+
data=EndEventData(
91+
referenced_documents=referenced_documents,
92+
truncated=None,
93+
input_tokens=token_usage.input_tokens,
94+
output_tokens=token_usage.output_tokens,
95+
),
96+
available_quotas=available_quotas,
97+
)
98+
return serialize_event(payload, resolved_media_type)
99+
100+
101+
def serialize_event(
102+
payload: StreamEventPayload,
103+
media_type: str = MEDIA_TYPE_JSON,
104+
) -> str:
105+
"""Serialize an LLM stream payload (token, tool, turn complete) for the client."""
106+
if media_type == MEDIA_TYPE_JSON:
107+
return format_stream_data(payload)
108+
return serialize_event_text(payload)
109+
110+
111+
@singledispatch
112+
def serialize_event_text(_payload: StreamPayloadBase) -> str:
113+
"""Serialize stream payload to plain text for text media type."""
114+
return ""
115+
116+
117+
@serialize_event_text.register
118+
def _(payload: TokenStreamPayload) -> str:
119+
"""Serialize token stream payload to plain text."""
120+
return payload.data.token
121+
122+
123+
@serialize_event_text.register
124+
def _(_payload: TurnCompleteStreamPayload) -> str:
125+
"""Serialize turn complete stream payload to plain text."""
126+
return ""
127+
128+
129+
@serialize_event_text.register
130+
def _(payload: ToolCallStreamPayload) -> str:
131+
"""Serialize tool call stream payload to plain text."""
132+
return f"[Tool Call: {payload.data.name}]\n"
133+
134+
135+
@serialize_event_text.register
136+
def _(_payload: ToolResultStreamPayload) -> str:
137+
"""Serialize tool result stream payload to plain text."""
138+
return "[Tool Result]\n"
139+
140+
141+
@serialize_event_text.register
142+
def _(payload: EndStreamPayload) -> str:
143+
"""Serialize end stream payload to plain text."""
144+
ref_docs_string = "\n".join(
145+
f"{doc.doc_title}: {doc.doc_url}"
146+
for doc in payload.data.referenced_documents
147+
if doc.doc_url and doc.doc_title
148+
)
149+
return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else ""
150+
151+
152+
@serialize_event_text.register
153+
def _(payload: ErrorStreamPayload) -> str:
154+
"""Serialize error stream payload to plain text."""
155+
cause_part = payload.data.cause if payload.data.cause is not None else ""
156+
return (
157+
f"Status: {payload.data.status_code} - {payload.data.response} - {cause_part}"
158+
)
159+
160+
161+
@serialize_event_text.register
162+
def _(_payload: StartStreamPayload) -> str:
163+
"""Serialize start stream payload to plain text."""
164+
return ""

0 commit comments

Comments
 (0)