11"""Streaming query handler using Responses API."""
22
3- # pylint: disable=too-many-lines
4-
53import asyncio
64import datetime
75from collections .abc import AsyncIterator
108from fastapi import APIRouter , Depends , HTTPException , Request
119from fastapi .responses import StreamingResponse
1210from llama_stack_api import (
13- OpenAIResponseMessage ,
1411 OpenAIResponseObject ,
1512 OpenAIResponseObjectStream ,
1613)
5653 MEDIA_TYPE_EVENT_STREAM ,
5754 MEDIA_TYPE_JSON ,
5855 MEDIA_TYPE_TEXT ,
59- TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS ,
6056)
6157from log import get_logger
6258from metrics import recording
10096 is_context_length_error ,
10197 prepare_input ,
10298 store_query_results ,
103- update_conversation_topic_summary ,
10499 validate_attachments_metadata ,
105100 validate_model_provider_override ,
106101)
118113 prepare_responses_params ,
119114)
120115from utils .shields import (
121- append_turn_to_conversation ,
122116 run_shield_moderation ,
123117 validate_shield_ids_override ,
124118)
125- from utils .stream_interrupts import get_stream_interrupt_registry
119+ from utils .stream_interrupts import (
120+ deregister_stream ,
121+ persist_interrupted_turn ,
122+ register_interrupt_callback ,
123+ )
126124from utils .streaming_sse import (
127125 http_exception_stream_event ,
128126 shield_violation_generator ,
@@ -430,43 +428,6 @@ async def retrieve_response_generator(
430428 raise HTTPException (** error_response .model_dump ()) from e
431429
432430
433- async def _background_update_topic_summary (
434- context : ResponseGeneratorContext ,
435- model : str ,
436- ) -> None :
437- """Generate topic summary and update DB/cache in the background.
438-
439- Runs as a fire-and-forget task after an interrupted turn is persisted.
440- All errors are caught and logged.
441- """
442- try :
443- topic_summary = await asyncio .wait_for (
444- get_topic_summary (
445- context .query_request .query ,
446- context .client ,
447- model ,
448- ),
449- timeout = TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS ,
450- )
451- if topic_summary :
452- update_conversation_topic_summary (
453- context .conversation_id ,
454- topic_summary ,
455- user_id = context .user_id ,
456- skip_userid_check = context .skip_userid_check ,
457- )
458- except asyncio .TimeoutError :
459- logger .warning (
460- "Topic summary timed out for interrupted turn, request %s" ,
461- context .request_id ,
462- )
463- except Exception : # pylint: disable=broad-except
464- logger .exception (
465- "Failed to generate topic summary for interrupted turn, request %s" ,
466- context .request_id ,
467- )
468-
469-
470431async def shutdown_background_topic_summary_tasks () -> None :
471432 """Cancel and await outstanding background topic summary tasks on shutdown.
472433
@@ -485,148 +446,6 @@ async def shutdown_background_topic_summary_tasks() -> None:
485446 await asyncio .gather (* tasks , return_exceptions = True )
486447
487448
488- async def _persist_interrupted_turn (
489- context : ResponseGeneratorContext ,
490- responses_params : ResponsesApiParams ,
491- turn_summary : TurnSummary ,
492- original_input : Optional [ResponseInput ] = None ,
493- ) -> None :
494- """Persist the user query and an interrupted response into the conversation.
495-
496- Called when a streaming request is cancelled so the exchange is not lost.
497- Persists immediately with topic_summary=None so the conversation exists
498- when the client fetches. Topic summary is generated in a background task
499- and updated when ready.
500-
501- Parameters:
502- ----------
503- context: The response generator context.
504- responses_params: The Responses API parameters.
505- turn_summary: TurnSummary with llm_response already set to the
506- interrupted message.
507- original_input: In compacted mode, the original user input before the
508- explicit-input rewrite. When set, the turn is persisted against it
509- (the ``conversation`` parameter was dropped, and
510- ``responses_params.input`` is the explicit rewrite); ``None``
511- otherwise (LCORE-1572).
512- """
513- try :
514- if original_input is not None :
515- await append_turn_items_to_conversation (
516- context .client ,
517- responses_params .conversation ,
518- original_input ,
519- [
520- OpenAIResponseMessage (
521- role = "assistant" , content = INTERRUPTED_RESPONSE_MESSAGE
522- )
523- ],
524- )
525- else :
526- await append_turn_to_conversation (
527- context .client ,
528- responses_params .conversation ,
529- cast (str , responses_params .input ),
530- INTERRUPTED_RESPONSE_MESSAGE ,
531- )
532- except Exception : # pylint: disable=broad-except
533- logger .exception (
534- "Failed to append interrupted turn to conversation for request %s" ,
535- context .request_id ,
536- )
537-
538- try :
539- completed_at = datetime .datetime .now (datetime .UTC ).strftime (
540- "%Y-%m-%dT%H:%M:%SZ"
541- )
542- store_query_results (
543- user_id = context .user_id ,
544- conversation_id = context .conversation_id ,
545- model = responses_params .model ,
546- completed_at = completed_at ,
547- started_at = context .started_at ,
548- summary = turn_summary ,
549- query = context .query_request .query ,
550- skip_userid_check = context .skip_userid_check ,
551- topic_summary = None ,
552- )
553-
554- if (
555- not context .query_request .conversation_id
556- and context .query_request .generate_topic_summary
557- ):
558- task = asyncio .create_task (
559- _background_update_topic_summary (
560- context = context ,
561- model = responses_params .model ,
562- )
563- )
564- _background_topic_summary_tasks .append (task )
565- task .add_done_callback (_background_topic_summary_tasks .remove )
566- except Exception : # pylint: disable=broad-except
567- logger .exception (
568- "Failed to store interrupted query results for request %s" ,
569- context .request_id ,
570- )
571-
572-
573- def _register_interrupt_callback (
574- context : ResponseGeneratorContext ,
575- responses_params : ResponsesApiParams ,
576- turn_summary : TurnSummary ,
577- original_input : Optional [ResponseInput ] = None ,
578- ) -> list [bool ]:
579- """Build an interrupt callback and register the stream for cancellation.
580-
581- The callback is invoked by ``cancel_stream`` when the client
582- interrupts, so persistence runs regardless of where the
583- ``CancelledError`` is raised in the ASGI stack.
584-
585- A mutable one-element list is used as a shared guard so the
586- callback and the in-generator ``CancelledError`` handler never
587- both persist the same turn.
588-
589- Parameters:
590- ----------
591- context: The response generator context.
592- responses_params: The Responses API parameters.
593- turn_summary: TurnSummary populated during streaming.
594-
595- Returns:
596- -------
597- A mutable list ``[False]`` used as a persist-done guard; the
598- caller should check ``guard[0]`` before persisting and set
599- it to ``True`` afterwards.
600- """
601- guard : list [bool ] = [False ]
602-
603- async def _on_interrupt () -> None :
604- if guard [0 ]:
605- return
606- guard [0 ] = True
607- turn_summary .llm_response = INTERRUPTED_RESPONSE_MESSAGE
608- await _persist_interrupted_turn (
609- context , responses_params , turn_summary , original_input
610- )
611-
612- current_task = asyncio .current_task ()
613- if current_task is not None :
614- get_stream_interrupt_registry ().register_stream (
615- request_id = context .request_id ,
616- user_id = context .user_id ,
617- task = current_task ,
618- on_interrupt = _on_interrupt ,
619- )
620- else :
621- logger .warning (
622- "No current asyncio task for request %s; "
623- "stream interruption will not be available" ,
624- context .request_id ,
625- )
626-
627- return guard
628-
629-
630449async def generate_response_with_compaction (
631450 context : ResponseGeneratorContext ,
632451 responses_params : ResponsesApiParams ,
@@ -759,8 +578,12 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi
759578 Yields:
760579 SSE-formatted strings from the wrapped generator
761580 """
762- persist_guard = _register_interrupt_callback (
763- context , responses_params , turn_summary , original_input
581+ persist_guard = register_interrupt_callback (
582+ context ,
583+ responses_params ,
584+ turn_summary ,
585+ _background_topic_summary_tasks ,
586+ original_input ,
764587 )
765588
766589 stream_completed = False
@@ -802,12 +625,16 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi
802625 if not persist_guard [0 ]:
803626 persist_guard [0 ] = True
804627 turn_summary .llm_response = INTERRUPTED_RESPONSE_MESSAGE
805- await _persist_interrupted_turn (
806- context , responses_params , turn_summary , original_input
628+ await persist_interrupted_turn (
629+ context ,
630+ responses_params ,
631+ turn_summary ,
632+ _background_topic_summary_tasks ,
633+ original_input ,
807634 )
808635 yield stream_interrupted_event (context .request_id )
809636 finally :
810- get_stream_interrupt_registry (). deregister_stream (context .request_id )
637+ deregister_stream (context .request_id )
811638
812639 if not stream_completed :
813640 return
0 commit comments