Skip to content

Commit b7c6d08

Browse files
authored
Merge pull request #1462 from major/rlsapi-v1-shield-moderation
RSPEED-2809: Add shield moderation to rlsapi_v1 /infer endpoint
2 parents 4961334 + a63664e commit b7c6d08

3 files changed

Lines changed: 278 additions & 3 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
extract_token_usage,
4949
get_mcp_tools,
5050
)
51+
from utils.shields import run_shield_moderation
5152
from utils.suid import get_suid
5253

5354
logger = get_logger(__name__)
@@ -318,6 +319,63 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
318319
background_tasks.add_task(send_splunk_event, event, sourcetype)
319320

320321

322+
async def _check_shield_moderation(
323+
input_text: str,
324+
request_id: str,
325+
background_tasks: BackgroundTasks,
326+
infer_request: RlsapiV1InferRequest,
327+
request: Request,
328+
) -> Optional[RlsapiV1InferResponse]:
329+
"""Run shield moderation and return a refusal response if blocked.
330+
331+
Uses all configured shields in Llama Stack. When no shields are
332+
registered, moderation is a no-op and returns None immediately.
333+
334+
Args:
335+
input_text: The combined user input to moderate.
336+
request_id: Unique identifier for the request.
337+
background_tasks: FastAPI background tasks for async Splunk event sending.
338+
infer_request: The original inference request (for Splunk event context).
339+
request: The FastAPI request object (for Splunk event context).
340+
341+
Returns:
342+
An RlsapiV1InferResponse containing the refusal message if the input
343+
was blocked, or None if moderation passed.
344+
"""
345+
client = AsyncLlamaStackClientHolder().get_client()
346+
moderation_result = await run_shield_moderation(client, input_text)
347+
348+
if moderation_result.decision != "blocked":
349+
return None
350+
351+
logger.info(
352+
"Request %s blocked by shield moderation: %s",
353+
request_id,
354+
moderation_result.message,
355+
)
356+
_queue_splunk_event(
357+
background_tasks,
358+
infer_request,
359+
request,
360+
request_id,
361+
moderation_result.message,
362+
0.0,
363+
"infer_shield_blocked",
364+
)
365+
return RlsapiV1InferResponse(
366+
data=RlsapiV1InferData(
367+
text=moderation_result.message,
368+
request_id=request_id,
369+
tool_calls=None,
370+
tool_results=None,
371+
rag_chunks=None,
372+
referenced_documents=None,
373+
input_tokens=None,
374+
output_tokens=None,
375+
)
376+
)
377+
378+
321379
def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-positional-arguments
322380
background_tasks: BackgroundTasks,
323381
infer_request: RlsapiV1InferRequest,
@@ -441,13 +499,24 @@ async def infer_endpoint( # pylint: disable=R0914
441499
logger.info("Processing rlsapi v1 /infer request %s", request_id)
442500

443501
input_source = infer_request.get_input_source()
444-
model_id = await _get_default_model_id()
445-
provider, model = extract_provider_and_model_from_model_id(model_id)
446-
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
447502
logger.debug(
448503
"Request %s: Combined input source length: %d", request_id, len(input_source)
449504
)
450505

506+
# Run shield moderation on user input before inference.
507+
# Uses all configured shields; no-op when no shields are registered.
508+
# Runs before model/tool discovery so blocked requests short-circuit
509+
# without incurring external I/O.
510+
blocked_response = await _check_shield_moderation(
511+
input_source, request_id, background_tasks, infer_request, request
512+
)
513+
if blocked_response is not None:
514+
return blocked_response
515+
516+
model_id = await _get_default_model_id()
517+
provider, model = extract_provider_and_model_from_model_id(model_id)
518+
mcp_tools: list[Any] = await get_mcp_tools(request_headers=request.headers)
519+
451520
start_time = time.monotonic()
452521

453522
# Check if verbose metadata should be returned

tests/integration/endpoints/test_rlsapi_v1_integration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from models.rlsapi.responses import RlsapiV1InferResponse
3333
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
3434
from utils.suid import check_suid
35+
from utils.types import ShieldModerationPassed
3536
from version import __version__
3637

3738
# ==========================================
@@ -80,6 +81,15 @@ def mock_authorization_fixture(mocker: MockerFixture) -> None:
8081
mock_authorization_resolvers(mocker)
8182

8283

84+
@pytest.fixture(autouse=True, name="mock_shield_passed")
85+
def mock_shield_passed_fixture(mocker: MockerFixture) -> None:
86+
"""Mock shield moderation to pass for all integration tests."""
87+
mocker.patch(
88+
"app.endpoints.rlsapi_v1.run_shield_moderation",
89+
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
90+
)
91+
92+
8393
def _create_mock_response_output(mocker: MockerFixture, text: str) -> Any:
8494
"""Create a mock Responses API output item with assistant message."""
8595
mock_output_item = mocker.Mock()

tests/unit/app/endpoints/test_rlsapi_v1.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytest
1414
from fastapi import HTTPException, status
15+
from llama_stack_api import OpenAIResponseMessage
1516
from llama_stack_client import APIConnectionError
1617
from pydantic import ValidationError
1718
from pytest_mock import MockerFixture
@@ -41,6 +42,7 @@
4142
from models.rlsapi.responses import RlsapiV1InferResponse
4243
from tests.unit.utils.auth_helpers import mock_authorization_resolvers
4344
from utils.suid import check_suid
45+
from utils.types import ShieldModerationBlocked, ShieldModerationPassed
4446

4547
MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token")
4648

@@ -125,6 +127,19 @@ def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None:
125127
mock_authorization_resolvers(mocker)
126128

127129

130+
@pytest.fixture(autouse=True, name="mock_shield_passed")
131+
def mock_shield_passed_fixture(mocker: MockerFixture) -> None:
132+
"""Mock shield moderation to pass for all endpoint tests by default.
133+
134+
Individual tests can override this by patching run_shield_moderation
135+
with a different return value.
136+
"""
137+
mocker.patch(
138+
"app.endpoints.rlsapi_v1.run_shield_moderation",
139+
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
140+
)
141+
142+
128143
@pytest.fixture(name="mock_api_connection_error")
129144
def mock_api_connection_error_fixture(mocker: MockerFixture) -> None:
130145
"""Mock responses.create() to raise APIConnectionError."""
@@ -850,6 +865,187 @@ async def test_infer_queues_splunk_error_event_on_failure(
850865
assert call_args[0][2] == "infer_error"
851866

852867

868+
# --- Test shield moderation ---
869+
870+
871+
def _create_blocked_moderation_result() -> ShieldModerationBlocked:
872+
"""Create a ShieldModerationBlocked result for testing."""
873+
return ShieldModerationBlocked(
874+
message="I can't answer that. Can I help with something else?",
875+
moderation_id="modr-test-123",
876+
refusal_response=OpenAIResponseMessage(
877+
role="assistant",
878+
content="I can't answer that. Can I help with something else?",
879+
),
880+
)
881+
882+
883+
@pytest.mark.asyncio
884+
async def test_infer_shield_blocked_returns_refusal(
885+
mocker: MockerFixture,
886+
mock_configuration: AppConfig,
887+
mock_llm_response: None,
888+
mock_auth_resolvers: None,
889+
mock_request_factory: Callable[..., Any],
890+
mock_background_tasks: Any,
891+
) -> None:
892+
"""Test that blocked shield moderation returns refusal text without calling LLM."""
893+
blocked = _create_blocked_moderation_result()
894+
mocker.patch(
895+
"app.endpoints.rlsapi_v1.run_shield_moderation",
896+
new=mocker.AsyncMock(return_value=blocked),
897+
)
898+
899+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
900+
mock_request = mock_request_factory()
901+
902+
response = await infer_endpoint(
903+
infer_request=infer_request,
904+
request=mock_request,
905+
background_tasks=mock_background_tasks,
906+
auth=MOCK_AUTH,
907+
)
908+
909+
assert isinstance(response, RlsapiV1InferResponse)
910+
assert response.data.text == blocked.message
911+
assert response.data.request_id is not None
912+
assert check_suid(response.data.request_id)
913+
# Blocked response must not include verbose metadata
914+
assert response.data.tool_calls is None
915+
assert response.data.tool_results is None
916+
assert response.data.rag_chunks is None
917+
assert response.data.referenced_documents is None
918+
assert response.data.input_tokens is None
919+
assert response.data.output_tokens is None
920+
921+
922+
@pytest.mark.asyncio
923+
async def test_infer_shield_blocked_skips_llm_call(
924+
mocker: MockerFixture,
925+
mock_configuration: AppConfig,
926+
mock_llm_response: None,
927+
mock_auth_resolvers: None,
928+
mock_request_factory: Callable[..., Any],
929+
mock_background_tasks: Any,
930+
) -> None:
931+
"""Test that blocked shield moderation prevents any LLM call."""
932+
blocked = _create_blocked_moderation_result()
933+
mocker.patch(
934+
"app.endpoints.rlsapi_v1.run_shield_moderation",
935+
new=mocker.AsyncMock(return_value=blocked),
936+
)
937+
mock_retrieve = mocker.patch(
938+
"app.endpoints.rlsapi_v1.retrieve_simple_response",
939+
new=mocker.AsyncMock(),
940+
)
941+
942+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
943+
944+
await infer_endpoint(
945+
infer_request=infer_request,
946+
request=mock_request_factory(),
947+
background_tasks=mock_background_tasks,
948+
auth=MOCK_AUTH,
949+
)
950+
951+
mock_retrieve.assert_not_called()
952+
953+
954+
@pytest.mark.asyncio
955+
async def test_infer_shield_blocked_queues_splunk_event(
956+
mocker: MockerFixture,
957+
mock_configuration: AppConfig,
958+
mock_llm_response: None,
959+
mock_auth_resolvers: None,
960+
mock_request_factory: Callable[..., Any],
961+
mock_background_tasks: Any,
962+
) -> None:
963+
"""Test that blocked shield moderation queues a Splunk event with correct sourcetype."""
964+
blocked = _create_blocked_moderation_result()
965+
mocker.patch(
966+
"app.endpoints.rlsapi_v1.run_shield_moderation",
967+
new=mocker.AsyncMock(return_value=blocked),
968+
)
969+
970+
infer_request = RlsapiV1InferRequest(question="How do I hack a server?")
971+
972+
await infer_endpoint(
973+
infer_request=infer_request,
974+
request=mock_request_factory(),
975+
background_tasks=mock_background_tasks,
976+
auth=MOCK_AUTH,
977+
)
978+
979+
mock_background_tasks.add_task.assert_called_once()
980+
call_args = mock_background_tasks.add_task.call_args
981+
assert call_args[0][2] == "infer_shield_blocked"
982+
983+
984+
@pytest.mark.asyncio
985+
async def test_infer_shield_passed_proceeds_to_llm(
986+
mocker: MockerFixture,
987+
mock_configuration: AppConfig,
988+
mock_llm_response: None,
989+
mock_auth_resolvers: None,
990+
mock_request_factory: Callable[..., Any],
991+
mock_background_tasks: Any,
992+
) -> None:
993+
"""Test that passed shield moderation proceeds to normal LLM inference."""
994+
# autouse fixture already patches with ShieldModerationPassed
995+
infer_request = RlsapiV1InferRequest(question="How do I list files?")
996+
997+
response = await infer_endpoint(
998+
infer_request=infer_request,
999+
request=mock_request_factory(),
1000+
background_tasks=mock_background_tasks,
1001+
auth=MOCK_AUTH,
1002+
)
1003+
1004+
assert response.data.text == "This is a test LLM response."
1005+
# Splunk event should use normal sourcetype
1006+
call_args = mock_background_tasks.add_task.call_args
1007+
assert call_args[0][2] == "infer_with_llm"
1008+
1009+
1010+
@pytest.mark.asyncio
1011+
async def test_infer_shield_moderation_receives_combined_input(
1012+
mocker: MockerFixture,
1013+
mock_configuration: AppConfig,
1014+
mock_llm_response: None,
1015+
mock_auth_resolvers: None,
1016+
mock_request_factory: Callable[..., Any],
1017+
mock_background_tasks: Any,
1018+
) -> None:
1019+
"""Test that shield moderation receives the full combined input source."""
1020+
mock_moderation = mocker.AsyncMock(return_value=ShieldModerationPassed())
1021+
mocker.patch(
1022+
"app.endpoints.rlsapi_v1.run_shield_moderation",
1023+
new=mock_moderation,
1024+
)
1025+
1026+
infer_request = RlsapiV1InferRequest(
1027+
question="Why did this fail?",
1028+
context=RlsapiV1Context(
1029+
stdin="piped input",
1030+
terminal=RlsapiV1Terminal(output="permission denied"),
1031+
),
1032+
)
1033+
1034+
await infer_endpoint(
1035+
infer_request=infer_request,
1036+
request=mock_request_factory(),
1037+
background_tasks=mock_background_tasks,
1038+
auth=MOCK_AUTH,
1039+
)
1040+
1041+
mock_moderation.assert_called_once()
1042+
# The input_text argument should be the combined input source
1043+
input_text = mock_moderation.call_args[0][1]
1044+
assert "Why did this fail?" in input_text
1045+
assert "piped input" in input_text
1046+
assert "permission denied" in input_text
1047+
1048+
8531049
@pytest.mark.asyncio
8541050
async def test_infer_splunk_event_includes_rh_identity_context(
8551051
mocker: MockerFixture,

0 commit comments

Comments
 (0)