|
12 | 12 |
|
13 | 13 | import pytest |
14 | 14 | from fastapi import HTTPException, status |
| 15 | +from llama_stack_api import OpenAIResponseMessage |
15 | 16 | from llama_stack_client import APIConnectionError |
16 | 17 | from pydantic import ValidationError |
17 | 18 | from pytest_mock import MockerFixture |
|
41 | 42 | from models.rlsapi.responses import RlsapiV1InferResponse |
42 | 43 | from tests.unit.utils.auth_helpers import mock_authorization_resolvers |
43 | 44 | from utils.suid import check_suid |
| 45 | +from utils.types import ShieldModerationBlocked, ShieldModerationPassed |
44 | 46 |
|
45 | 47 | MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token") |
46 | 48 |
|
@@ -125,6 +127,19 @@ def mock_auth_resolvers_fixture(mocker: MockerFixture) -> None: |
125 | 127 | mock_authorization_resolvers(mocker) |
126 | 128 |
|
127 | 129 |
|
| 130 | +@pytest.fixture(autouse=True, name="mock_shield_passed") |
| 131 | +def mock_shield_passed_fixture(mocker: MockerFixture) -> None: |
| 132 | + """Mock shield moderation to pass for all endpoint tests by default. |
| 133 | +
|
| 134 | + Individual tests can override this by patching run_shield_moderation |
| 135 | + with a different return value. |
| 136 | + """ |
| 137 | + mocker.patch( |
| 138 | + "app.endpoints.rlsapi_v1.run_shield_moderation", |
| 139 | + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), |
| 140 | + ) |
| 141 | + |
| 142 | + |
128 | 143 | @pytest.fixture(name="mock_api_connection_error") |
129 | 144 | def mock_api_connection_error_fixture(mocker: MockerFixture) -> None: |
130 | 145 | """Mock responses.create() to raise APIConnectionError.""" |
@@ -850,6 +865,187 @@ async def test_infer_queues_splunk_error_event_on_failure( |
850 | 865 | assert call_args[0][2] == "infer_error" |
851 | 866 |
|
852 | 867 |
|
| 868 | +# --- Test shield moderation --- |
| 869 | + |
| 870 | + |
| 871 | +def _create_blocked_moderation_result() -> ShieldModerationBlocked: |
| 872 | + """Create a ShieldModerationBlocked result for testing.""" |
| 873 | + return ShieldModerationBlocked( |
| 874 | + message="I can't answer that. Can I help with something else?", |
| 875 | + moderation_id="modr-test-123", |
| 876 | + refusal_response=OpenAIResponseMessage( |
| 877 | + role="assistant", |
| 878 | + content="I can't answer that. Can I help with something else?", |
| 879 | + ), |
| 880 | + ) |
| 881 | + |
| 882 | + |
| 883 | +@pytest.mark.asyncio |
| 884 | +async def test_infer_shield_blocked_returns_refusal( |
| 885 | + mocker: MockerFixture, |
| 886 | + mock_configuration: AppConfig, |
| 887 | + mock_llm_response: None, |
| 888 | + mock_auth_resolvers: None, |
| 889 | + mock_request_factory: Callable[..., Any], |
| 890 | + mock_background_tasks: Any, |
| 891 | +) -> None: |
| 892 | + """Test that blocked shield moderation returns refusal text without calling LLM.""" |
| 893 | + blocked = _create_blocked_moderation_result() |
| 894 | + mocker.patch( |
| 895 | + "app.endpoints.rlsapi_v1.run_shield_moderation", |
| 896 | + new=mocker.AsyncMock(return_value=blocked), |
| 897 | + ) |
| 898 | + |
| 899 | + infer_request = RlsapiV1InferRequest(question="How do I hack a server?") |
| 900 | + mock_request = mock_request_factory() |
| 901 | + |
| 902 | + response = await infer_endpoint( |
| 903 | + infer_request=infer_request, |
| 904 | + request=mock_request, |
| 905 | + background_tasks=mock_background_tasks, |
| 906 | + auth=MOCK_AUTH, |
| 907 | + ) |
| 908 | + |
| 909 | + assert isinstance(response, RlsapiV1InferResponse) |
| 910 | + assert response.data.text == blocked.message |
| 911 | + assert response.data.request_id is not None |
| 912 | + assert check_suid(response.data.request_id) |
| 913 | + # Blocked response must not include verbose metadata |
| 914 | + assert response.data.tool_calls is None |
| 915 | + assert response.data.tool_results is None |
| 916 | + assert response.data.rag_chunks is None |
| 917 | + assert response.data.referenced_documents is None |
| 918 | + assert response.data.input_tokens is None |
| 919 | + assert response.data.output_tokens is None |
| 920 | + |
| 921 | + |
| 922 | +@pytest.mark.asyncio |
| 923 | +async def test_infer_shield_blocked_skips_llm_call( |
| 924 | + mocker: MockerFixture, |
| 925 | + mock_configuration: AppConfig, |
| 926 | + mock_llm_response: None, |
| 927 | + mock_auth_resolvers: None, |
| 928 | + mock_request_factory: Callable[..., Any], |
| 929 | + mock_background_tasks: Any, |
| 930 | +) -> None: |
| 931 | + """Test that blocked shield moderation prevents any LLM call.""" |
| 932 | + blocked = _create_blocked_moderation_result() |
| 933 | + mocker.patch( |
| 934 | + "app.endpoints.rlsapi_v1.run_shield_moderation", |
| 935 | + new=mocker.AsyncMock(return_value=blocked), |
| 936 | + ) |
| 937 | + mock_retrieve = mocker.patch( |
| 938 | + "app.endpoints.rlsapi_v1.retrieve_simple_response", |
| 939 | + new=mocker.AsyncMock(), |
| 940 | + ) |
| 941 | + |
| 942 | + infer_request = RlsapiV1InferRequest(question="How do I hack a server?") |
| 943 | + |
| 944 | + await infer_endpoint( |
| 945 | + infer_request=infer_request, |
| 946 | + request=mock_request_factory(), |
| 947 | + background_tasks=mock_background_tasks, |
| 948 | + auth=MOCK_AUTH, |
| 949 | + ) |
| 950 | + |
| 951 | + mock_retrieve.assert_not_called() |
| 952 | + |
| 953 | + |
| 954 | +@pytest.mark.asyncio |
| 955 | +async def test_infer_shield_blocked_queues_splunk_event( |
| 956 | + mocker: MockerFixture, |
| 957 | + mock_configuration: AppConfig, |
| 958 | + mock_llm_response: None, |
| 959 | + mock_auth_resolvers: None, |
| 960 | + mock_request_factory: Callable[..., Any], |
| 961 | + mock_background_tasks: Any, |
| 962 | +) -> None: |
| 963 | + """Test that blocked shield moderation queues a Splunk event with correct sourcetype.""" |
| 964 | + blocked = _create_blocked_moderation_result() |
| 965 | + mocker.patch( |
| 966 | + "app.endpoints.rlsapi_v1.run_shield_moderation", |
| 967 | + new=mocker.AsyncMock(return_value=blocked), |
| 968 | + ) |
| 969 | + |
| 970 | + infer_request = RlsapiV1InferRequest(question="How do I hack a server?") |
| 971 | + |
| 972 | + await infer_endpoint( |
| 973 | + infer_request=infer_request, |
| 974 | + request=mock_request_factory(), |
| 975 | + background_tasks=mock_background_tasks, |
| 976 | + auth=MOCK_AUTH, |
| 977 | + ) |
| 978 | + |
| 979 | + mock_background_tasks.add_task.assert_called_once() |
| 980 | + call_args = mock_background_tasks.add_task.call_args |
| 981 | + assert call_args[0][2] == "infer_shield_blocked" |
| 982 | + |
| 983 | + |
| 984 | +@pytest.mark.asyncio |
| 985 | +async def test_infer_shield_passed_proceeds_to_llm( |
| 986 | + mocker: MockerFixture, |
| 987 | + mock_configuration: AppConfig, |
| 988 | + mock_llm_response: None, |
| 989 | + mock_auth_resolvers: None, |
| 990 | + mock_request_factory: Callable[..., Any], |
| 991 | + mock_background_tasks: Any, |
| 992 | +) -> None: |
| 993 | + """Test that passed shield moderation proceeds to normal LLM inference.""" |
| 994 | + # autouse fixture already patches with ShieldModerationPassed |
| 995 | + infer_request = RlsapiV1InferRequest(question="How do I list files?") |
| 996 | + |
| 997 | + response = await infer_endpoint( |
| 998 | + infer_request=infer_request, |
| 999 | + request=mock_request_factory(), |
| 1000 | + background_tasks=mock_background_tasks, |
| 1001 | + auth=MOCK_AUTH, |
| 1002 | + ) |
| 1003 | + |
| 1004 | + assert response.data.text == "This is a test LLM response." |
| 1005 | + # Splunk event should use normal sourcetype |
| 1006 | + call_args = mock_background_tasks.add_task.call_args |
| 1007 | + assert call_args[0][2] == "infer_with_llm" |
| 1008 | + |
| 1009 | + |
| 1010 | +@pytest.mark.asyncio |
| 1011 | +async def test_infer_shield_moderation_receives_combined_input( |
| 1012 | + mocker: MockerFixture, |
| 1013 | + mock_configuration: AppConfig, |
| 1014 | + mock_llm_response: None, |
| 1015 | + mock_auth_resolvers: None, |
| 1016 | + mock_request_factory: Callable[..., Any], |
| 1017 | + mock_background_tasks: Any, |
| 1018 | +) -> None: |
| 1019 | + """Test that shield moderation receives the full combined input source.""" |
| 1020 | + mock_moderation = mocker.AsyncMock(return_value=ShieldModerationPassed()) |
| 1021 | + mocker.patch( |
| 1022 | + "app.endpoints.rlsapi_v1.run_shield_moderation", |
| 1023 | + new=mock_moderation, |
| 1024 | + ) |
| 1025 | + |
| 1026 | + infer_request = RlsapiV1InferRequest( |
| 1027 | + question="Why did this fail?", |
| 1028 | + context=RlsapiV1Context( |
| 1029 | + stdin="piped input", |
| 1030 | + terminal=RlsapiV1Terminal(output="permission denied"), |
| 1031 | + ), |
| 1032 | + ) |
| 1033 | + |
| 1034 | + await infer_endpoint( |
| 1035 | + infer_request=infer_request, |
| 1036 | + request=mock_request_factory(), |
| 1037 | + background_tasks=mock_background_tasks, |
| 1038 | + auth=MOCK_AUTH, |
| 1039 | + ) |
| 1040 | + |
| 1041 | + mock_moderation.assert_called_once() |
| 1042 | + # The input_text argument should be the combined input source |
| 1043 | + input_text = mock_moderation.call_args[0][1] |
| 1044 | + assert "Why did this fail?" in input_text |
| 1045 | + assert "piped input" in input_text |
| 1046 | + assert "permission denied" in input_text |
| 1047 | + |
| 1048 | + |
853 | 1049 | @pytest.mark.asyncio |
854 | 1050 | async def test_infer_splunk_event_includes_rh_identity_context( |
855 | 1051 | mocker: MockerFixture, |
|
0 commit comments