Skip to content

Commit e12f20b

Browse files
committed
Extend SSEDelta to contain information for both reasoning and output
1 parent 32b88e0 commit e12f20b

4 files changed

Lines changed: 70 additions & 39 deletions

File tree

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -404,50 +404,80 @@ async def _iter_sse_lines(
404404
async def _handle_streaming_request(self, query: Query) -> None:
405405
"""Handle streaming response."""
406406
async for response in self._make_http_request(query):
407-
accumulated_content = []
407+
output_chunks = []
408+
reasoning_chunks = []
408409
first_chunk_sent = False
409410

410411
# Process SSE stream - yields batches of chunks
411412
async for chunk_batch in self._iter_sse_lines(response):
412-
accumulated_content.extend(chunk_batch)
413-
414-
# Determine which chunks to send: all or just first
415-
chunks_to_send = (
416-
chunk_batch
417-
if self.http_config.stream_all_chunks
418-
else chunk_batch[:1]
419-
if not first_chunk_sent
420-
else []
421-
)
422-
423-
# Send chunks
424-
for content in chunks_to_send:
425-
await self._response_socket.send(
426-
StreamChunk(
427-
id=query.id,
428-
response_chunk=content,
429-
is_complete=False,
430-
metadata={
431-
"first_chunk": not first_chunk_sent,
432-
"final_chunk": False,
433-
},
434-
)
413+
output_delta = []
414+
reasoning_delta = []
415+
for delta in chunk_batch:
416+
if delta.content:
417+
output_delta.append(delta.content)
418+
elif delta.reasoning:
419+
reasoning_delta.append(delta.reasoning)
420+
else:
421+
logger.debug("empty SSE delta")
422+
continue
423+
424+
for delta_batch, accumulator in (
425+
(reasoning_delta, reasoning_chunks),
426+
(output_delta, output_chunks),
427+
):
428+
if not delta_batch:
429+
continue
430+
accumulator.extend(delta_batch)
431+
432+
# Determine which chunks to send: all or just first
433+
chunks_to_send = (
434+
delta_batch
435+
if self.http_config.stream_all_chunks
436+
else delta_batch[:1]
437+
if not first_chunk_sent
438+
else []
435439
)
436-
first_chunk_sent = True
437-
if self.http_config.record_worker_events:
438-
EventRecorder.record_event(
439-
SampleEvent.ZMQ_RESPONSE_SENT,
440-
time.monotonic_ns(),
441-
sample_uuid=query.id,
442-
assert_active=True,
440+
441+
# Send chunks
442+
for content in chunks_to_send:
443+
await self._response_socket.send(
444+
StreamChunk(
445+
id=query.id,
446+
response_chunk=content,
447+
is_complete=False,
448+
metadata={
449+
"first_chunk": not first_chunk_sent,
450+
"final_chunk": False,
451+
},
452+
)
443453
)
454+
first_chunk_sent = True
455+
if self.http_config.record_worker_events:
456+
EventRecorder.record_event(
457+
SampleEvent.ZMQ_RESPONSE_SENT,
458+
time.monotonic_ns(),
459+
sample_uuid=query.id,
460+
assert_active=True,
461+
)
444462

445463
# Send final complete response
446-
response_output = []
447-
if accumulated_content:
448-
response_output.append(accumulated_content[0])
449-
if len(accumulated_content) > 1:
450-
response_output.append("".join(accumulated_content[1:]))
464+
if reasoning_chunks:
465+
resp_reasoning = [reasoning_chunks[0]]
466+
if len(reasoning_chunks) > 1:
467+
resp_reasoning.append("".join(reasoning_chunks[1:]))
468+
response_output = {
469+
"output": "".join(output_chunks),
470+
"reasoning": resp_reasoning,
471+
}
472+
elif output_chunks:
473+
resp_output = [output_chunks[0]]
474+
if len(output_chunks) > 1:
475+
resp_output.append("".join(output_chunks[1:]))
476+
response_output = {
477+
"output": resp_output,
478+
}
479+
else:
480+
response_output = {"output": []}
451481

452482
await self._response_socket.send(
453483
QueryResult(

src/inference_endpoint/load_generator/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def query_result_complete(self, result: QueryResult) -> None:
186186
SampleEvent.COMPLETE,
187187
timestamp_ns,
188188
sample_uuid=result.id,
189-
data={"output": result.response_output},
189+
data=result.response_output,
190190
)
191191

192192
for hook in self.complete_hooks:

src/inference_endpoint/openai/openai_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class SSEDelta(msgspec.Struct):
4141
"""SSE delta object containing content."""
4242

4343
content: str = ""
44+
reasoning: str = ""
4445

4546

4647
class SSEChoice(msgspec.Struct):
@@ -75,7 +76,7 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult:
7576
def decode_sse_message(cls, json_bytes: bytes) -> str:
7677
"""Decode SSE message and extract content string."""
7778
msg = msgspec.json.decode(json_bytes, type=SSEMessage)
78-
return msg.choices[0].delta.content
79+
return msg.choices[0].delta
7980

8081
# ========================================================================
8182
# Internal APIs

src/inference_endpoint/openai/openai_msgspec_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult:
127127
def decode_sse_message(cls, json_bytes: bytes) -> str:
128128
"""Decode SSE message and extract content string."""
129129
msg = cls._sse_decoder.decode(json_bytes)
130-
return msg.choices[0].delta.content
130+
return msg.choices[0].delta
131131

132132
# ========================================================================
133133
# Internal APIs

0 commit comments

Comments
 (0)