Skip to content

Commit 066b0e5

Browse files
committed
Shields impl using Moderations API
1 parent 3ac487c commit 066b0e5

20 files changed

Lines changed: 790 additions & 271 deletions

docker-compose-library.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ services:
1818
- TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-}
1919
# OpenAI
2020
- OPENAI_API_KEY=${OPENAI_API_KEY}
21-
- E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4-turbo}
21+
- E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini}
2222
# Azure
2323
- AZURE_API_KEY=${AZURE_API_KEY:-}
2424
# RHAIIS

docker-compose.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ services:
1616
- TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-}
1717
# OpenAI
1818
- OPENAI_API_KEY=${OPENAI_API_KEY}
19-
- E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL}
19+
- E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini}
2020
# Azure
2121
- AZURE_API_KEY=${AZURE_API_KEY}
2222
# RHAIIS

run.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,15 @@ storage:
131131
namespace: prompts
132132
backend: kv_default
133133
registered_resources:
134-
models: []
135-
shields: []
134+
models:
135+
- model_id: gpt-4o-mini
136+
provider_id: openai
137+
model_type: llm
138+
provider_model_id: gpt-4o-mini
139+
shields:
140+
- shield_id: llama-guard
141+
provider_id: llama-guard
142+
provider_shield_id: openai/gpt-4o-mini
136143
vector_dbs: []
137144
datasets: []
138145
scoring_fns: []

src/app/endpoints/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from typing import Annotated, Any, Optional, cast
99

1010
from fastapi import APIRouter, Depends, HTTPException, Request
11-
from litellm.exceptions import RateLimitError
1211
from llama_stack_client import (
1312
APIConnectionError,
1413
APIStatusError,
15-
AsyncLlamaStackClient, # type: ignore
14+
AsyncLlamaStackClient,
15+
RateLimitError, # type: ignore
1616
)
1717
from llama_stack_client.types import Shield, UserMessage # type: ignore
1818
from llama_stack_client.types.alpha.agents.turn import Turn

src/app/endpoints/query_v2.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from llama_stack.apis.agents.openai_responses import (
1111
OpenAIResponseObject,
1212
)
13-
from llama_stack_client import AsyncLlamaStackClient # type: ignore
13+
from llama_stack_client import AsyncLlamaStackClient
1414

1515
import metrics
1616
from app.endpoints.query import (
@@ -42,7 +42,10 @@
4242
)
4343
from utils.mcp_headers import mcp_headers_dependency
4444
from utils.responses import extract_text_from_response_output_item
45-
from utils.shields import detect_shield_violations, get_available_shields
45+
from utils.shields import (
46+
append_turn_to_conversation,
47+
run_shield_moderation,
48+
)
4649
from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id
4750
from utils.token_counter import TokenCounter
4851
from utils.types import RAGChunk, ToolCallSummary, ToolResultSummary, TurnSummary
@@ -322,9 +325,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
322325
and the conversation ID, the list of parsed referenced documents,
323326
and token usage information.
324327
"""
325-
# List available shields for Responses API
326-
available_shields = await get_available_shields(client)
327-
328328
# use system prompt from request or default one
329329
system_prompt = get_system_prompt(query_request, configuration)
330330
logger.debug("Using system prompt: %s", system_prompt)
@@ -370,6 +370,26 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
370370
conversation_id,
371371
)
372372

373+
# Run shield moderation before calling LLM
374+
moderation_result = await run_shield_moderation(client, input_text)
375+
if moderation_result.blocked:
376+
violation_message = moderation_result.message or ""
377+
await append_turn_to_conversation(
378+
client, llama_stack_conv_id, input_text, violation_message
379+
)
380+
summary = TurnSummary(
381+
llm_response=violation_message,
382+
tool_calls=[],
383+
tool_results=[],
384+
rag_chunks=[],
385+
)
386+
return (
387+
summary,
388+
normalize_conversation_id(conversation_id),
389+
[],
390+
TokenCounter(),
391+
)
392+
373393
# Create OpenAI response using responses API
374394
create_kwargs: dict[str, Any] = {
375395
"input": input_text,
@@ -381,10 +401,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
381401
"conversation": llama_stack_conv_id,
382402
}
383403

384-
# Add shields to extra_body if available
385-
if available_shields:
386-
create_kwargs["extra_body"] = {"guardrails": available_shields}
387-
388404
response = await client.responses.create(**create_kwargs)
389405
response = cast(OpenAIResponseObject, response)
390406
logger.info("Response: %s", response)
@@ -410,9 +426,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
410426
if tool_result:
411427
tool_results.append(tool_result)
412428

413-
# Check for shield violations across all output items
414-
detect_shield_violations(response.output)
415-
416429
logger.info(
417430
"Response processing complete - Tool calls: %d, Response length: %d chars",
418431
len(tool_calls),

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from fastapi import APIRouter, Depends, Request
1313
from fastapi.responses import StreamingResponse
14-
from litellm.exceptions import RateLimitError
1514
from llama_stack_client import (
1615
APIConnectionError,
17-
AsyncLlamaStackClient, # type: ignore
16+
AsyncLlamaStackClient,
17+
RateLimitError, # type: ignore
1818
)
1919
from llama_stack_client.types import UserMessage # type: ignore
2020
from llama_stack_client.types.alpha.agents.agent_turn_response_stream_chunk import (

src/app/endpoints/streaming_query_v2.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
from fastapi import APIRouter, Depends, Request
77
from fastapi.responses import StreamingResponse
88
from llama_stack.apis.agents.openai_responses import (
9+
OpenAIResponseContentPartOutputText,
10+
OpenAIResponseMessage,
911
OpenAIResponseObject,
1012
OpenAIResponseObjectStream,
13+
OpenAIResponseObjectStreamResponseCompleted,
14+
OpenAIResponseObjectStreamResponseContentPartAdded,
15+
OpenAIResponseObjectStreamResponseOutputTextDelta,
16+
OpenAIResponseOutputMessageContentOutputText,
1117
)
1218
from llama_stack_client import AsyncLlamaStackClient
1319

@@ -53,7 +59,10 @@
5359
from utils.quota import consume_tokens, get_available_quotas
5460
from utils.suid import normalize_conversation_id, to_llama_stack_conversation_id
5561
from utils.mcp_headers import mcp_headers_dependency
56-
from utils.shields import detect_shield_violations, get_available_shields
62+
from utils.shields import (
63+
append_turn_to_conversation,
64+
run_shield_moderation,
65+
)
5766
from utils.token_counter import TokenCounter
5867
from utils.transcripts import store_transcript
5968
from utils.types import ToolCallSummary, TurnSummary
@@ -234,12 +243,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
234243
# Capture the response object for token usage extraction
235244
latest_response_object = getattr(chunk, "response", None)
236245

237-
# Check for shield violations in the completed response
238-
if latest_response_object:
239-
output = getattr(latest_response_object, "output", None)
240-
if output is not None:
241-
detect_shield_violations(output)
242-
243246
if not emitted_turn_complete:
244247
final_message = summary.llm_response or "".join(text_parts)
245248
if not final_message:
@@ -394,9 +397,6 @@ async def retrieve_response( # pylint: disable=too-many-locals
394397
tuple: A tuple containing the streaming response object
395398
and the conversation ID.
396399
"""
397-
# List available shields for Responses API
398-
available_shields = await get_available_shields(client)
399-
400400
# use system prompt from request or default one
401401
system_prompt = get_system_prompt(query_request, configuration)
402402
logger.debug("Using system prompt: %s", system_prompt)
@@ -441,6 +441,18 @@ async def retrieve_response( # pylint: disable=too-many-locals
441441
conversation_id,
442442
)
443443

444+
# Run shield moderation before calling LLM
445+
moderation_result = await run_shield_moderation(client, input_text)
446+
if moderation_result.blocked:
447+
violation_message = moderation_result.message or ""
448+
await append_turn_to_conversation(
449+
client, llama_stack_conv_id, input_text, violation_message
450+
)
451+
return (
452+
create_violation_stream(violation_message, moderation_result.shield_model),
453+
normalize_conversation_id(conversation_id),
454+
)
455+
444456
create_params: dict[str, Any] = {
445457
"input": input_text,
446458
"model": model_id,
@@ -451,14 +463,55 @@ async def retrieve_response( # pylint: disable=too-many-locals
451463
"conversation": llama_stack_conv_id,
452464
}
453465

454-
# Add shields to extra_body if available
455-
if available_shields:
456-
create_params["extra_body"] = {"guardrails": available_shields}
457-
458466
response = await client.responses.create(**create_params)
459467
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
460-
# async for chunk in response_stream:
461-
# logger.error("Chunk: %s", chunk.model_dump_json())
462-
# Return the normalized conversation_id
463-
# The response_generator will emit it in the start event
468+
464469
return response_stream, normalize_conversation_id(conversation_id)
470+
471+
472+
async def create_violation_stream(
473+
message: str,
474+
shield_model: str | None = None,
475+
) -> AsyncIterator[OpenAIResponseObjectStream]:
476+
"""Create a minimal response stream for shield violations."""
477+
response_id = "resp_shield_violation"
478+
item_id = "msg_shield_violation"
479+
480+
# Content part added (triggers empty initial token)
481+
yield OpenAIResponseObjectStreamResponseContentPartAdded(
482+
content_index=0,
483+
response_id=response_id,
484+
item_id=item_id,
485+
output_index=0,
486+
part=OpenAIResponseContentPartOutputText(text=""),
487+
sequence_number=0,
488+
)
489+
490+
# Text delta
491+
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
492+
content_index=0,
493+
delta=message,
494+
item_id=item_id,
495+
output_index=0,
496+
sequence_number=1,
497+
)
498+
499+
# Completed response
500+
yield OpenAIResponseObjectStreamResponseCompleted(
501+
response=OpenAIResponseObject(
502+
id=response_id,
503+
created_at=0,
504+
model=shield_model or "shield",
505+
output=[
506+
OpenAIResponseMessage(
507+
id=item_id,
508+
content=[
509+
OpenAIResponseOutputMessageContentOutputText(text=message)
510+
],
511+
role="assistant",
512+
status="completed",
513+
)
514+
],
515+
status="completed",
516+
)
517+
)

src/utils/shields.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""Utility functions for working with Llama Stack shields."""
22

33
import logging
4-
from typing import Any
4+
from typing import Any, cast
55

6-
from llama_stack_client import AsyncLlamaStackClient
6+
from fastapi import HTTPException
7+
from llama_stack_client import AsyncLlamaStackClient, BadRequestError
8+
from llama_stack_client.types import CreateResponse
79

810
import metrics
11+
from models.responses import NotFoundResponse
12+
from utils.types import ShieldModerationResult
913

1014
logger = logging.getLogger(__name__)
1115

16+
DEFAULT_VIOLATION_MESSAGE = "I cannot process this request due to policy restrictions."
17+
1218

1319
async def get_available_shields(client: AsyncLlamaStackClient) -> list[str]:
1420
"""
@@ -52,3 +58,100 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
5258
logger.warning("Shield violation detected: %s", refusal)
5359
return True
5460
return False
61+
62+
63+
async def run_shield_moderation(
64+
client: AsyncLlamaStackClient,
65+
input_text: str,
66+
) -> ShieldModerationResult:
67+
"""
68+
Run shield moderation on input text.
69+
70+
Iterates through all configured shields and runs moderation checks.
71+
Raises HTTPException if shield model is not found.
72+
73+
Parameters:
74+
client: The Llama Stack client.
75+
input_text: The text to moderate.
76+
77+
Returns:
78+
ShieldModerationResult: Result indicating if content was blocked and the message.
79+
80+
Raises:
81+
HTTPException: If shield's provider_resource_id is not configured or model not found.
82+
"""
83+
available_models = {model.identifier for model in await client.models.list()}
84+
85+
for shield in await client.shields.list():
86+
if (
87+
not shield.provider_resource_id
88+
or shield.provider_resource_id not in available_models
89+
):
90+
response = NotFoundResponse(
91+
resource="Shield model", resource_id=shield.provider_resource_id or ""
92+
)
93+
raise HTTPException(**response.model_dump())
94+
95+
try:
96+
moderation = await client.moderations.create(
97+
input=input_text, model=shield.provider_resource_id
98+
)
99+
moderation_result = cast(CreateResponse, moderation)
100+
101+
if moderation_result.results and moderation_result.results[0].flagged:
102+
result = moderation_result.results[0]
103+
metrics.llm_calls_validation_errors_total.inc()
104+
logger.warning(
105+
"Shield '%s' flagged content: categories=%s",
106+
shield.identifier,
107+
result.categories,
108+
)
109+
violation_message = result.user_message or DEFAULT_VIOLATION_MESSAGE
110+
return ShieldModerationResult(
111+
blocked=True,
112+
message=violation_message,
113+
shield_model=shield.provider_resource_id,
114+
)
115+
116+
# Known Llama Stack bug: BadRequestError is raised when violation is present
117+
# in the shield LLM response but has wrong format that cannot be parsed.
118+
except BadRequestError:
119+
logger.warning(
120+
"Shield '%s' returned BadRequestError, treating as blocked",
121+
shield.identifier,
122+
)
123+
metrics.llm_calls_validation_errors_total.inc()
124+
return ShieldModerationResult(
125+
blocked=True,
126+
message=DEFAULT_VIOLATION_MESSAGE,
127+
shield_model=shield.provider_resource_id,
128+
)
129+
130+
return ShieldModerationResult(blocked=False)
131+
132+
133+
async def append_turn_to_conversation(
134+
client: AsyncLlamaStackClient,
135+
conversation_id: str,
136+
user_message: str,
137+
assistant_message: str,
138+
) -> None:
139+
"""
140+
Append a user/assistant turn to a conversation after shield violation.
141+
142+
Used to record the conversation turn when a shield blocks the request,
143+
storing both the user's original message and the violation response.
144+
145+
Parameters:
146+
client: The Llama Stack client.
147+
conversation_id: The Llama Stack conversation ID.
148+
user_message: The user's input message.
149+
assistant_message: The shield violation response message.
150+
"""
151+
await client.conversations.items.create(
152+
conversation_id,
153+
items=[
154+
{"type": "message", "role": "user", "content": user_message},
155+
{"type": "message", "role": "assistant", "content": assistant_message},
156+
],
157+
)

0 commit comments

Comments
 (0)