Skip to content

Commit 42bf698

Browse files
authored
Merge pull request #1918 from asimurka/relocate_streaming_utils
LCORE-2311: Relocated stream interruption utilities
2 parents 199d7d5 + eafb386 commit 42bf698

4 files changed

Lines changed: 419 additions & 256 deletions

File tree

src/app/endpoints/streaming_query.py

Lines changed: 18 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Streaming query handler using Responses API."""
22

3-
# pylint: disable=too-many-lines
4-
53
import asyncio
64
import datetime
75
from collections.abc import AsyncIterator
@@ -10,7 +8,6 @@
108
from fastapi import APIRouter, Depends, HTTPException, Request
119
from fastapi.responses import StreamingResponse
1210
from llama_stack_api import (
13-
OpenAIResponseMessage,
1411
OpenAIResponseObject,
1512
OpenAIResponseObjectStream,
1613
)
@@ -56,7 +53,6 @@
5653
MEDIA_TYPE_EVENT_STREAM,
5754
MEDIA_TYPE_JSON,
5855
MEDIA_TYPE_TEXT,
59-
TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS,
6056
)
6157
from log import get_logger
6258
from metrics import recording
@@ -100,7 +96,6 @@
10096
is_context_length_error,
10197
prepare_input,
10298
store_query_results,
103-
update_conversation_topic_summary,
10499
validate_attachments_metadata,
105100
validate_model_provider_override,
106101
)
@@ -118,11 +113,14 @@
118113
prepare_responses_params,
119114
)
120115
from utils.shields import (
121-
append_turn_to_conversation,
122116
run_shield_moderation,
123117
validate_shield_ids_override,
124118
)
125-
from utils.stream_interrupts import get_stream_interrupt_registry
119+
from utils.stream_interrupts import (
120+
deregister_stream,
121+
persist_interrupted_turn,
122+
register_interrupt_callback,
123+
)
126124
from utils.streaming_sse import (
127125
http_exception_stream_event,
128126
shield_violation_generator,
@@ -430,43 +428,6 @@ async def retrieve_response_generator(
430428
raise HTTPException(**error_response.model_dump()) from e
431429

432430

433-
async def _background_update_topic_summary(
434-
context: ResponseGeneratorContext,
435-
model: str,
436-
) -> None:
437-
"""Generate topic summary and update DB/cache in the background.
438-
439-
Runs as a fire-and-forget task after an interrupted turn is persisted.
440-
All errors are caught and logged.
441-
"""
442-
try:
443-
topic_summary = await asyncio.wait_for(
444-
get_topic_summary(
445-
context.query_request.query,
446-
context.client,
447-
model,
448-
),
449-
timeout=TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS,
450-
)
451-
if topic_summary:
452-
update_conversation_topic_summary(
453-
context.conversation_id,
454-
topic_summary,
455-
user_id=context.user_id,
456-
skip_userid_check=context.skip_userid_check,
457-
)
458-
except asyncio.TimeoutError:
459-
logger.warning(
460-
"Topic summary timed out for interrupted turn, request %s",
461-
context.request_id,
462-
)
463-
except Exception: # pylint: disable=broad-except
464-
logger.exception(
465-
"Failed to generate topic summary for interrupted turn, request %s",
466-
context.request_id,
467-
)
468-
469-
470431
async def shutdown_background_topic_summary_tasks() -> None:
471432
"""Cancel and await outstanding background topic summary tasks on shutdown.
472433
@@ -485,148 +446,6 @@ async def shutdown_background_topic_summary_tasks() -> None:
485446
await asyncio.gather(*tasks, return_exceptions=True)
486447

487448

488-
async def _persist_interrupted_turn(
489-
context: ResponseGeneratorContext,
490-
responses_params: ResponsesApiParams,
491-
turn_summary: TurnSummary,
492-
original_input: Optional[ResponseInput] = None,
493-
) -> None:
494-
"""Persist the user query and an interrupted response into the conversation.
495-
496-
Called when a streaming request is cancelled so the exchange is not lost.
497-
Persists immediately with topic_summary=None so the conversation exists
498-
when the client fetches. Topic summary is generated in a background task
499-
and updated when ready.
500-
501-
Parameters:
502-
----------
503-
context: The response generator context.
504-
responses_params: The Responses API parameters.
505-
turn_summary: TurnSummary with llm_response already set to the
506-
interrupted message.
507-
original_input: In compacted mode, the original user input before the
508-
explicit-input rewrite. When set, the turn is persisted against it
509-
(the ``conversation`` parameter was dropped, and
510-
``responses_params.input`` is the explicit rewrite); ``None``
511-
otherwise (LCORE-1572).
512-
"""
513-
try:
514-
if original_input is not None:
515-
await append_turn_items_to_conversation(
516-
context.client,
517-
responses_params.conversation,
518-
original_input,
519-
[
520-
OpenAIResponseMessage(
521-
role="assistant", content=INTERRUPTED_RESPONSE_MESSAGE
522-
)
523-
],
524-
)
525-
else:
526-
await append_turn_to_conversation(
527-
context.client,
528-
responses_params.conversation,
529-
cast(str, responses_params.input),
530-
INTERRUPTED_RESPONSE_MESSAGE,
531-
)
532-
except Exception: # pylint: disable=broad-except
533-
logger.exception(
534-
"Failed to append interrupted turn to conversation for request %s",
535-
context.request_id,
536-
)
537-
538-
try:
539-
completed_at = datetime.datetime.now(datetime.UTC).strftime(
540-
"%Y-%m-%dT%H:%M:%SZ"
541-
)
542-
store_query_results(
543-
user_id=context.user_id,
544-
conversation_id=context.conversation_id,
545-
model=responses_params.model,
546-
completed_at=completed_at,
547-
started_at=context.started_at,
548-
summary=turn_summary,
549-
query=context.query_request.query,
550-
skip_userid_check=context.skip_userid_check,
551-
topic_summary=None,
552-
)
553-
554-
if (
555-
not context.query_request.conversation_id
556-
and context.query_request.generate_topic_summary
557-
):
558-
task = asyncio.create_task(
559-
_background_update_topic_summary(
560-
context=context,
561-
model=responses_params.model,
562-
)
563-
)
564-
_background_topic_summary_tasks.append(task)
565-
task.add_done_callback(_background_topic_summary_tasks.remove)
566-
except Exception: # pylint: disable=broad-except
567-
logger.exception(
568-
"Failed to store interrupted query results for request %s",
569-
context.request_id,
570-
)
571-
572-
573-
def _register_interrupt_callback(
574-
context: ResponseGeneratorContext,
575-
responses_params: ResponsesApiParams,
576-
turn_summary: TurnSummary,
577-
original_input: Optional[ResponseInput] = None,
578-
) -> list[bool]:
579-
"""Build an interrupt callback and register the stream for cancellation.
580-
581-
The callback is invoked by ``cancel_stream`` when the client
582-
interrupts, so persistence runs regardless of where the
583-
``CancelledError`` is raised in the ASGI stack.
584-
585-
A mutable one-element list is used as a shared guard so the
586-
callback and the in-generator ``CancelledError`` handler never
587-
both persist the same turn.
588-
589-
Parameters:
590-
----------
591-
context: The response generator context.
592-
responses_params: The Responses API parameters.
593-
turn_summary: TurnSummary populated during streaming.
594-
595-
Returns:
596-
-------
597-
A mutable list ``[False]`` used as a persist-done guard; the
598-
caller should check ``guard[0]`` before persisting and set
599-
it to ``True`` afterwards.
600-
"""
601-
guard: list[bool] = [False]
602-
603-
async def _on_interrupt() -> None:
604-
if guard[0]:
605-
return
606-
guard[0] = True
607-
turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE
608-
await _persist_interrupted_turn(
609-
context, responses_params, turn_summary, original_input
610-
)
611-
612-
current_task = asyncio.current_task()
613-
if current_task is not None:
614-
get_stream_interrupt_registry().register_stream(
615-
request_id=context.request_id,
616-
user_id=context.user_id,
617-
task=current_task,
618-
on_interrupt=_on_interrupt,
619-
)
620-
else:
621-
logger.warning(
622-
"No current asyncio task for request %s; "
623-
"stream interruption will not be available",
624-
context.request_id,
625-
)
626-
627-
return guard
628-
629-
630449
async def generate_response_with_compaction(
631450
context: ResponseGeneratorContext,
632451
responses_params: ResponsesApiParams,
@@ -759,8 +578,12 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi
759578
Yields:
760579
SSE-formatted strings from the wrapped generator
761580
"""
762-
persist_guard = _register_interrupt_callback(
763-
context, responses_params, turn_summary, original_input
581+
persist_guard = register_interrupt_callback(
582+
context,
583+
responses_params,
584+
turn_summary,
585+
_background_topic_summary_tasks,
586+
original_input,
764587
)
765588

766589
stream_completed = False
@@ -802,12 +625,16 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi
802625
if not persist_guard[0]:
803626
persist_guard[0] = True
804627
turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE
805-
await _persist_interrupted_turn(
806-
context, responses_params, turn_summary, original_input
628+
await persist_interrupted_turn(
629+
context,
630+
responses_params,
631+
turn_summary,
632+
_background_topic_summary_tasks,
633+
original_input,
807634
)
808635
yield stream_interrupted_event(context.request_id)
809636
finally:
810-
get_stream_interrupt_registry().deregister_stream(context.request_id)
637+
deregister_stream(context.request_id)
811638

812639
if not stream_completed:
813640
return

0 commit comments

Comments
 (0)