Skip to content

Commit d5466e8

Browse files
committed
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests.
1 parent 7c41630 commit d5466e8

8 files changed

Lines changed: 799 additions & 0 deletions

File tree

src/constants.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,46 @@
250250
"I cannot process this request due to policy restrictions."
251251
)
252252

253+
# The Default model prompt and the default invalid question response for QuestionValidityConfig
254+
DEFAULT_MODEL_PROMPT: Final[str] = """
255+
Instructions:
256+
- You are a question classifying tool
257+
- You are an expert in kubernetes and openshift
258+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
259+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
260+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
261+
- If the given question is an empty string, answer with the word ${rejected}
262+
263+
264+
Example Question:
265+
Why is the sky blue?
266+
Example Response:
267+
${rejected}
268+
269+
Example Question:
270+
Why is the grass green?
271+
Example Response:
272+
${rejected}
273+
274+
Example Question:
275+
Why is sand yellow?
276+
Example Response:
277+
${rejected}
278+
279+
Example Question:
280+
Can you help configure my cluster to automatically scale?
281+
Example Response:
282+
${allowed}
283+
284+
Question:
285+
${message}
286+
Response:
287+
"""
288+
DEFAULT_INVALID_QUESTION_RESPONSE: Final[str] = """
289+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
290+
please ask me a question related to OpenShift.
291+
"""
292+
253293
# Placeholder slug used in responses when the server substituted its own
254294
# system prompt for the client's instructions. Avoids leaking the actual
255295
# server prompt back to the client.

src/models/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,6 +2176,24 @@ class SkillsConfiguration(ConfigurationBase):
21762176
)
21772177

21782178

2179+
class QuestionValidityConfig(ConfigurationBase):
2180+
"""Configuration for the question validity guardrail."""
2181+
2182+
model_id: str = Field(
2183+
..., title="Model id", description="The model_id to use for the guard"
2184+
)
2185+
model_prompt: str = Field(
2186+
default=constants.DEFAULT_MODEL_PROMPT,
2187+
title="Model prompt",
2188+
description="The default prompt sent to the LLM used to validate the Users' question.",
2189+
)
2190+
invalid_question_response: str = Field(
2191+
default=constants.DEFAULT_INVALID_QUESTION_RESPONSE,
2192+
title="Invalid question response",
2193+
description="The default response when the Users' question is determined to be invalid.",
2194+
)
2195+
2196+
21792197
class Configuration(ConfigurationBase):
21802198
"""Global service configuration."""
21812199

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from pydantic_ai_lightspeed.capabilities.question_validity._capability import (
4+
QuestionValidity,
5+
)
6+
7+
__all__ = ["QuestionValidity"]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from collections.abc import Sequence
13+
from dataclasses import dataclass, field
14+
from string import Template
15+
16+
from pydantic_ai import AgentRunResult, RunContext
17+
from pydantic_ai._agent_graph import GraphAgentState
18+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
19+
from pydantic_ai.direct import model_request
20+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
21+
from pydantic_ai.models import Model
22+
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
23+
24+
from client import AsyncLlamaStackClientHolder
25+
from log import get_logger
26+
from models.config import (
27+
QuestionValidityConfig,
28+
)
29+
from pydantic_ai_lightspeed.llamastack import LlamaStackResponsesModel
30+
from utils.pydantic_ai import llama_stack_provider_from_client
31+
32+
logger = get_logger(__name__)
33+
34+
SUBJECT_REJECTED = "REJECTED"
35+
SUBJECT_ALLOWED = "ALLOWED"
36+
37+
38+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
39+
"""Extract and combine all text content into a string from a UserContent sequence.
40+
41+
Parameters:
42+
user_content: A sequence of user content items to extract text from.
43+
44+
Returns:
45+
A single string with all text content joined by newlines.
46+
"""
47+
str_arr: list[str] = []
48+
for c in user_content:
49+
match c:
50+
case str() as s:
51+
str_arr.append(s)
52+
case TextContent(content=c):
53+
str_arr.append(c)
54+
55+
return "\n".join(str_arr)
56+
57+
58+
def _create_model_from_llama_stack_client(model_id: str) -> LlamaStackResponsesModel:
59+
"""Create a LlamaStackResponsesModel from the shared Llama Stack client.
60+
61+
Parameters:
62+
model_id: The model identifier to use for the responses model.
63+
64+
Returns:
65+
A configured LlamaStackResponsesModel instance.
66+
"""
67+
client = AsyncLlamaStackClientHolder().get_client()
68+
provider = llama_stack_provider_from_client(client)
69+
settings = OpenAIResponsesModelSettings(openai_store=False)
70+
return LlamaStackResponsesModel(model_id, provider=provider, settings=settings)
71+
72+
73+
@dataclass
74+
class QuestionValidity(AbstractCapability[None]):
75+
"""Block or modify user input based on a guardrail check.
76+
77+
The guard function receives the user prompt and returns True if safe.
78+
79+
Example:
80+
```python
81+
from pydantic_ai import Agent
82+
from pydantic_ai.models.openai import OpenAIResponsesModel
83+
84+
model = OpenAIResponsesModel("gpt-4o-mini")
85+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
86+
```
87+
"""
88+
89+
config: QuestionValidityConfig
90+
_model: Model = field(init=False)
91+
92+
def __post_init__(self) -> None:
93+
"""Initialize the model instance from the configured model ID."""
94+
self._model = _create_model_from_llama_stack_client(self.config.model_id)
95+
96+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
97+
"""Build the classification prompt from the user message.
98+
99+
Parameters:
100+
message: The user input as a string, sequence of user content, or None.
101+
102+
Returns:
103+
The rendered prompt string ready to send to the validity model.
104+
"""
105+
match message:
106+
case str() as s:
107+
_message = s
108+
case Sequence() as seq:
109+
_message = _extract_message_str_from_user_content(seq)
110+
case None:
111+
_message = ""
112+
113+
return Template(self.config.model_prompt).substitute(
114+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
115+
)
116+
117+
async def wrap_run(
118+
self, ctx: RunContext, *, handler: WrapRunHandler
119+
) -> AgentRunResult:
120+
"""Run the question validity check before delegating to the main agent.
121+
122+
Sends the user prompt to the validity model for classification.
123+
If the question is allowed, the handler proceeds normally.
124+
Otherwise, a rejection response is returned and the main agent
125+
is bypassed.
126+
127+
Parameters:
128+
ctx: The run context containing the user prompt and usage tracker.
129+
handler: The handler that invokes the main agent run.
130+
131+
Returns:
132+
The agent run result, either from the main agent or a rejection.
133+
"""
134+
prompt = self._build_prompt(ctx.prompt)
135+
136+
result = await model_request(
137+
model=self._model,
138+
messages=[ModelRequest.user_text_prompt(prompt)],
139+
)
140+
141+
# Include token usage from the question validity request
142+
ctx.usage.incr(result.usage)
143+
144+
if result.text is not None and result.text.strip() == SUBJECT_ALLOWED:
145+
return await handler() # proceed with the real run
146+
147+
# short-circuit: return the rejection message with shield usage tracked
148+
state = GraphAgentState(usage=ctx.usage)
149+
return AgentRunResult(
150+
output=self.config.invalid_question_response, _state=state
151+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)