Skip to content

Commit 559a48f

Browse files
committed
fix lint & cleanup
Signed-off-by: Lucas <lyoon@redhat.com>
1 parent f2fa760 commit 559a48f

4 files changed

Lines changed: 19 additions & 11 deletions

File tree

src/app/endpoints/query.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
"""Handler for REST API call to provide answer to query using Response API."""
44

5-
import logging
65
import datetime
76
from typing import Annotated, Any, Optional, cast
87

@@ -196,7 +195,11 @@ async def query_endpoint_handler(
196195

197196
# Retrieve response using Responses API
198197
turn_summary = await retrieve_response(
199-
client, responses_params, query_request.shield_ids, vector_store_ids, rag_id_mapping
198+
client,
199+
responses_params,
200+
query_request.shield_ids,
201+
vector_store_ids,
202+
rag_id_mapping,
200203
)
201204

202205
if pre_rag_chunks:

src/models/requests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class QueryRequest(BaseModel):
171171
None,
172172
description="Optional list of safety shield IDs to apply. "
173173
"If None, all configured shields are used. "
174-
"If empty list, all shields are skipped.",
174+
"If provided, must contain at least one valid shield ID (empty list raises 422 error).",
175175
examples=["llama-guard", "custom-shield"],
176176
)
177177

src/utils/shields.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utility functions for working with Llama Stack shields."""
22

3-
import logging
43
from typing import Any, Optional, cast
54

65
from fastapi import HTTPException
@@ -128,8 +127,11 @@ async def run_shield_moderation(
128127
# Filter shields based on shield_ids parameter
129128
if shield_ids is not None:
130129
if len(shield_ids) == 0:
131-
logger.info("shield_ids=[] provided, skipping all shields")
132-
return ShieldModerationResult(blocked=False)
130+
response = UnprocessableEntityResponse(
131+
response="Invalid shield configuration",
132+
cause="shield_ids provided but no shields selected. Remove the parameter to use default shields.",
133+
)
134+
raise HTTPException(**response.model_dump())
133135

134136
shields_to_run = [s for s in all_shields if s.identifier in shield_ids]
135137

tests/unit/utils/test_shields.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,19 +307,22 @@ async def test_returns_blocked_on_bad_request_error(
307307
mock_metric.inc.assert_called_once()
308308

309309
@pytest.mark.asyncio
310-
async def test_shield_ids_empty_list_skips_all_shields(
310+
async def test_shield_ids_empty_list_raises_422(
311311
self, mocker: MockerFixture
312312
) -> None:
313-
"""Test that shield_ids=[] explicitly skips all shields (intentional bypass)."""
313+
"""Test that shield_ids=[] raises HTTPException 422 (prevents bypass)."""
314314
mock_client = mocker.Mock()
315315
shield = mocker.Mock()
316316
shield.identifier = "shield-1"
317317
mock_client.shields.list = mocker.AsyncMock(return_value=[shield])
318318

319-
result = await run_shield_moderation(mock_client, "test input", shield_ids=[])
319+
with pytest.raises(HTTPException) as exc_info:
320+
await run_shield_moderation(mock_client, "test input", shield_ids=[])
320321

321-
assert result.blocked is False
322-
mock_client.shields.list.assert_called_once()
322+
assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
323+
assert "shield_ids provided but no shields selected" in str(
324+
exc_info.value.detail
325+
)
323326

324327
@pytest.mark.asyncio
325328
async def test_shield_ids_raises_exception_when_no_shields_found(

0 commit comments

Comments
 (0)