|
32 | 32 | from authorization.middleware import authorize |
33 | 33 | from configuration import configuration |
34 | 34 | from constants import MEDIA_TYPE_JSON |
35 | | -import metrics |
36 | 35 | from models.config import Action |
37 | 36 | from models.context import ResponseGeneratorContext |
38 | 37 | from models.requests import QueryRequest |
|
42 | 41 | get_system_prompt, |
43 | 42 | ) |
44 | 43 | from utils.mcp_headers import mcp_headers_dependency |
| 44 | +from utils.shields import detect_shield_violations, get_available_shields |
45 | 45 | from utils.token_counter import TokenCounter |
46 | 46 | from utils.transcripts import store_transcript |
47 | 47 | from utils.types import TurnSummary, ToolCallSummary |
@@ -247,14 +247,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat |
247 | 247 |
|
248 | 248 | # Check for shield violations in the completed response |
249 | 249 | if latest_response_object: |
250 | | - for output_item in getattr(latest_response_object, "output", []): |
251 | | - item_type = getattr(output_item, "type", None) |
252 | | - if item_type == "message": |
253 | | - refusal = getattr(output_item, "refusal", None) |
254 | | - if refusal: |
255 | | - # Metric for LLM validation errors (shield violations) |
256 | | - metrics.llm_calls_validation_errors_total.inc() |
257 | | - logger.warning("Shield violation detected: %s", refusal) |
| 250 | + detect_shield_violations( |
| 251 | + getattr(latest_response_object, "output", []) |
| 252 | + ) |
258 | 253 |
|
259 | 254 | if not emitted_turn_complete: |
260 | 255 | final_message = summary.llm_response or "".join(text_parts) |
@@ -379,11 +374,7 @@ async def retrieve_response( |
379 | 374 | and the conversation ID. |
380 | 375 | """ |
381 | 376 | # List available shields for Responses API |
382 | | - available_shields = [shield.identifier for shield in await client.shields.list()] |
383 | | - if not available_shields: |
384 | | - logger.info("No available shields. Disabling safety") |
385 | | - else: |
386 | | - logger.info("Available shields: %s", available_shields) |
| 377 | + available_shields = await get_available_shields(client) |
387 | 378 |
|
388 | 379 | # use system prompt from request or default one |
389 | 380 | system_prompt = get_system_prompt(query_request, configuration) |
|
0 commit comments