66from fastapi import APIRouter , Depends , Request
77from fastapi .responses import StreamingResponse
88from llama_stack .apis .agents .openai_responses import (
9+ OpenAIResponseContentPartOutputText ,
10+ OpenAIResponseMessage ,
911 OpenAIResponseObject ,
1012 OpenAIResponseObjectStream ,
13+ OpenAIResponseObjectStreamResponseCompleted ,
14+ OpenAIResponseObjectStreamResponseContentPartAdded ,
15+ OpenAIResponseObjectStreamResponseOutputTextDelta ,
16+ OpenAIResponseOutputMessageContentOutputText ,
1117)
1218from llama_stack_client import AsyncLlamaStackClient
1319
5359from utils .quota import consume_tokens , get_available_quotas
5460from utils .suid import normalize_conversation_id , to_llama_stack_conversation_id
5561from utils .mcp_headers import mcp_headers_dependency
56- from utils .shields import detect_shield_violations , get_available_shields
62+ from utils .shields import (
63+ append_turn_to_conversation ,
64+ run_shield_moderation ,
65+ )
5766from utils .token_counter import TokenCounter
5867from utils .transcripts import store_transcript
5968from utils .types import ToolCallSummary , TurnSummary
@@ -234,12 +243,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
234243 # Capture the response object for token usage extraction
235244 latest_response_object = getattr (chunk , "response" , None )
236245
237- # Check for shield violations in the completed response
238- if latest_response_object :
239- output = getattr (latest_response_object , "output" , None )
240- if output is not None :
241- detect_shield_violations (output )
242-
243246 if not emitted_turn_complete :
244247 final_message = summary .llm_response or "" .join (text_parts )
245248 if not final_message :
@@ -394,9 +397,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
394397 tuple: A tuple containing the streaming response object
395398 and the conversation ID.
396399 """
397- # List available shields for Responses API
398- available_shields = await get_available_shields (client )
399-
400400 # use system prompt from request or default one
401401 system_prompt = get_system_prompt (query_request , configuration )
402402 logger .debug ("Using system prompt: %s" , system_prompt )
@@ -441,6 +441,18 @@ async def retrieve_response( # pylint: disable=too-many-locals
441441 conversation_id ,
442442 )
443443
444+ # Run shield moderation before calling LLM
445+ moderation_result = await run_shield_moderation (client , input_text )
446+ if moderation_result .blocked :
447+ violation_message = moderation_result .message or ""
448+ await append_turn_to_conversation (
449+ client , llama_stack_conv_id , input_text , violation_message
450+ )
451+ return (
452+ create_violation_stream (violation_message , moderation_result .shield_model ),
453+ normalize_conversation_id (conversation_id ),
454+ )
455+
444456 create_params : dict [str , Any ] = {
445457 "input" : input_text ,
446458 "model" : model_id ,
@@ -451,14 +463,55 @@ async def retrieve_response( # pylint: disable=too-many-locals
451463 "conversation" : llama_stack_conv_id ,
452464 }
453465
454- # Add shields to extra_body if available
455- if available_shields :
456- create_params ["extra_body" ] = {"guardrails" : available_shields }
457-
458466 response = await client .responses .create (** create_params )
459467 response_stream = cast (AsyncIterator [OpenAIResponseObjectStream ], response )
460- # async for chunk in response_stream:
461- # logger.error("Chunk: %s", chunk.model_dump_json())
462- # Return the normalized conversation_id
463- # The response_generator will emit it in the start event
468+
464469 return response_stream , normalize_conversation_id (conversation_id )
470+
471+
472+ async def create_violation_stream (
473+ message : str ,
474+ shield_model : str | None = None ,
475+ ) -> AsyncIterator [OpenAIResponseObjectStream ]:
476+ """Create a minimal response stream for shield violations."""
477+ response_id = "resp_shield_violation"
478+ item_id = "msg_shield_violation"
479+
480+ # Content part added (triggers empty initial token)
481+ yield OpenAIResponseObjectStreamResponseContentPartAdded (
482+ content_index = 0 ,
483+ response_id = response_id ,
484+ item_id = item_id ,
485+ output_index = 0 ,
486+ part = OpenAIResponseContentPartOutputText (text = "" ),
487+ sequence_number = 0 ,
488+ )
489+
490+ # Text delta
491+ yield OpenAIResponseObjectStreamResponseOutputTextDelta (
492+ content_index = 0 ,
493+ delta = message ,
494+ item_id = item_id ,
495+ output_index = 0 ,
496+ sequence_number = 1 ,
497+ )
498+
499+ # Completed response
500+ yield OpenAIResponseObjectStreamResponseCompleted (
501+ response = OpenAIResponseObject (
502+ id = response_id ,
503+ created_at = 0 ,
504+ model = shield_model or "shield" ,
505+ output = [
506+ OpenAIResponseMessage (
507+ id = item_id ,
508+ content = [
509+ OpenAIResponseOutputMessageContentOutputText (text = message )
510+ ],
511+ role = "assistant" ,
512+ status = "completed" ,
513+ )
514+ ],
515+ status = "completed" ,
516+ )
517+ )
0 commit comments