Skip to content

Commit fb25b8c

Browse files
committed
Add shields support to the responses API implementation
It includes both streaming and not streaming support, by leveraging the refusal field on the response
1 parent a1b6f9c commit fb25b8c

4 files changed

Lines changed: 534 additions & 9 deletions

File tree

src/app/endpoints/query_v2.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
317317
given query, handling shield configuration, tool usage, and
318318
attachment validation.
319319
320-
This function configures system prompts and toolgroups
320+
This function configures system prompts, shields, and toolgroups
321321
(including RAG and MCP integration) as needed based on
322322
the query request and system configuration. It
323323
validates attachments, manages conversation and session
@@ -337,8 +337,12 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
337337
and the conversation ID, the list of parsed referenced documents,
338338
and token usage information.
339339
"""
340-
# TODO(ltomasbo): implement shields support once available in Responses API
341-
logger.info("Shields are not yet supported in Responses API. Disabling safety")
340+
# List available shields for Responses API
341+
available_shields = [shield.identifier for shield in await client.shields.list()]
342+
if not available_shields:
343+
logger.info("No available shields. Disabling safety")
344+
else:
345+
logger.info("Available shields: %s", available_shields)
342346

343347
# use system prompt from request or default one
344348
system_prompt = get_system_prompt(query_request, configuration)
@@ -376,6 +380,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
376380
if query_request.conversation_id:
377381
create_kwargs["previous_response_id"] = query_request.conversation_id
378382

383+
# Add shields to extra_body if available
384+
if available_shields:
385+
create_kwargs["extra_body"] = {"guardrails": available_shields}
386+
379387
response = await client.responses.create(**create_kwargs)
380388
response = cast(OpenAIResponseObject, response)
381389

@@ -401,6 +409,15 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
401409
if tool_summary:
402410
tool_calls.append(tool_summary)
403411

412+
# Check for shield violations
413+
item_type = getattr(output_item, "type", None)
414+
if item_type == "message":
415+
refusal = getattr(output_item, "refusal", None)
416+
if refusal:
417+
# Metric for LLM validation errors (shield violations)
418+
metrics.llm_calls_validation_errors_total.inc()
419+
logger.warning("Shield violation detected: %s", refusal)
420+
404421
logger.info(
405422
"Response processing complete - Tool calls: %d, Response length: %d chars",
406423
len(tool_calls),

src/app/endpoints/streaming_query_v2.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from authorization.middleware import authorize
3535
from configuration import configuration
3636
from constants import MEDIA_TYPE_JSON
37+
import metrics
3738
from models.cache_entry import CacheEntry
3839
from models.config import Action
3940
from models.context import ResponseGeneratorContext
@@ -247,6 +248,18 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
247248
elif event_type == "response.completed":
248249
# Capture the response object for token usage extraction
249250
latest_response_object = getattr(chunk, "response", None)
251+
252+
# Check for shield violations in the completed response
253+
if latest_response_object:
254+
for output_item in getattr(latest_response_object, "output", []):
255+
item_type = getattr(output_item, "type", None)
256+
if item_type == "message":
257+
refusal = getattr(output_item, "refusal", None)
258+
if refusal:
259+
# Metric for LLM validation errors (shield violations)
260+
metrics.llm_calls_validation_errors_total.inc()
261+
logger.warning("Shield violation detected: %s", refusal)
262+
250263
if not emitted_turn_complete:
251264
final_message = summary.llm_response or "".join(text_parts)
252265
yield format_stream_data(
@@ -394,11 +407,11 @@ async def retrieve_response(
394407
Asynchronously retrieves a streaming response and conversation
395408
ID from the Llama Stack agent for a given user query.
396409
397-
This function configures input/output shields, system prompt,
398-
and tool usage based on the request and environment. It
399-
prepares the agent with appropriate headers and toolgroups,
400-
validates attachments if present, and initiates a streaming
401-
turn with the user's query and any provided documents.
410+
This function configures shields, system prompt, and tool usage
411+
based on the request and environment. It prepares the agent with
412+
appropriate headers and toolgroups, validates attachments if
413+
present, and initiates a streaming turn with the user's query
414+
and any provided documents.
402415
403416
Parameters:
404417
model_id (str): Identifier of the model to use for the query.
@@ -411,7 +424,12 @@ async def retrieve_response(
411424
tuple: A tuple containing the streaming response object
412425
and the conversation ID.
413426
"""
414-
logger.info("Shields are not yet supported in Responses API.")
427+
# List available shields for Responses API
428+
available_shields = [shield.identifier for shield in await client.shields.list()]
429+
if not available_shields:
430+
logger.info("No available shields. Disabling safety")
431+
else:
432+
logger.info("Available shields: %s", available_shields)
415433

416434
# use system prompt from request or default one
417435
system_prompt = get_system_prompt(query_request, configuration)
@@ -448,6 +466,10 @@ async def retrieve_response(
448466
if query_request.conversation_id:
449467
create_params["previous_response_id"] = query_request.conversation_id
450468

469+
# Add shields to extra_body if available
470+
if available_shields:
471+
create_params["extra_body"] = {"guardrails": available_shields}
472+
451473
response = await client.responses.create(**create_params)
452474
response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response)
453475

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker):
109109
mock_vector_stores = mocker.Mock()
110110
mock_vector_stores.data = []
111111
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
112+
# Mock shields.list
113+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
112114

113115
# Ensure system prompt resolution does not require real config
114116
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
@@ -143,6 +145,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
143145
mock_vector_stores = mocker.Mock()
144146
mock_vector_stores.data = [mocker.Mock(id="dbA")]
145147
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
148+
# Mock shields.list
149+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
146150

147151
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
148152
mock_cfg = mocker.Mock()
@@ -207,6 +211,8 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker):
207211
mock_vector_stores = mocker.Mock()
208212
mock_vector_stores.data = []
209213
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
214+
# Mock shields.list
215+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
210216

211217
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
212218
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -252,6 +258,8 @@ async def test_retrieve_response_with_usage_info(mocker):
252258
mock_vector_stores = mocker.Mock()
253259
mock_vector_stores.data = []
254260
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
261+
# Mock shields.list
262+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
255263

256264
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
257265
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -289,6 +297,8 @@ async def test_retrieve_response_with_usage_dict(mocker):
289297
mock_vector_stores = mocker.Mock()
290298
mock_vector_stores.data = []
291299
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
300+
# Mock shields.list
301+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
292302

293303
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
294304
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -326,6 +336,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker):
326336
mock_vector_stores = mocker.Mock()
327337
mock_vector_stores.data = []
328338
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
339+
# Mock shields.list
340+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
329341

330342
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
331343
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -354,6 +366,8 @@ async def test_retrieve_response_validates_attachments(mocker):
354366
mock_vector_stores = mocker.Mock()
355367
mock_vector_stores.data = []
356368
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
369+
# Mock shields.list
370+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
357371

358372
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
359373
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
@@ -459,3 +473,177 @@ def _raise(*_args, **_kwargs):
459473
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
460474
assert "Unable to connect to Llama Stack" in str(exc.value.detail)
461475
fail_metric.inc.assert_called_once()
476+
477+
478+
@pytest.mark.asyncio
479+
async def test_retrieve_response_with_shields_available(mocker):
480+
"""Test that shields are listed and passed to responses API when available."""
481+
mock_client = mocker.Mock()
482+
483+
# Mock shields.list to return available shields
484+
shield1 = mocker.Mock()
485+
shield1.identifier = "shield-1"
486+
shield2 = mocker.Mock()
487+
shield2.identifier = "shield-2"
488+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2])
489+
490+
output_item = mocker.Mock()
491+
output_item.type = "message"
492+
output_item.role = "assistant"
493+
output_item.content = "Safe response"
494+
495+
response_obj = mocker.Mock()
496+
response_obj.id = "resp-shields"
497+
response_obj.output = [output_item]
498+
response_obj.usage = None
499+
500+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
501+
mock_vector_stores = mocker.Mock()
502+
mock_vector_stores.data = []
503+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
504+
505+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
506+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
507+
508+
qr = QueryRequest(query="hello")
509+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
510+
mock_client, "model-shields", qr, token="tkn", provider_id="test-provider"
511+
)
512+
513+
assert conv_id == "resp-shields"
514+
assert summary.llm_response == "Safe response"
515+
516+
# Verify that shields were passed in extra_body
517+
kwargs = mock_client.responses.create.call_args.kwargs
518+
assert "extra_body" in kwargs
519+
assert "guardrails" in kwargs["extra_body"]
520+
assert kwargs["extra_body"]["guardrails"] == ["shield-1", "shield-2"]
521+
522+
523+
@pytest.mark.asyncio
524+
async def test_retrieve_response_with_no_shields_available(mocker):
525+
"""Test that no extra_body is added when no shields are available."""
526+
mock_client = mocker.Mock()
527+
528+
# Mock shields.list to return no shields
529+
mock_client.shields.list = mocker.AsyncMock(return_value=[])
530+
531+
output_item = mocker.Mock()
532+
output_item.type = "message"
533+
output_item.role = "assistant"
534+
output_item.content = "Response without shields"
535+
536+
response_obj = mocker.Mock()
537+
response_obj.id = "resp-no-shields"
538+
response_obj.output = [output_item]
539+
response_obj.usage = None
540+
541+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
542+
mock_vector_stores = mocker.Mock()
543+
mock_vector_stores.data = []
544+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
545+
546+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
547+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
548+
549+
qr = QueryRequest(query="hello")
550+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
551+
mock_client, "model-no-shields", qr, token="tkn", provider_id="test-provider"
552+
)
553+
554+
assert conv_id == "resp-no-shields"
555+
assert summary.llm_response == "Response without shields"
556+
557+
# Verify that no extra_body was added
558+
kwargs = mock_client.responses.create.call_args.kwargs
559+
assert "extra_body" not in kwargs
560+
561+
562+
@pytest.mark.asyncio
563+
async def test_retrieve_response_detects_shield_violation(mocker):
564+
"""Test that shield violations are detected and metrics are incremented."""
565+
mock_client = mocker.Mock()
566+
567+
# Mock shields.list to return available shields
568+
shield1 = mocker.Mock()
569+
shield1.identifier = "safety-shield"
570+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1])
571+
572+
# Create output with shield violation (refusal)
573+
output_item = mocker.Mock()
574+
output_item.type = "message"
575+
output_item.role = "assistant"
576+
output_item.content = "I cannot help with that request"
577+
output_item.refusal = "Content violates safety policy"
578+
579+
response_obj = mocker.Mock()
580+
response_obj.id = "resp-violation"
581+
response_obj.output = [output_item]
582+
response_obj.usage = None
583+
584+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
585+
mock_vector_stores = mocker.Mock()
586+
mock_vector_stores.data = []
587+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
588+
589+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
590+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
591+
592+
# Mock the validation error metric
593+
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")
594+
595+
qr = QueryRequest(query="dangerous query")
596+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
597+
mock_client, "model-violation", qr, token="tkn", provider_id="test-provider"
598+
)
599+
600+
assert conv_id == "resp-violation"
601+
assert summary.llm_response == "I cannot help with that request"
602+
603+
# Verify that the validation error metric was incremented
604+
validation_metric.inc.assert_called_once()
605+
606+
607+
@pytest.mark.asyncio
608+
async def test_retrieve_response_no_violation_with_shields(mocker):
609+
"""Test that no metric is incremented when there's no shield violation."""
610+
mock_client = mocker.Mock()
611+
612+
# Mock shields.list to return available shields
613+
shield1 = mocker.Mock()
614+
shield1.identifier = "safety-shield"
615+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1])
616+
617+
# Create output without shield violation
618+
output_item = mocker.Mock()
619+
output_item.type = "message"
620+
output_item.role = "assistant"
621+
output_item.content = "Safe response"
622+
output_item.refusal = None # No violation
623+
624+
response_obj = mocker.Mock()
625+
response_obj.id = "resp-safe"
626+
response_obj.output = [output_item]
627+
response_obj.usage = None
628+
629+
mock_client.responses.create = mocker.AsyncMock(return_value=response_obj)
630+
mock_vector_stores = mocker.Mock()
631+
mock_vector_stores.data = []
632+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_vector_stores)
633+
634+
mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT")
635+
mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[]))
636+
637+
# Mock the validation error metric
638+
validation_metric = mocker.patch("metrics.llm_calls_validation_errors_total")
639+
640+
qr = QueryRequest(query="safe query")
641+
summary, conv_id, _referenced_docs, _token_usage = await retrieve_response(
642+
mock_client, "model-safe", qr, token="tkn", provider_id="test-provider"
643+
)
644+
645+
assert conv_id == "resp-safe"
646+
assert summary.llm_response == "Safe response"
647+
648+
# Verify that the validation error metric was NOT incremented
649+
validation_metric.inc.assert_not_called()

0 commit comments

Comments
 (0)