Skip to content

Commit 8d2deee

Browse files
committed
Refactor of shield moderation
1 parent 6d7c76a commit 8d2deee

13 files changed

Lines changed: 475 additions & 280 deletions

File tree

src/app/endpoints/query.py

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
UnauthorizedResponse,
3737
UnprocessableEntityResponse,
3838
)
39+
from utils.conversations import append_turn_items_to_conversation
3940
from utils.endpoints import (
4041
check_configuration_loaded,
4142
validate_and_retrieve_conversation,
@@ -59,14 +60,11 @@
5960
get_topic_summary,
6061
prepare_responses_params,
6162
)
62-
from utils.shields import (
63-
append_turn_to_conversation,
64-
run_shield_moderation,
65-
validate_shield_ids_override,
66-
)
63+
from utils.shields import run_shield_moderation, validate_shield_ids_override
6764
from utils.suid import normalize_conversation_id
6865
from utils.types import (
6966
ResponsesApiParams,
67+
ShieldModerationResult,
7068
TurnSummary,
7169
)
7270
from utils.vector_search import build_rag_context
@@ -158,14 +156,21 @@ async def query_endpoint_handler(
158156

159157
client = AsyncLlamaStackClientHolder().get_client()
160158

161-
# Build RAG context from Inline RAG sources
162-
inline_rag_context = await build_rag_context(
163-
client, query_request.query, query_request.vector_store_ids, query_request.solr
164-
)
165-
166159
# Moderation input is the raw user content (query + attachments) without injected RAG
167160
# context, to avoid false positives from retrieved document content.
168161
moderation_input = prepare_input(query_request)
162+
moderation_result = await run_shield_moderation(
163+
client, moderation_input, query_request.shield_ids
164+
)
165+
166+
# Build RAG context from Inline RAG sources
167+
inline_rag_context = await build_rag_context(
168+
client,
169+
moderation_result.decision,
170+
query_request.query,
171+
query_request.vector_store_ids,
172+
query_request.solr,
173+
)
169174

170175
# Prepare API request parameters
171176
responses_params = await prepare_responses_params(
@@ -177,7 +182,7 @@ async def query_endpoint_handler(
177182
stream=False,
178183
store=True,
179184
request_headers=request.headers,
180-
inline_rag_context=inline_rag_context.context_text or None,
185+
inline_rag_context=inline_rag_context.context_text,
181186
)
182187

183188
# Handle Azure token refresh if needed
@@ -189,32 +194,22 @@ async def query_endpoint_handler(
189194
):
190195
client = await update_azure_token(client)
191196

192-
# Build index identification mapping for RAG source resolution
193-
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
194-
rag_id_mapping = configuration.rag_id_mapping
195-
196197
# Retrieve response using Responses API
197-
turn_summary = await retrieve_response(
198-
client,
199-
responses_params,
200-
query_request.shield_ids,
201-
vector_store_ids,
202-
rag_id_mapping,
203-
moderation_input=moderation_input,
204-
)
205-
206-
# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
207-
rag_chunks = inline_rag_context.rag_chunks
208-
tool_rag_chunks = turn_summary.rag_chunks or []
209-
logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks))
210-
turn_summary.rag_chunks = rag_chunks + tool_rag_chunks
211-
212-
# Add tool-based RAG documents and chunks
213-
rag_documents = inline_rag_context.referenced_documents
214-
tool_rag_documents = turn_summary.referenced_documents or []
215-
turn_summary.referenced_documents = deduplicate_referenced_documents(
216-
rag_documents + tool_rag_documents
217-
)
198+
turn_summary = await retrieve_response(client, responses_params, moderation_result)
199+
200+
if moderation_result.decision == "passed":
201+
# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
202+
rag_chunks = inline_rag_context.rag_chunks
203+
tool_rag_chunks = turn_summary.rag_chunks
204+
logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks))
205+
turn_summary.rag_chunks = rag_chunks + tool_rag_chunks
206+
207+
# Add tool-based RAG documents and chunks
208+
rag_documents = inline_rag_context.referenced_documents
209+
tool_rag_documents = turn_summary.referenced_documents
210+
turn_summary.referenced_documents = deduplicate_referenced_documents(
211+
rag_documents + tool_rag_documents
212+
)
218213

219214
# Get topic summary for new conversation
220215
if not user_conversation and query_request.generate_topic_summary:
@@ -272,10 +267,7 @@ async def query_endpoint_handler(
272267
async def retrieve_response( # pylint: disable=too-many-locals
273268
client: AsyncLlamaStackClient,
274269
responses_params: ResponsesApiParams,
275-
shield_ids: Optional[list[str]] = None,
276-
vector_store_ids: Optional[list[str]] = None,
277-
rag_id_mapping: Optional[dict[str, str]] = None,
278-
moderation_input: Optional[str] = None,
270+
moderation_result: ShieldModerationResult,
279271
) -> TurnSummary:
280272
"""
281273
Retrieve response from LLMs and agents.
@@ -286,33 +278,21 @@ async def retrieve_response( # pylint: disable=too-many-locals
286278
Parameters:
287279
client: The AsyncLlamaStackClient to use for the request.
288280
responses_params: The Responses API parameters.
289-
shield_ids: Optional list of shield IDs for moderation.
290-
vector_store_ids: Vector store IDs used in the query for source resolution.
291-
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
292-
moderation_input: Text to moderate. Should be the raw user content (query +
293-
attachments) without injected RAG context to avoid false positives.
294-
Falls back to responses_params.input if not provided.
281+
moderation_result: The moderation result.
295282
296283
Returns:
297284
TurnSummary: Summary of the LLM response content
298285
"""
299286
response: Optional[OpenAIResponseObject] = None
300-
try:
301-
moderation_result = await run_shield_moderation(
287+
if moderation_result.decision == "blocked":
288+
await append_turn_items_to_conversation(
302289
client,
303-
moderation_input or cast(str, responses_params.input),
304-
shield_ids,
290+
responses_params.conversation,
291+
responses_params.input,
292+
[moderation_result.refusal_response],
305293
)
306-
if moderation_result.decision == "blocked":
307-
# Handle shield moderation blocking
308-
violation_message = moderation_result.message
309-
await append_turn_to_conversation(
310-
client,
311-
responses_params.conversation,
312-
cast(str, responses_params.input),
313-
violation_message,
314-
)
315-
return TurnSummary(llm_response=violation_message)
294+
return TurnSummary(llm_response=moderation_result.message)
295+
try:
316296
response = await client.responses.create(
317297
**responses_params.model_dump(exclude_none=True)
318298
)
@@ -333,6 +313,8 @@ async def retrieve_response( # pylint: disable=too-many-locals
333313
error_response = handle_known_apistatus_errors(e, responses_params.model)
334314
raise HTTPException(**error_response.model_dump()) from e
335315

316+
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
317+
rag_id_mapping = configuration.rag_id_mapping
336318
return build_turn_summary(
337319
response, responses_params.model, vector_store_ids, rag_id_mapping
338320
)

src/app/endpoints/streaming_query.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
UnauthorizedResponse,
5858
UnprocessableEntityResponse,
5959
)
60+
from utils.conversations import append_turn_items_to_conversation
6061
from utils.endpoints import (
6162
check_configuration_loaded,
6263
validate_and_retrieve_conversation,
@@ -189,10 +190,22 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
189190

190191
client = AsyncLlamaStackClientHolder().get_client()
191192

193+
# Moderation input is the raw user content (query + attachments) without injected RAG
194+
# context, to avoid false positives from retrieved document content.
195+
moderation_input = prepare_input(query_request)
196+
moderation_result = await run_shield_moderation(
197+
client, moderation_input, query_request.shield_ids
198+
)
199+
192200
# Build RAG context from Inline RAG sources
193201
inline_rag_context = await build_rag_context(
194-
client, query_request.query, query_request.vector_store_ids, query_request.solr
202+
client,
203+
moderation_result.decision,
204+
query_request.query,
205+
query_request.vector_store_ids,
206+
query_request.solr,
195207
)
208+
196209
# Prepare API request parameters
197210
responses_params = await prepare_responses_params(
198211
client=client,
@@ -203,7 +216,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
203216
stream=True,
204217
store=True,
205218
request_headers=request.headers,
206-
inline_rag_context=inline_rag_context.context_text or None,
219+
inline_rag_context=inline_rag_context.context_text,
207220
)
208221

209222
# Handle Azure token refresh if needed
@@ -227,6 +240,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
227240
query_request=query_request,
228241
started_at=started_at,
229242
client=client,
243+
moderation_result=moderation_result,
230244
vector_store_ids=extract_vector_store_ids_from_tools(responses_params.tools),
231245
rag_id_mapping=configuration.rag_id_mapping,
232246
)
@@ -240,9 +254,15 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
240254
generator, turn_summary = await retrieve_response_generator(
241255
responses_params=responses_params,
242256
context=context,
243-
inline_rag_documents=inline_rag_context.referenced_documents,
257+
inline_rag_docs=inline_rag_context.referenced_documents,
244258
)
245259

260+
# Combine inline RAG results (BYOK + Solr) with tool-based results
261+
if context.moderation_result.decision == "passed":
262+
turn_summary.referenced_documents = deduplicate_referenced_documents(
263+
inline_rag_context.referenced_documents + turn_summary.referenced_documents
264+
)
265+
246266
response_media_type = (
247267
MEDIA_TYPE_TEXT
248268
if query_request.media_type == MEDIA_TYPE_TEXT
@@ -263,7 +283,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
263283
async def retrieve_response_generator(
264284
responses_params: ResponsesApiParams,
265285
context: ResponseGeneratorContext,
266-
inline_rag_documents: list[ReferencedDocument],
286+
inline_rag_docs: list[ReferencedDocument],
267287
) -> tuple[AsyncIterator[str], TurnSummary]:
268288
"""
269289
Retrieve the appropriate response generator.
@@ -275,40 +295,43 @@ async def retrieve_response_generator(
275295
Args:
276296
responses_params: The Responses API parameters
277297
context: The response generator context
278-
inline_rag_documents: Referenced documents from inline RAG (BYOK + Solr)
279-
298+
inline_rag_docs: Inline RAG (BYOK + Solr) documents
280299
Returns:
281300
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary
282301
283302
"""
284303
turn_summary = TurnSummary()
285304
try:
286-
moderation_result = await run_shield_moderation(
287-
context.client,
288-
prepare_input(context.query_request),
289-
context.query_request.shield_ids,
290-
)
291-
if moderation_result.decision == "blocked":
292-
turn_summary.llm_response = moderation_result.message
293-
await append_turn_to_conversation(
305+
if context.moderation_result.decision == "blocked":
306+
turn_summary.llm_response = context.moderation_result.message
307+
await append_turn_items_to_conversation(
294308
context.client,
295309
responses_params.conversation,
296-
cast(str, responses_params.input),
297-
moderation_result.message,
310+
responses_params.input,
311+
[context.moderation_result.refusal_response],
298312
)
299313
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
300314
return (
301-
shield_violation_generator(moderation_result.message, media_type),
315+
shield_violation_generator(
316+
context.moderation_result.message,
317+
media_type,
318+
),
302319
turn_summary,
303320
)
304321
# Retrieve response stream (may raise exceptions)
305322
response = await context.client.responses.create(
306323
**responses_params.model_dump(exclude_none=True)
307324
)
308325
# Store pre-RAG documents for later merging with tool-based RAG
309-
turn_summary.inline_rag_documents = inline_rag_documents
310-
return response_generator(response, context, turn_summary), turn_summary
311-
326+
return (
327+
response_generator(
328+
response,
329+
context,
330+
turn_summary,
331+
inline_rag_docs,
332+
),
333+
turn_summary,
334+
)
312335
# Handle know LLS client errors only at stream creation time and shield execution
313336
except RuntimeError as e: # library mode wraps 413 into runtime error
314337
if "context_length" in str(e).lower():
@@ -559,6 +582,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
559582
turn_response: AsyncIterator[OpenAIResponseObjectStream],
560583
context: ResponseGeneratorContext,
561584
turn_summary: TurnSummary,
585+
inline_rag_docs: list[ReferencedDocument],
562586
) -> AsyncIterator[str]:
563587
"""Generate SSE formatted streaming response.
564588
@@ -570,7 +594,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
570594
turn_response: The streaming response from Llama Stack
571595
context: The response generator context
572596
turn_summary: TurnSummary to populate during streaming
573-
597+
inline_rag_docs: Inline RAG (BYOK + Solr) documents
574598
Yields:
575599
SSE-formatted strings for tokens, tool calls, tool results,
576600
turn completion, and error events.
@@ -741,15 +765,15 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
741765
turn_summary.token_usage = extract_token_usage(
742766
latest_response_object.usage, context.model_id
743767
)
744-
tool_based_documents = parse_referenced_documents(
768+
# Parse tool-based referenced documents from the final response object
769+
tool_rag_docs = parse_referenced_documents(
745770
latest_response_object,
746771
vector_store_ids=context.vector_store_ids,
747772
rag_id_mapping=context.rag_id_mapping,
748773
)
749-
750-
# Merge pre-RAG documents with tool-based documents and deduplicate
774+
# Combine inline RAG results (BYOK + Solr) with tool-based results
751775
turn_summary.referenced_documents = deduplicate_referenced_documents(
752-
turn_summary.inline_rag_documents + tool_based_documents
776+
inline_rag_docs + tool_rag_docs
753777
)
754778

755779

src/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,5 @@
214214
# Environment variable to force StreamHandler instead of RichHandler
215215
# Set to any non-empty value to disable RichHandler
216216
LIGHTSPEED_STACK_DISABLE_RICH_HANDLER_ENV_VAR = "LIGHTSPEED_STACK_DISABLE_RICH_HANDLER"
217+
218+
DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions."

src/models/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from llama_stack_client import AsyncLlamaStackClient
55

66
from models.requests import QueryRequest
7+
from utils.types import ShieldModerationResult
78

89

910
@dataclass
@@ -23,6 +24,7 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes
2324
query_request: The query request object
2425
started_at: Timestamp when the request started (ISO 8601 format)
2526
client: The Llama Stack client for API interactions
27+
moderation_result: The moderation result
2628
vector_store_ids: Vector store IDs used in the query for source resolution.
2729
rag_id_mapping: Mapping from vector_db_id to user-facing rag_id.
2830
"""
@@ -42,6 +44,7 @@ class ResponseGeneratorContext: # pylint: disable=too-many-instance-attributes
4244

4345
# Dependencies & State
4446
client: AsyncLlamaStackClient
47+
moderation_result: ShieldModerationResult
4548

4649
# RAG index identification
4750
vector_store_ids: list[str] = field(default_factory=list)

src/models/requests.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ class QueryRequest(BaseModel):
179179
shield_ids: Optional[list[str]] = Field(
180180
None,
181181
description="Optional list of safety shield IDs to apply. "
182-
"If None, all configured shields are used. "
183-
"If provided, must contain at least one valid shield ID (empty list raises 422 error).",
182+
"If None, all configured shields are used. ",
184183
examples=["llama-guard", "custom-shield"],
185184
)
186185

0 commit comments

Comments
 (0)