Skip to content

Commit 71b93ea

Browse files
authored
Merge pull request #1221 from are-ces/fix-custom-shield-clean
LCORE-1209: Custom shields not compatible with LCORE
2 parents 8f7f1ef + 1de20a4 commit 71b93ea

2 files changed

Lines changed: 42 additions & 4 deletions

File tree

src/utils/shields.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from llama_stack_client.types import CreateResponse
88

99
import metrics
10+
from log import get_logger
1011
from models.responses import (
1112
NotFoundResponse,
1213
)
1314
from utils.types import ShieldModerationResult
14-
from log import get_logger
1515

1616
logger = get_logger(__name__)
1717

@@ -83,7 +83,12 @@ async def run_shield_moderation(
8383

8484
shields = await client.shields.list()
8585
for shield in shields:
86-
if (
86+
# Only validate provider_resource_id against models for llama-guard.
87+
# Llama Stack does not verify that the llama-guard model is registered,
88+
# so we check it here to fail fast with a clear error.
89+
# Custom shield providers (e.g. lightspeed_question_validity) configure
90+
# their model internally, so provider_resource_id is not a model ID.
91+
if shield.provider_id == "llama-guard" and (
8792
not shield.provider_resource_id
8893
or shield.provider_resource_id not in available_models
8994
):

tests/unit/utils/test_shields.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,48 @@ async def test_returns_blocked_with_default_message_when_no_user_message(
228228
assert result.message == DEFAULT_VIOLATION_MESSAGE
229229
assert result.shield_model == "moderation-model"
230230

231+
@pytest.mark.asyncio
232+
async def test_skips_model_check_for_non_llama_guard_shields(
233+
self, mocker: MockerFixture
234+
) -> None:
235+
"""Test that non-llama-guard shields skip model validation and proceed to moderation."""
236+
mock_client = mocker.Mock()
237+
238+
# Setup custom shield (not llama-guard) with provider_resource_id not in models
239+
shield = mocker.Mock()
240+
shield.identifier = "custom-shield"
241+
shield.provider_id = "lightspeed_question_validity"
242+
shield.provider_resource_id = "not-a-model-id"
243+
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])
244+
245+
# No matching models - should NOT raise for non-llama-guard
246+
mock_client.models.list = mocker.AsyncMock(return_value=[])
247+
248+
# Setup moderation result (not flagged)
249+
moderation_result = mocker.Mock()
250+
moderation_result.results = [mocker.Mock(flagged=False)]
251+
mock_client.moderations.create = mocker.AsyncMock(
252+
return_value=moderation_result
253+
)
254+
255+
result = await run_shield_moderation(mock_client, "test input")
256+
257+
assert result.blocked is False
258+
mock_client.moderations.create.assert_called_once_with(
259+
input="test input", model="not-a-model-id"
260+
)
261+
231262
@pytest.mark.asyncio
232263
async def test_raises_http_exception_when_shield_model_not_found(
233264
self, mocker: MockerFixture
234265
) -> None:
235266
"""Test that run_shield_moderation raises HTTPException when shield model not in models."""
236267
mock_client = mocker.Mock()
237268

238-
# Setup shield with provider_resource_id
269+
# Setup llama-guard shield with provider_resource_id not in models
239270
shield = mocker.Mock()
240271
shield.identifier = "test-shield"
272+
shield.provider_id = "llama-guard"
241273
shield.provider_resource_id = "missing-model"
242274
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])
243275

@@ -259,9 +291,10 @@ async def test_raises_http_exception_when_shield_has_no_provider_resource_id(
259291
"""Test that run_shield_moderation raises HTTPException when no provider_resource_id."""
260292
mock_client = mocker.Mock()
261293

262-
# Setup shield without provider_resource_id
294+
# Setup llama-guard shield without provider_resource_id
263295
shield = mocker.Mock()
264296
shield.identifier = "test-shield"
297+
shield.provider_id = "llama-guard"
265298
shield.provider_resource_id = None
266299
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])
267300

0 commit comments

Comments
 (0)