Skip to content

Commit 9e06a51

Browse files
committed
refactor: extract _check_shield_moderation helper, fix integration tests
Move the moderation logic out of infer_endpoint into a private helper to avoid increasing the function's cyclomatic complexity (stays at C 13). Add shield moderation mock to the integration tests so they don't hit the real run_shield_moderation path with non-async client mocks. Signed-off-by: Major Hayden <major@redhat.com>
1 parent b8cb192 commit 9e06a51

2 files changed

Lines changed: 72 additions & 30 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,63 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
319319
background_tasks.add_task(send_splunk_event, event, sourcetype)
320320

321321

322+
async def _check_shield_moderation(
323+
input_text: str,
324+
request_id: str,
325+
background_tasks: BackgroundTasks,
326+
infer_request: RlsapiV1InferRequest,
327+
request: Request,
328+
) -> Optional[RlsapiV1InferResponse]:
329+
"""Run shield moderation and return a refusal response if blocked.
330+
331+
Uses all configured shields in Llama Stack. When no shields are
332+
registered, moderation is a no-op and returns None immediately.
333+
334+
Args:
335+
input_text: The combined user input to moderate.
336+
request_id: Unique identifier for the request.
337+
background_tasks: FastAPI background tasks for async Splunk event sending.
338+
infer_request: The original inference request (for Splunk event context).
339+
request: The FastAPI request object (for Splunk event context).
340+
341+
Returns:
342+
An RlsapiV1InferResponse containing the refusal message if the input
343+
was blocked, or None if moderation passed.
344+
"""
345+
client = AsyncLlamaStackClientHolder().get_client()
346+
moderation_result = await run_shield_moderation(client, input_text)
347+
348+
if moderation_result.decision != "blocked":
349+
return None
350+
351+
logger.info(
352+
"Request %s blocked by shield moderation: %s",
353+
request_id,
354+
moderation_result.message,
355+
)
356+
_queue_splunk_event(
357+
background_tasks,
358+
infer_request,
359+
request,
360+
request_id,
361+
moderation_result.message,
362+
0.0,
363+
"infer_shield_blocked",
364+
)
365+
return RlsapiV1InferResponse(
366+
data=RlsapiV1InferData(
367+
text=moderation_result.message,
368+
request_id=request_id,
369+
tool_calls=None,
370+
tool_results=None,
371+
rag_chunks=None,
372+
referenced_documents=None,
373+
input_tokens=None,
374+
output_tokens=None,
375+
)
376+
)
377+
378+
322379
def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
323380
background_tasks: BackgroundTasks,
324381
infer_request: RlsapiV1InferRequest,
@@ -451,36 +508,11 @@ async def infer_endpoint( # pylint: disable=R0914
451508

452509
# Run shield moderation on user input before inference.
453510
# Uses all configured shields; no-op when no shields are registered.
454-
client = AsyncLlamaStackClientHolder().get_client()
455-
moderation_result = await run_shield_moderation(client, input_source)
456-
457-
if moderation_result.decision == "blocked":
458-
logger.info(
459-
"Request %s blocked by shield moderation: %s",
460-
request_id,
461-
moderation_result.message,
462-
)
463-
_queue_splunk_event(
464-
background_tasks,
465-
infer_request,
466-
request,
467-
request_id,
468-
moderation_result.message,
469-
0.0,
470-
"infer_shield_blocked",
471-
)
472-
return RlsapiV1InferResponse(
473-
data=RlsapiV1InferData(
474-
text=moderation_result.message,
475-
request_id=request_id,
476-
tool_calls=None,
477-
tool_results=None,
478-
rag_chunks=None,
479-
referenced_documents=None,
480-
input_tokens=None,
481-
output_tokens=None,
482-
)
483-
)
511+
blocked_response = await _check_shield_moderation(
512+
input_source, request_id, background_tasks, infer_request, request
513+
)
514+
if blocked_response is not None:
515+
return blocked_response
484516

485517
start_time = time.monotonic()
486518

tests/integration/endpoints/test_rlsapi_v1_integration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from models.rlsapi.responses import RlsapiV1InferResponse
3333
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
3434
from utils.suid import check_suid
35+
from utils.types import ShieldModerationPassed
3536
from version import __version__
3637

3738
# ==========================================
@@ -80,6 +81,15 @@ def mock_authorization_fixture(mocker: MockerFixture) -> None:
8081
mock_authorization_resolvers(mocker)
8182

8283

84+
@pytest.fixture(autouse=True, name="mock_shield_passed")
85+
def mock_shield_passed_fixture(mocker: MockerFixture) -> None:
86+
"""Mock shield moderation to pass for all integration tests."""
87+
mocker.patch(
88+
"app.endpoints.rlsapi_v1.run_shield_moderation",
89+
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
90+
)
91+
92+
8393
def _create_mock_response_output(mocker: MockerFixture, text: str) -> Any:
8494
"""Create a mock Responses API output item with assistant message."""
8595
mock_output_item = mocker.Mock()

0 commit comments

Comments
 (0)