Skip to content

Commit b8cb192

Browse files
committed
feat: add shield moderation to rlsapi_v1 /infer endpoint
Wire run_shield_moderation() into the rlsapi_v1 /infer endpoint so CLA requests go through the same safety checks as responses, query, and streaming_query. When a shield blocks input, the refusal message is returned as normal response text and the LLM call is skipped entirely. No-op when no shields are configured. RSPEED-2809 Signed-off-by: Major Hayden <major@redhat.com>
1 parent 873f605 commit b8cb192

2 files changed

Lines changed: 230 additions & 0 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
extract_token_usage,
4949
get_mcp_tools,
5050
)
51+
from utils.shields import run_shield_moderation
5152
from utils.suid import get_suid
5253

5354
logger = get_logger(__name__)
@@ -448,6 +449,39 @@ async def infer_endpoint( # pylint: disable=R0914
448449
"Request %s: Combined input source length: %d", request_id, len(input_source)
449450
)
450451

452+
# Run shield moderation on user input before inference.
453+
# Uses all configured shields; no-op when no shields are registered.
454+
client = AsyncLlamaStackClientHolder().get_client()
455+
moderation_result = await run_shield_moderation(client, input_source)
456+
457+
if moderation_result.decision == "blocked":
458+
logger.info(
459+
"Request %s blocked by shield moderation: %s",
460+
request_id,
461+
moderation_result.message,
462+
)
463+
_queue_splunk_event(
464+
background_tasks,
465+
infer_request,
466+
request,
467+
request_id,
468+
moderation_result.message,
469+
0.0,
470+
"infer_shield_blocked",
471+
)
472+
return RlsapiV1InferResponse(
473+
data=RlsapiV1InferData(
474+
text=moderation_result.message,
475+
request_id=request_id,
476+
tool_calls=None,
477+
tool_results=None,
478+
rag_chunks=None,
479+
referenced_documents=None,
480+
input_tokens=None,
481+
output_tokens=None,
482+
)
483+
)
484+
451485
start_time = time.monotonic()
452486

453487
# Check if verbose metadata should be returned

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytest
1414
from fastapi import HTTPException, status
15+
from llama_stack_api import OpenAIResponseMessage
1516
from llama_stack_client import APIConnectionError
1617
from pydantic import ValidationError
1718
from pytest_mock import MockerFixture
@@ -41,6 +42,7 @@
4142
from models.rlsapi.responses import RlsapiV1InferResponse
4243
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
4344
from utils.suid import check_suid
45+
from utils.types import ShieldModerationBlocked, ShieldModerationPassed
4446

4547
MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token")
4648

@@ -125,6 +127,19 @@ def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None:
125127
mock_authorization_resolvers(mocker)
126128

127129

130+
@pytest.fixture(autouse=True, name="mock_shield_passed")
131+
def mock_shield_passed_fixture(mocker: MockerFixture) -> None:
132+
"""Mock shield moderation to pass for all endpoint tests by default.
133+
134+
Individual tests can override this by patching run_shield_moderation
135+
with a different return value.
136+
"""
137+
mocker.patch(
138+
"app.endpoints.rlsapi_v1.run_shield_moderation",
139+
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
140+
)
141+
142+
128143
@pytest.fixture(name="mock_api_connection_error")
129144
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
130145
"""Mock responses.create() to raise APIConnectionError."""
@@ -850,6 +865,187 @@ async def test_infer_queues_splunk_error_event_on_failure(
850865
assert call_args[0][2] == "infer_error"
851866

852867

868+
# --- Test shield moderation ---
869+
870+
871+
def _create_blocked_moderation_result() -> ShieldModerationBlocked:
872+
"""Create a ShieldModerationBlocked result for testing."""
873+
return ShieldModerationBlocked(
874+
message="I can't answer that. Can I help with something else?",
875+
moderation_id="modr-test-123",
876+
refusal_response=OpenAIResponseMessage(
877+
role="assistant",
878+
content="I can't answer that. Can I help with something else?",
879+
),
880+
)
881+
882+
883+
@pytest.mark.asyncio
884+
async def test_infer_shield_blocked_returns_refusal(
885+
mocker: MockerFixture,
886+
mock_configuration: AppConfig,
887+
mock_llm_response: None,
888+
mock_auth_resolvers: None,
889+
mock_request_factory: Callable[..., Any],
890+
mock_background_tasks: Any,
891+
) -> None:
892+
"""Test that blocked shield moderation returns refusal text without calling LLM."""
893+
blocked = _create_blocked_moderation_result()
894+
mocker.patch(
895+
"app.endpoints.rlsapi_v1.run_shield_moderation",
896+
new=mocker.AsyncMock(return_value=blocked),
897+
)
898+
899+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
900+
mock_request = mock_request_factory()
901+
902+
response = await infer_endpoint(
903+
infer_request=infer_request,
904+
request=mock_request,
905+
background_tasks=mock_background_tasks,
906+
auth=MOCK_AUTH,
907+
)
908+
909+
assert isinstance(response, RlsapiV1InferResponse)
910+
assert response.data.text == blocked.message
911+
assert response.data.request_id is not None
912+
assert check_suid(response.data.request_id)
913+
# Blocked response must not include verbose metadata
914+
assert response.data.tool_calls is None
915+
assert response.data.tool_results is None
916+
assert response.data.rag_chunks is None
917+
assert response.data.referenced_documents is None
918+
assert response.data.input_tokens is None
919+
assert response.data.output_tokens is None
920+
921+
922+
@pytest.mark.asyncio
923+
async def test_infer_shield_blocked_skips_llm_call(
924+
mocker: MockerFixture,
925+
mock_configuration: AppConfig,
926+
mock_llm_response: None,
927+
mock_auth_resolvers: None,
928+
mock_request_factory: Callable[..., Any],
929+
mock_background_tasks: Any,
930+
) -> None:
931+
"""Test that blocked shield moderation prevents any LLM call."""
932+
blocked = _create_blocked_moderation_result()
933+
mocker.patch(
934+
"app.endpoints.rlsapi_v1.run_shield_moderation",
935+
new=mocker.AsyncMock(return_value=blocked),
936+
)
937+
mock_retrieve = mocker.patch(
938+
"app.endpoints.rlsapi_v1.retrieve_simple_response",
939+
new=mocker.AsyncMock(),
940+
)
941+
942+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
943+
944+
await infer_endpoint(
945+
infer_request=infer_request,
946+
request=mock_request_factory(),
947+
background_tasks=mock_background_tasks,
948+
auth=MOCK_AUTH,
949+
)
950+
951+
mock_retrieve.assert_not_called()
952+
953+
954+
@pytest.mark.asyncio
955+
async def test_infer_shield_blocked_queues_splunk_event(
956+
mocker: MockerFixture,
957+
mock_configuration: AppConfig,
958+
mock_llm_response: None,
959+
mock_auth_resolvers: None,
960+
mock_request_factory: Callable[..., Any],
961+
mock_background_tasks: Any,
962+
) -> None:
963+
"""Test that blocked shield moderation queues a Splunk event with correct sourcetype."""
964+
blocked = _create_blocked_moderation_result()
965+
mocker.patch(
966+
"app.endpoints.rlsapi_v1.run_shield_moderation",
967+
new=mocker.AsyncMock(return_value=blocked),
968+
)
969+
970+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
971+
972+
await infer_endpoint(
973+
infer_request=infer_request,
974+
request=mock_request_factory(),
975+
background_tasks=mock_background_tasks,
976+
auth=MOCK_AUTH,
977+
)
978+
979+
mock_background_tasks.add_task.assert_called_once()
980+
call_args = mock_background_tasks.add_task.call_args
981+
assert call_args[0][2] == "infer_shield_blocked"
982+
983+
984+
@pytest.mark.asyncio
985+
async def test_infer_shield_passed_proceeds_to_llm(
986+
mocker: MockerFixture,
987+
mock_configuration: AppConfig,
988+
mock_llm_response: None,
989+
mock_auth_resolvers: None,
990+
mock_request_factory: Callable[..., Any],
991+
mock_background_tasks: Any,
992+
) -> None:
993+
"""Test that passed shield moderation proceeds to normal LLM inference."""
994+
# autouse fixture already patches with ShieldModerationPassed
995+
infer_request = RlsapiV1InferRequest(question="How do I list files?")
996+
997+
response = await infer_endpoint(
998+
infer_request=infer_request,
999+
request=mock_request_factory(),
1000+
background_tasks=mock_background_tasks,
1001+
auth=MOCK_AUTH,
1002+
)
1003+
1004+
assert response.data.text == "This is a test LLM response."
1005+
# Splunk event should use normal sourcetype
1006+
call_args = mock_background_tasks.add_task.call_args
1007+
assert call_args[0][2] == "infer_with_llm"
1008+
1009+
1010+
@pytest.mark.asyncio
1011+
async def test_infer_shield_moderation_receives_combined_input(
1012+
mocker: MockerFixture,
1013+
mock_configuration: AppConfig,
1014+
mock_llm_response: None,
1015+
mock_auth_resolvers: None,
1016+
mock_request_factory: Callable[..., Any],
1017+
mock_background_tasks: Any,
1018+
) -> None:
1019+
"""Test that shield moderation receives the full combined input source."""
1020+
mock_moderation = mocker.AsyncMock(return_value=ShieldModerationPassed())
1021+
mocker.patch(
1022+
"app.endpoints.rlsapi_v1.run_shield_moderation",
1023+
new=mock_moderation,
1024+
)
1025+
1026+
infer_request = RlsapiV1InferRequest(
1027+
question="Why did this fail?",
1028+
context=RlsapiV1Context(
1029+
stdin="piped input",
1030+
terminal=RlsapiV1Terminal(output="permission denied"),
1031+
),
1032+
)
1033+
1034+
await infer_endpoint(
1035+
infer_request=infer_request,
1036+
request=mock_request_factory(),
1037+
background_tasks=mock_background_tasks,
1038+
auth=MOCK_AUTH,
1039+
)
1040+
1041+
mock_moderation.assert_called_once()
1042+
# The input_text argument should be the combined input source
1043+
input_text = mock_moderation.call_args[0][1]
1044+
assert "Why did this fail?" in input_text
1045+
assert "piped input" in input_text
1046+
assert "permission denied" in input_text
1047+
1048+
8531049
@pytest.mark.asyncio
8541050
async def test_infer_splunk_event_includes_rh_identity_context(
8551051
mocker: MockerFixture,

0 commit comments

Comments
 (0)