Skip to content

Commit a25393f

Browse files
committed
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests.
1 parent d464fbc commit a25393f

8 files changed

Lines changed: 715 additions & 0 deletions

File tree

src/constants.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,47 @@
248248
"I cannot process this request due to policy restrictions."
249249
)
250250

251+
252+
DEFAULT_MODEL_PROMPT: Final[str] = """
253+
Instructions:
254+
- You are a question classifying tool
255+
- You are an expert in kubernetes and openshift
256+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
257+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
258+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
259+
- If the given question is an empty string, answer with the word ${rejected}
260+
261+
262+
Example Question:
263+
Why is the sky blue?
264+
Example Response:
265+
${rejected}
266+
267+
Example Question:
268+
Why is the grass green?
269+
Example Response:
270+
${rejected}
271+
272+
Example Question:
273+
Why is sand yellow?
274+
Example Response:
275+
${rejected}
276+
277+
Example Question:
278+
Can you help configure my cluster to automatically scale?
279+
Example Response:
280+
${allowed}
281+
282+
Question:
283+
${message}
284+
Response:
285+
"""
286+
287+
DEFAULT_INVALID_QUESTION_RESPONSE: Final[str] = """
288+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
289+
please ask me a question related to OpenShift.
290+
"""
291+
251292
# Placeholder slug used in responses when the server substituted its own
252293
# system prompt for the client's instructions. Avoids leaking the actual
253294
# server prompt back to the client.

src/models/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,24 @@ class SkillsConfiguration(ConfigurationBase):
20602060
)
20612061

20622062

2063+
class QuestionValidityConfig(ConfigurationBase):
2064+
"""Configuration for the question validity guardrail."""
2065+
2066+
model_id: str = Field(
2067+
..., title="Model id", description="The model_id to use for the guard"
2068+
)
2069+
model_prompt: str = Field(
2070+
default=constants.DEFAULT_MODEL_PROMPT,
2071+
title="Model prompt",
2072+
description="The default prompt sent to the LLM used to validate the Users' question.",
2073+
)
2074+
invalid_question_response: str = Field(
2075+
default=constants.DEFAULT_INVALID_QUESTION_RESPONSE,
2076+
title="Invalid question response",
2077+
description="The default response when the Users' question is determined to be invalid.",
2078+
)
2079+
2080+
20632081
class Configuration(ConfigurationBase):
20642082
"""Global service configuration."""
20652083

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from pydantic_ai_lightspeed.capabilities.question_validity._capability import (
4+
QuestionValidity,
5+
)
6+
7+
__all__ = ["QuestionValidity"]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from collections.abc import Sequence
13+
from dataclasses import dataclass, field
14+
from string import Template
15+
16+
from pydantic_ai import AgentRunResult, RunContext
17+
from pydantic_ai._agent_graph import GraphAgentState
18+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
19+
from pydantic_ai.direct import model_request
20+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
21+
from pydantic_ai.models import Model, infer_model
22+
23+
from log import get_logger
24+
from models.config import (
25+
QuestionValidityConfig,
26+
)
27+
28+
logger = get_logger(__name__)
29+
30+
SUBJECT_REJECTED = "REJECTED"
31+
SUBJECT_ALLOWED = "ALLOWED"
32+
33+
34+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
35+
"""Extract and combine all text content into a string from a UserContent sequence.
36+
37+
Parameters:
38+
user_content: A sequence of user content items to extract text from.
39+
40+
Returns:
41+
A single string with all text content joined by newlines.
42+
"""
43+
str_arr: list[str] = []
44+
for c in user_content:
45+
match c:
46+
case str() as s:
47+
str_arr.append(s)
48+
case TextContent(content=c):
49+
str_arr.append(c)
50+
51+
return "\n".join(str_arr)
52+
53+
54+
@dataclass
55+
class QuestionValidity(AbstractCapability[None]):
56+
"""Block or modify user input based on a guardrail check.
57+
58+
The guard function receives the user prompt and returns True if safe.
59+
60+
Example:
61+
```python
62+
from pydantic_ai import Agent
63+
from pydantic_ai.models.openai import OpenAIResponsesModel
64+
65+
model = OpenAIResponsesModel("gpt-4o-mini")
66+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
67+
```
68+
"""
69+
70+
config: QuestionValidityConfig
71+
_model: Model = field(init=False)
72+
73+
def __post_init__(self) -> None:
74+
"""Initialize the model instance from the configured model ID."""
75+
self._model = infer_model(self.config.model_id)
76+
77+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
78+
"""Build the classification prompt from the user message.
79+
80+
Parameters:
81+
message: The user input as a string, sequence of user content, or None.
82+
83+
Returns:
84+
The rendered prompt string ready to send to the validity model.
85+
"""
86+
match message:
87+
case str() as s:
88+
_message = s
89+
case Sequence() as seq:
90+
_message = _extract_message_str_from_user_content(seq)
91+
case None:
92+
_message = ""
93+
94+
return Template(self.config.model_prompt).substitute(
95+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
96+
)
97+
98+
async def wrap_run(
99+
self, ctx: RunContext, *, handler: WrapRunHandler
100+
) -> AgentRunResult:
101+
"""Run the question validity check before delegating to the main agent.
102+
103+
Sends the user prompt to the validity model for classification.
104+
If the question is allowed, the handler proceeds normally.
105+
Otherwise, a rejection response is returned and the main agent
106+
is bypassed.
107+
108+
Parameters:
109+
ctx: The run context containing the user prompt and usage tracker.
110+
handler: The handler that invokes the main agent run.
111+
112+
Returns:
113+
The agent run result, either from the main agent or a rejection.
114+
"""
115+
prompt = self._build_prompt(ctx.prompt)
116+
117+
result = await model_request(
118+
model=self._model,
119+
messages=[ModelRequest.user_text_prompt(prompt)],
120+
)
121+
122+
# Include token usage from the question validity request
123+
ctx.usage.incr(result.usage)
124+
125+
if result.text is not None and result.text.strip() == SUBJECT_ALLOWED:
126+
return await handler() # proceed with the real run
127+
128+
# short-circuit: return the rejection message with shield usage tracked
129+
state = GraphAgentState(usage=ctx.usage)
130+
return AgentRunResult(
131+
output=self.config.invalid_question_response, _state=state
132+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)