Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from llama_stack_client.types import CreateResponse

import metrics
from log import get_logger
from models.responses import (
NotFoundResponse,
)
from utils.types import ShieldModerationResult
from log import get_logger

logger = get_logger(__name__)

Expand Down Expand Up @@ -83,7 +83,12 @@ async def run_shield_moderation(

shields = await client.shields.list()
for shield in shields:
if (
# Only validate provider_resource_id against models for llama-guard.
# Llama Stack does not verify that the llama-guard model is registered,
# so we check it here to fail fast with a clear error.
# Custom shield providers (e.g. lightspeed_question_validity) configure
# their model internally, so provider_resource_id is not a model ID.
if shield.provider_id == "llama-guard" and (
not shield.provider_resource_id
or shield.provider_resource_id not in available_models
):
Expand Down
37 changes: 35 additions & 2 deletions tests/unit/utils/test_shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,48 @@ async def test_returns_blocked_with_default_message_when_no_user_message(
assert result.message == DEFAULT_VIOLATION_MESSAGE
assert result.shield_model == "moderation-model"

@pytest.mark.asyncio
async def test_skips_model_check_for_non_llama_guard_shields(
self, mocker: MockerFixture
) -> None:
"""Test that non-llama-guard shields skip model validation and proceed to moderation."""
mock_client = mocker.Mock()

# Setup custom shield (not llama-guard) with provider_resource_id not in models
shield = mocker.Mock()
shield.identifier = "custom-shield"
shield.provider_id = "lightspeed_question_validity"
shield.provider_resource_id = "not-a-model-id"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

# No matching models - should NOT raise for non-llama-guard
mock_client.models.list = mocker.AsyncMock(return_value=[])

# Setup moderation result (not flagged)
moderation_result = mocker.Mock()
moderation_result.results = [mocker.Mock(flagged=False)]
mock_client.moderations.create = mocker.AsyncMock(
return_value=moderation_result
)

result = await run_shield_moderation(mock_client, "test input")

assert result.blocked is False
mock_client.moderations.create.assert_called_once_with(
input="test input", model="not-a-model-id"
)

@pytest.mark.asyncio
async def test_raises_http_exception_when_shield_model_not_found(
self, mocker: MockerFixture
) -> None:
"""Test that run_shield_moderation raises HTTPException when shield model not in models."""
mock_client = mocker.Mock()

# Setup shield with provider_resource_id
# Setup llama-guard shield with provider_resource_id not in models
shield = mocker.Mock()
shield.identifier = "test-shield"
shield.provider_id = "llama-guard"
shield.provider_resource_id = "missing-model"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

Expand All @@ -259,9 +291,10 @@ async def test_raises_http_exception_when_shield_has_no_provider_resource_id(
"""Test that run_shield_moderation raises HTTPException when no provider_resource_id."""
mock_client = mocker.Mock()

# Setup shield without provider_resource_id
# Setup llama-guard shield without provider_resource_id
shield = mocker.Mock()
shield.identifier = "test-shield"
shield.provider_id = "llama-guard"
shield.provider_resource_id = None
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

Expand Down
Loading