Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
097d039
Safety shield config
JslYoon Feb 3, 2026
71b5bad
Typos for shield fail gracefully. Added test suite
JslYoon Feb 3, 2026
dc5b4bb
Merge remote-tracking branch 'upstream/main' into JslYoon-safety-shie…
JslYoon Feb 11, 2026
bdf945f
new configuration parameter that will enable or disable safety shield…
JslYoon Feb 13, 2026
f2fa760
Merge branch 'main' into JslYoon-safety-shield-config
JslYoon Feb 17, 2026
7b2ad45
fix lint & shield empty list config
JslYoon Feb 17, 2026
a5ed817
Merge remote-tracking branch 'upstream/main' into JslYoon-safety-shie…
JslYoon Feb 23, 2026
fcd2b86
openapi documentation for safety shielf config
JslYoon Feb 23, 2026
a2c6646
Merge branch 'main' into JslYoon-safety-shield-config
JslYoon Feb 25, 2026
93c0085
fix integration test
JslYoon Feb 25, 2026
2fa7a00
LCORE-948: Restart prow e2e pod (#1181)
radofuchs Feb 25, 2026
417b509
LCORE-1282: unify version usage in LCore
tisnik Feb 25, 2026
3413438
Created and documented LCORE OpenResponses specification
asimurka Feb 24, 2026
3c7693f
LCORE-1374: contribution guide
tisnik Feb 25, 2026
07321b7
add interrupted user query to conversation
Jdubrick Feb 24, 2026
876cf50
Removed check for shield model if shield is not llama-guard
are-ces Feb 25, 2026
db682c3
Merge branch 'main' into JslYoon-safety-shield-config
JslYoon Feb 25, 2026
08153d9
fix lint
JslYoon Feb 25, 2026
74a9f5c
Merge branch 'main' into JslYoon-safety-shield-config
JslYoon Feb 25, 2026
90ac5a7
Merge branch 'main' into JslYoon-safety-shield-config
JslYoon Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/app/endpoints/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals
generate_topic_summary=True,
media_type=None,
vector_store_ids=vector_store_ids,
shield_ids=None,
)

# Get LLM client and select model
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
4 changes: 3 additions & 1 deletion src/app/endpoints/streaming_query_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
)

# Run shield moderation before calling LLM
moderation_result = await run_shield_moderation(client, input_text)
moderation_result = await run_shield_moderation(
client, input_text, query_request.shield_ids
)
if moderation_result.blocked:
violation_message = moderation_result.message or ""
await append_turn_to_conversation(
Expand Down
9 changes: 9 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QueryRequest(BaseModel):
generate_topic_summary: Whether to generate topic summary for new conversations.
media_type: The optional media type for response format (application/json or text/plain).
vector_store_ids: The optional list of specific vector store IDs to query for RAG.
shield_ids: The optional list of safety shield IDs to apply.

Example:
```python
Expand Down Expand Up @@ -166,6 +167,14 @@ class QueryRequest(BaseModel):
examples=["ocp_docs", "knowledge_base", "vector_db_1"],
)

shield_ids: Optional[list[str]] = Field(
None,
description="Optional list of safety shield IDs to apply. "
"If None, all configured shields are used. "
"If empty list, all shields are skipped.",
examples=["llama-guard", "custom-shield"],
)

# provides examples for /docs endpoint
model_config = {
"extra": "forbid",
Expand Down
38 changes: 34 additions & 4 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Utility functions for working with Llama Stack shields."""

import logging
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
from typing import Any, cast
from typing import Any, Optional, cast

from fastapi import HTTPException
from llama_stack_client import AsyncLlamaStackClient, BadRequestError
from llama_stack_client.types import CreateResponse

import metrics
from models.responses import NotFoundResponse
from models.responses import NotFoundResponse, UnprocessableEntityResponse
from utils.types import ShieldModerationResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,26 +63,56 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
async def run_shield_moderation(
client: AsyncLlamaStackClient,
input_text: str,
shield_ids: Optional[list[str]] = None,
) -> ShieldModerationResult:
"""
Run shield moderation on input text.

Iterates through all configured shields and runs moderation checks.
Iterates through configured shields and runs moderation checks.
Raises HTTPException if shield model is not found.

Parameters:
client: The Llama Stack client.
input_text: The text to moderate.
shield_ids: Optional list of shield IDs to use. If None, uses all shields.
If empty list, skips all shields.

Returns:
ShieldModerationResult: Result indicating if content was blocked and the message.

Raises:
HTTPException: If shield's provider_resource_id is not configured or model not found.
"""
all_shields = await client.shields.list()

# Filter shields based on shield_ids parameter
if shield_ids is not None:
if len(shield_ids) == 0:
logger.info("shield_ids=[] provided, skipping all shields")
return ShieldModerationResult(blocked=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Passing shield_ids=[] silently bypasses all safety shields.

When shield overrides are enabled (the default), a caller can send shield_ids: [] to skip every shield with no moderation at all. This is a different semantic from "choose specific shields" and could be an unintended bypass vector. Consider either:

  1. Treating an empty list the same as None (i.e., run all shields), or
  2. Rejecting an empty list with a 422, similar to the "no valid shields found" check below.
♻️ Option 2: Reject empty list
     if shield_ids is not None:
         if len(shield_ids) == 0:
-            logger.info("shield_ids=[] provided, skipping all shields")
-            return ShieldModerationResult(blocked=False)
+            response = UnprocessableEntityResponse(
+                response="Invalid shield configuration",
+                cause="shield_ids cannot be an empty list. Omit the field to use all shields.",
+            )
+            raise HTTPException(**response.model_dump())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/utils/shields.py` around lines 128 - 131, The current branch treating
shield_ids == [] as "skip all shields" is unsafe; replace that behavior by
rejecting an explicit empty list like the later "no valid shields found" check:
in the function handling shield_ids, remove the logger.info + return
ShieldModerationResult(blocked=False) for len(shield_ids) == 0 and instead raise
a 422 validation error (HTTPException or the project’s validation error type)
with a clear message (e.g., "shield_ids provided but no shields selected"),
ensuring the raised error follows the same format and status code used by the
"no valid shields found" logic; update any tests to expect a 422 for an explicit
empty list.


shields_to_run = [s for s in all_shields if s.identifier in shield_ids]

# Log warning if requested shield not found
requested = set(shield_ids)
available = {s.identifier for s in shields_to_run}
missing = requested - available
if missing:
logger.warning("Requested shields not found: %s", missing)

# Reject if no requested shields were found (prevents accidental bypass)
if not shields_to_run:
response = UnprocessableEntityResponse(
response="Invalid shield configuration",
cause=f"Requested shield_ids not found: {sorted(missing)}",
)
raise HTTPException(**response.model_dump())
else:
shields_to_run = list(all_shields)

available_models = {model.id for model in await client.models.list()}

for shield in await client.shields.list():
for shield in shields_to_run:
if (
not shield.provider_resource_id
or shield.provider_resource_id not in available_models
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/utils/test_shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,75 @@ async def test_returns_blocked_on_bad_request_error(
assert result.shield_model == "moderation-model"
mock_metric.inc.assert_called_once()

@pytest.mark.asyncio
async def test_shield_ids_empty_list_skips_all_shields(
self, mocker: MockerFixture
) -> None:
"""Test that shield_ids=[] explicitly skips all shields (intentional bypass)."""
mock_client = mocker.Mock()
shield = mocker.Mock()
shield.identifier = "shield-1"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

result = await run_shield_moderation(mock_client, "test input", shield_ids=[])

assert result.blocked is False
mock_client.shields.list.assert_called_once()

@pytest.mark.asyncio
async def test_shield_ids_raises_exception_when_no_shields_found(
self, mocker: MockerFixture
) -> None:
"""Test shield_ids raises HTTPException when no requested shields exist."""
mock_client = mocker.Mock()
shield = mocker.Mock()
shield.identifier = "shield-1"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])

with pytest.raises(HTTPException) as exc_info:
await run_shield_moderation(
mock_client, "test input", shield_ids=["typo-shield"]
)

assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert "Invalid shield configuration" in exc_info.value.detail["response"] # type: ignore
assert "typo-shield" in exc_info.value.detail["cause"] # type: ignore

@pytest.mark.asyncio
async def test_shield_ids_filters_to_specific_shield(
self, mocker: MockerFixture
) -> None:
"""Test that shield_ids filters to only specified shields."""
mock_client = mocker.Mock()

shield1 = mocker.Mock()
shield1.identifier = "shield-1"
shield1.provider_resource_id = "model-1"
shield2 = mocker.Mock()
shield2.identifier = "shield-2"
shield2.provider_resource_id = "model-2"
mock_client.shields.list = mocker.AsyncMock(return_value=[shield1, shield2])

model1 = mocker.Mock()
model1.id = "model-1"
mock_client.models.list = mocker.AsyncMock(return_value=[model1])

moderation_result = mocker.Mock()
moderation_result.results = [mocker.Mock(flagged=False)]
mock_client.moderations.create = mocker.AsyncMock(
return_value=moderation_result
)

result = await run_shield_moderation(
mock_client, "test input", shield_ids=["shield-1"]
)

assert result.blocked is False
assert mock_client.moderations.create.call_count == 1
mock_client.moderations.create.assert_called_with(
input="test input", model="model-1"
)


class TestAppendTurnToConversation: # pylint: disable=too-few-public-methods
"""Tests for append_turn_to_conversation function."""
Expand Down
Loading