Skip to content

Commit 3d09842

Browse files
committed
refactor: ensure consistency across detector APIs and reduce code duplication
- Add SYSTEM_ERROR_LABELS constant to base.py (single source of truth) - Add SERVER_ERROR handling for HTTP 5xx responses in detections_api.py - Add config incomplete checks before accessing rails.config in actions.py - Add detections_api_generate_block_message action for dynamic block messages - Fix env var names in tests (DETECTOR_API_* instead of DETECTIONS_API_*) - Update test assertion for SERVER_ERROR label on HTTP 500 All tests passing (109 tests)
1 parent 22fb78a commit 3d09842

5 files changed

Lines changed: 81 additions & 32 deletions

File tree

nemoguardrails/library/detector_clients/actions.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,15 @@
2222
from typing import Any, Dict, Optional
2323

2424
from nemoguardrails.actions import action
25-
from nemoguardrails.library.detector_clients.base import AggregatedDetectorResult, DetectorResult
25+
from nemoguardrails.library.detector_clients.base import (
26+
SYSTEM_ERROR_LABELS,
27+
AggregatedDetectorResult,
28+
DetectorResult,
29+
)
2630
from nemoguardrails.library.detector_clients.detections_api import DetectionsAPIClient
2731

2832
log = logging.getLogger(__name__)
2933

30-
""" System error labels indicate infrastructure/configuration issues,
31-
not content violations. Detectors with these labels failed to execute
32-
properly and should be treated as unavailable. """
33-
SYSTEM_ERROR_LABELS = {
34-
"ERROR",
35-
"HTTP_ERROR",
36-
"TIMEOUT",
37-
"NOT_FOUND",
38-
"VALIDATION_ERROR",
39-
"INVALID_RESPONSE",
40-
"CONFIG_ERROR",
41-
}
42-
4334

4435
async def _run_detections_api_detector(detector_name: str, detector_config: Any, text: str) -> DetectorResult:
4536
"""
@@ -299,3 +290,43 @@ async def detections_api_check_detector(
299290
log.info(f"Detections API {detector_name}: {'allowed' if result.allowed else 'blocked'} (score={result.score:.3f})")
300291

301292
return result.dict()
293+
294+
295+
@action()
296+
async def detections_api_generate_block_message(context: Optional[Dict] = None, **kwargs) -> str:
297+
"""
298+
Generate detailed block message with detector information.
299+
300+
Creates user-friendly messages explaining why content was blocked.
301+
Prioritizes system errors over content violations.
302+
303+
Args:
304+
context: NeMo context containing input_result from detector checks
305+
**kwargs: Additional arguments (ignored)
306+
307+
Returns:
308+
Human-readable block message string
309+
"""
310+
if context is None:
311+
return "Input blocked due to content policy violation."
312+
313+
input_result = context.get("input_result", {})
314+
315+
# Check for system errors first
316+
unavailable = input_result.get("unavailable_detectors", [])
317+
if unavailable:
318+
return f"Service temporarily unavailable. Detector(s) not reachable: {', '.join(unavailable)}"
319+
320+
# Check for content blocks
321+
blocking = input_result.get("blocking_detectors", [])
322+
if not blocking:
323+
return "Input blocked due to content policy violation."
324+
325+
# Single detector blocked
326+
if len(blocking) == 1:
327+
det = blocking[0]
328+
return f"Input blocked by {det['detector']} detector (score: {det['score']:.2f})"
329+
330+
# Multiple detectors blocked
331+
detector_names = [d["detector"] for d in blocking]
332+
return f"Input blocked by {len(blocking)} detectors: {', '.join(detector_names)}"

nemoguardrails/library/detector_clients/base.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@
3434
_http_session: Optional[aiohttp.ClientSession] = None
3535
_session_lock = asyncio.Lock()
3636

37+
# System error labels indicate infrastructure/configuration issues,
38+
# not content violations. Detectors with these labels failed to execute
39+
# properly and should be treated as unavailable.
40+
SYSTEM_ERROR_LABELS = {
41+
"ERROR",
42+
"HTTP_ERROR",
43+
"TIMEOUT",
44+
"NOT_FOUND",
45+
"VALIDATION_ERROR",
46+
"SERVER_ERROR",
47+
"INVALID_RESPONSE",
48+
"CONFIG_ERROR",
49+
"CONFIG_INCOMPLETE",
50+
}
51+
3752

3853
class DetectorResult(BaseModel):
3954
"""Standardized result from detector execution"""
@@ -132,14 +147,14 @@ def _get_ssl_context(self) -> Union[ssl.SSLContext, bool, None]:
132147
deployment environments (development, staging, production).
133148
134149
Priority order:
135-
1. Custom CA certificate file (if DETECTIONS_API_CA_CERT is set)
136-
2. SSL verification toggle (if DETECTIONS_API_VERIFY_SSL is set)
150+
1. Custom CA certificate file (if DETECTOR_API_CA_CERT is set)
151+
2. SSL verification toggle (if DETECTOR_API_VERIFY_SSL is set)
137152
3. Default system CA certificates
138153
139154
Environment Variables:
140-
DETECTIONS_API_CA_CERT: Path to custom CA certificate file (PEM format)
155+
DETECTOR_API_CA_CERT: Path to custom CA certificate file (PEM format)
141156
Common in Kubernetes/OpenShift with mounted secrets
142-
DETECTIONS_API_VERIFY_SSL: Set to "false" to disable SSL verification
157+
DETECTOR_API_VERIFY_SSL: Set to "false" to disable SSL verification
143158
WARNING: Only for development/testing!
144159
145160
Returns:
@@ -148,17 +163,17 @@ def _get_ssl_context(self) -> Union[ssl.SSLContext, bool, None]:
148163
None: Use default system CA certificates
149164
"""
150165
# Check for custom CA certificate file (Kubernetes secret volume)
151-
ca_cert_file = os.getenv("DETECTIONS_API_CA_CERT")
166+
ca_cert_file = os.getenv("DETECTOR_API_CA_CERT")
152167
if ca_cert_file and os.path.exists(ca_cert_file):
153168
ssl_context = ssl.create_default_context(cafile=ca_cert_file)
154169
log.info(f"Using custom CA certificate from {ca_cert_file}")
155170
return ssl_context
156171

157172
# Option to disable SSL verification (development/testing only)
158-
verify_ssl = os.getenv("DETECTIONS_API_VERIFY_SSL", "true").lower()
173+
verify_ssl = os.getenv("DETECTOR_API_VERIFY_SSL", "true").lower()
159174
if verify_ssl == "false":
160175
log.warning(
161-
"SSL verification disabled via DETECTIONS_API_VERIFY_SSL=false. "
176+
"SSL verification disabled via DETECTOR_API_VERIFY_SSL=false. "
162177
"This is NOT recommended for production environments!"
163178
)
164179
return False
@@ -206,13 +221,13 @@ async def _call_endpoint(
206221
token = self.api_key
207222
else:
208223
# Check for file-based secret (Kubernetes volume mount)
209-
secret_file = os.getenv("DETECTIONS_API_KEY_FILE")
224+
secret_file = os.getenv("DETECTOR_API_KEY_FILE")
210225
if secret_file and os.path.exists(secret_file):
211226
with open(secret_file, "r") as f:
212227
token = f.read().strip()
213228
else:
214229
# Fallback to environment variable
215-
token = os.getenv("DETECTIONS_API_KEY")
230+
token = os.getenv("DETECTOR_API_KEY")
216231
if token:
217232
request_headers["Authorization"] = f"Bearer {token}"
218233

nemoguardrails/library/detector_clients/detections_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def parse_response(self, response: Any, http_status: int) -> DetectorResult:
115115
elif http_status == 422:
116116
label = "VALIDATION_ERROR"
117117
reason = f"Invalid request to {self.detector_name}"
118+
elif http_status >= 500:
119+
label = "SERVER_ERROR"
120+
reason = f"Detections API server error (HTTP {http_status})"
118121
else:
119122
label = "ERROR"
120123
reason = f"HTTP {http_status} error from {self.detector_name}"

tests/test_detector_clients_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ async def test_post_request_with_env_api_key(self):
486486
mock_session.post = Mock(return_value=mock_post_cm)
487487

488488
with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session):
489-
with patch.dict("os.environ", {"DETECTIONS_API_KEY": "env-key-456"}):
489+
with patch.dict("os.environ", {"DETECTOR_API_KEY": "env-key-456"}):
490490
await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30)
491491

492492
# Verify env var key used
@@ -508,8 +508,8 @@ def test_ssl_context_with_custom_ca_cert(self, tmp_path):
508508
client = ConcreteDetectorClient(mock_config, "test-detector")
509509

510510
# Mock ssl.create_default_context to avoid needing valid cert
511-
with patch.dict("os.environ", {"DETECTIONS_API_CA_CERT": str(ca_cert_file)}):
512-
with patch("ssl.create_default_context") as mock_ssl:
511+
with patch.dict("os.environ", {"DETECTOR_API_CA_CERT": str(ca_cert_file)}):
512+
with patch("nemoguardrails.library.detector_clients.base.ssl.create_default_context") as mock_ssl:
513513
mock_ssl_context = Mock(spec=ssl.SSLContext)
514514
mock_ssl.return_value = mock_ssl_context
515515

@@ -525,7 +525,7 @@ def test_ssl_context_with_nonexistent_ca_cert_file(self):
525525

526526
client = ConcreteDetectorClient(mock_config, "test-detector")
527527

528-
with patch.dict("os.environ", {"DETECTIONS_API_CA_CERT": "/nonexistent/path/ca-cert.pem"}):
528+
with patch.dict("os.environ", {"DETECTOR_API_CA_CERT": "/nonexistent/path/ca-cert.pem"}):
529529
ssl_context = client._get_ssl_context()
530530

531531
# Should fall through to default behavior (None)
@@ -537,7 +537,7 @@ def test_ssl_verification_disabled(self):
537537

538538
client = ConcreteDetectorClient(mock_config, "test-detector")
539539

540-
with patch.dict("os.environ", {"DETECTIONS_API_VERIFY_SSL": "false"}):
540+
with patch.dict("os.environ", {"DETECTOR_API_VERIFY_SSL": "false"}):
541541
ssl_context = client._get_ssl_context()
542542

543543
# Should return False to disable verification
@@ -584,7 +584,7 @@ async def test_api_key_from_file(self, tmp_path):
584584
mock_session.post = Mock(return_value=mock_post_cm)
585585

586586
with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session):
587-
with patch.dict("os.environ", {"DETECTIONS_API_KEY_FILE": str(api_key_file)}):
587+
with patch.dict("os.environ", {"DETECTOR_API_KEY_FILE": str(api_key_file)}):
588588
await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30)
589589

590590
# Verify Authorization header used file-based key
@@ -612,7 +612,7 @@ async def test_api_key_file_not_exists(self):
612612
with patch("nemoguardrails.library.detector_clients.base._http_session", mock_session):
613613
with patch.dict(
614614
"os.environ",
615-
{"DETECTIONS_API_KEY_FILE": "/nonexistent/api-key", "DETECTIONS_API_KEY": "env-var-key-123"},
615+
{"DETECTOR_API_KEY_FILE": "/nonexistent/api-key", "DETECTOR_API_KEY": "env-var-key-123"},
616616
):
617617
await client._call_endpoint(endpoint="http://test.com/api", payload={"text": "test"}, timeout=30)
618618

tests/test_detector_clients_detections_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_parse_response_http_422(self):
335335
assert result.metadata["http_status"] == 422
336336

337337
def test_parse_response_http_500(self):
338-
"""Test HTTP 500 returns ERROR"""
338+
"""Test HTTP 500 returns SERVER_ERROR"""
339339
mock_config = Mock()
340340
mock_config.inference_endpoint = "http://test.com"
341341
mock_config.detector_id = "test-id"
@@ -345,7 +345,7 @@ def test_parse_response_http_500(self):
345345
result = client.parse_response({}, 500)
346346

347347
assert result.allowed is False
348-
assert result.label == "ERROR"
348+
assert result.label == "SERVER_ERROR"
349349
assert "HTTP 500" in result.reason
350350
assert result.metadata["http_status"] == 500
351351

0 commit comments

Comments
 (0)