-
Notifications
You must be signed in to change notification settings - Fork 86
LCORE-1259: Safety Shield configuration for query & streaming query #1100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
097d039
71b5bad
dc5b4bb
bdf945f
f2fa760
7b2ad45
a5ed817
fcd2b86
a2c6646
93c0085
2fa7a00
417b509
3413438
3c7693f
07321b7
876cf50
db682c3
08153d9
74a9f5c
90ac5a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| """Utility functions for working with Llama Stack shields.""" | ||
|
|
||
| import logging | ||
| from typing import Any, cast | ||
| from typing import Any, Optional, cast | ||
|
|
||
| from fastapi import HTTPException | ||
| from llama_stack_client import AsyncLlamaStackClient, BadRequestError | ||
| from llama_stack_client.types import CreateResponse | ||
|
|
||
| import metrics | ||
| from models.responses import NotFoundResponse | ||
| from models.responses import NotFoundResponse, UnprocessableEntityResponse | ||
| from utils.types import ShieldModerationResult | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -63,26 +63,56 @@ def detect_shield_violations(output_items: list[Any]) -> bool: | |
| async def run_shield_moderation( | ||
| client: AsyncLlamaStackClient, | ||
| input_text: str, | ||
| shield_ids: Optional[list[str]] = None, | ||
| ) -> ShieldModerationResult: | ||
| """ | ||
| Run shield moderation on input text. | ||
|
|
||
| Iterates through all configured shields and runs moderation checks. | ||
| Iterates through configured shields and runs moderation checks. | ||
| Raises HTTPException if shield model is not found. | ||
|
|
||
| Parameters: | ||
| client: The Llama Stack client. | ||
| input_text: The text to moderate. | ||
| shield_ids: Optional list of shield IDs to use. If None, uses all shields. | ||
| If empty list, skips all shields. | ||
|
|
||
| Returns: | ||
| ShieldModerationResult: Result indicating if content was blocked and the message. | ||
|
|
||
| Raises: | ||
| HTTPException: If shield's provider_resource_id is not configured or model not found. | ||
| """ | ||
| all_shields = await client.shields.list() | ||
|
|
||
| # Filter shields based on shield_ids parameter | ||
| if shield_ids is not None: | ||
| if len(shield_ids) == 0: | ||
| logger.info("shield_ids=[] provided, skipping all shields") | ||
| return ShieldModerationResult(blocked=False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing When shield overrides are enabled (the default), a caller can send
♻️ Option 2: Reject empty list if shield_ids is not None:
if len(shield_ids) == 0:
- logger.info("shield_ids=[] provided, skipping all shields")
- return ShieldModerationResult(blocked=False)
+ response = UnprocessableEntityResponse(
+ response="Invalid shield configuration",
+ cause="shield_ids cannot be an empty list. Omit the field to use all shields.",
+ )
+ raise HTTPException(**response.model_dump())🤖 Prompt for AI Agents |
||
|
|
||
| shields_to_run = [s for s in all_shields if s.identifier in shield_ids] | ||
|
|
||
| # Log warning if requested shield not found | ||
| requested = set(shield_ids) | ||
| available = {s.identifier for s in shields_to_run} | ||
| missing = requested - available | ||
| if missing: | ||
| logger.warning("Requested shields not found: %s", missing) | ||
|
|
||
| # Reject if no requested shields were found (prevents accidental bypass) | ||
| if not shields_to_run: | ||
| response = UnprocessableEntityResponse( | ||
| response="Invalid shield configuration", | ||
| cause=f"Requested shield_ids not found: {sorted(missing)}", | ||
| ) | ||
| raise HTTPException(**response.model_dump()) | ||
| else: | ||
| shields_to_run = list(all_shields) | ||
|
|
||
| available_models = {model.id for model in await client.models.list()} | ||
|
|
||
| for shield in await client.shields.list(): | ||
| for shield in shields_to_run: | ||
| if ( | ||
| not shield.provider_resource_id | ||
| or shield.provider_resource_id not in available_models | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.