@@ -109,6 +109,8 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker):
109109 mock_vector_stores = mocker .Mock ()
110110 mock_vector_stores .data = []
111111 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
112+ # Mock shields.list
113+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
112114
113115 # Ensure system prompt resolution does not require real config
114116 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
@@ -143,6 +145,8 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
143145 mock_vector_stores = mocker .Mock ()
144146 mock_vector_stores .data = [mocker .Mock (id = "dbA" )]
145147 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
148+ # Mock shields.list
149+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
146150
147151 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
148152 mock_cfg = mocker .Mock ()
@@ -207,6 +211,8 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker):
207211 mock_vector_stores = mocker .Mock ()
208212 mock_vector_stores .data = []
209213 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
214+ # Mock shields.list
215+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
210216
211217 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
212218 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -252,6 +258,8 @@ async def test_retrieve_response_with_usage_info(mocker):
252258 mock_vector_stores = mocker .Mock ()
253259 mock_vector_stores .data = []
254260 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
261+ # Mock shields.list
262+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
255263
256264 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
257265 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -289,6 +297,8 @@ async def test_retrieve_response_with_usage_dict(mocker):
289297 mock_vector_stores = mocker .Mock ()
290298 mock_vector_stores .data = []
291299 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
300+ # Mock shields.list
301+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
292302
293303 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
294304 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -326,6 +336,8 @@ async def test_retrieve_response_with_empty_usage_dict(mocker):
326336 mock_vector_stores = mocker .Mock ()
327337 mock_vector_stores .data = []
328338 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
339+ # Mock shields.list
340+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
329341
330342 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
331343 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -354,6 +366,8 @@ async def test_retrieve_response_validates_attachments(mocker):
354366 mock_vector_stores = mocker .Mock ()
355367 mock_vector_stores .data = []
356368 mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
369+ # Mock shields.list
370+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
357371
358372 mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
359373 mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
@@ -459,3 +473,177 @@ def _raise(*_args, **_kwargs):
459473 assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
460474 assert "Unable to connect to Llama Stack" in str (exc .value .detail )
461475 fail_metric .inc .assert_called_once ()
476+
477+
478+ @pytest .mark .asyncio
479+ async def test_retrieve_response_with_shields_available (mocker ):
480+ """Test that shields are listed and passed to responses API when available."""
481+ mock_client = mocker .Mock ()
482+
483+ # Mock shields.list to return available shields
484+ shield1 = mocker .Mock ()
485+ shield1 .identifier = "shield-1"
486+ shield2 = mocker .Mock ()
487+ shield2 .identifier = "shield-2"
488+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 , shield2 ])
489+
490+ output_item = mocker .Mock ()
491+ output_item .type = "message"
492+ output_item .role = "assistant"
493+ output_item .content = "Safe response"
494+
495+ response_obj = mocker .Mock ()
496+ response_obj .id = "resp-shields"
497+ response_obj .output = [output_item ]
498+ response_obj .usage = None
499+
500+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
501+ mock_vector_stores = mocker .Mock ()
502+ mock_vector_stores .data = []
503+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
504+
505+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
506+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
507+
508+ qr = QueryRequest (query = "hello" )
509+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
510+ mock_client , "model-shields" , qr , token = "tkn" , provider_id = "test-provider"
511+ )
512+
513+ assert conv_id == "resp-shields"
514+ assert summary .llm_response == "Safe response"
515+
516+ # Verify that shields were passed in extra_body
517+ kwargs = mock_client .responses .create .call_args .kwargs
518+ assert "extra_body" in kwargs
519+ assert "guardrails" in kwargs ["extra_body" ]
520+ assert kwargs ["extra_body" ]["guardrails" ] == ["shield-1" , "shield-2" ]
521+
522+
523+ @pytest .mark .asyncio
524+ async def test_retrieve_response_with_no_shields_available (mocker ):
525+ """Test that no extra_body is added when no shields are available."""
526+ mock_client = mocker .Mock ()
527+
528+ # Mock shields.list to return no shields
529+ mock_client .shields .list = mocker .AsyncMock (return_value = [])
530+
531+ output_item = mocker .Mock ()
532+ output_item .type = "message"
533+ output_item .role = "assistant"
534+ output_item .content = "Response without shields"
535+
536+ response_obj = mocker .Mock ()
537+ response_obj .id = "resp-no-shields"
538+ response_obj .output = [output_item ]
539+ response_obj .usage = None
540+
541+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
542+ mock_vector_stores = mocker .Mock ()
543+ mock_vector_stores .data = []
544+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
545+
546+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
547+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
548+
549+ qr = QueryRequest (query = "hello" )
550+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
551+ mock_client , "model-no-shields" , qr , token = "tkn" , provider_id = "test-provider"
552+ )
553+
554+ assert conv_id == "resp-no-shields"
555+ assert summary .llm_response == "Response without shields"
556+
557+ # Verify that no extra_body was added
558+ kwargs = mock_client .responses .create .call_args .kwargs
559+ assert "extra_body" not in kwargs
560+
561+
562+ @pytest .mark .asyncio
563+ async def test_retrieve_response_detects_shield_violation (mocker ):
564+ """Test that shield violations are detected and metrics are incremented."""
565+ mock_client = mocker .Mock ()
566+
567+ # Mock shields.list to return available shields
568+ shield1 = mocker .Mock ()
569+ shield1 .identifier = "safety-shield"
570+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
571+
572+ # Create output with shield violation (refusal)
573+ output_item = mocker .Mock ()
574+ output_item .type = "message"
575+ output_item .role = "assistant"
576+ output_item .content = "I cannot help with that request"
577+ output_item .refusal = "Content violates safety policy"
578+
579+ response_obj = mocker .Mock ()
580+ response_obj .id = "resp-violation"
581+ response_obj .output = [output_item ]
582+ response_obj .usage = None
583+
584+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
585+ mock_vector_stores = mocker .Mock ()
586+ mock_vector_stores .data = []
587+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
588+
589+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
590+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
591+
592+ # Mock the validation error metric
593+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
594+
595+ qr = QueryRequest (query = "dangerous query" )
596+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
597+ mock_client , "model-violation" , qr , token = "tkn" , provider_id = "test-provider"
598+ )
599+
600+ assert conv_id == "resp-violation"
601+ assert summary .llm_response == "I cannot help with that request"
602+
603+ # Verify that the validation error metric was incremented
604+ validation_metric .inc .assert_called_once ()
605+
606+
607+ @pytest .mark .asyncio
608+ async def test_retrieve_response_no_violation_with_shields (mocker ):
609+ """Test that no metric is incremented when there's no shield violation."""
610+ mock_client = mocker .Mock ()
611+
612+ # Mock shields.list to return available shields
613+ shield1 = mocker .Mock ()
614+ shield1 .identifier = "safety-shield"
615+ mock_client .shields .list = mocker .AsyncMock (return_value = [shield1 ])
616+
617+ # Create output without shield violation
618+ output_item = mocker .Mock ()
619+ output_item .type = "message"
620+ output_item .role = "assistant"
621+ output_item .content = "Safe response"
622+ output_item .refusal = None # No violation
623+
624+ response_obj = mocker .Mock ()
625+ response_obj .id = "resp-safe"
626+ response_obj .output = [output_item ]
627+ response_obj .usage = None
628+
629+ mock_client .responses .create = mocker .AsyncMock (return_value = response_obj )
630+ mock_vector_stores = mocker .Mock ()
631+ mock_vector_stores .data = []
632+ mock_client .vector_stores .list = mocker .AsyncMock (return_value = mock_vector_stores )
633+
634+ mocker .patch ("app.endpoints.query_v2.get_system_prompt" , return_value = "PROMPT" )
635+ mocker .patch ("app.endpoints.query_v2.configuration" , mocker .Mock (mcp_servers = []))
636+
637+ # Mock the validation error metric
638+ validation_metric = mocker .patch ("metrics.llm_calls_validation_errors_total" )
639+
640+ qr = QueryRequest (query = "safe query" )
641+ summary , conv_id , _referenced_docs , _token_usage = await retrieve_response (
642+ mock_client , "model-safe" , qr , token = "tkn" , provider_id = "test-provider"
643+ )
644+
645+ assert conv_id == "resp-safe"
646+ assert summary .llm_response == "Safe response"
647+
648+ # Verify that the validation error metric was NOT incremented
649+ validation_metric .inc .assert_not_called ()
0 commit comments