Skip to content

Commit 3700804

Browse files
committed
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests.
1 parent d464fbc commit 3700804

6 files changed

Lines changed: 697 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from pydantic_ai_lightspeed.capabilities.question_validity._capacity import (
4+
QuestionValidity,
5+
)
6+
7+
__all__ = ["QuestionValidity"]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import copy
13+
from collections.abc import Sequence
14+
from dataclasses import dataclass
15+
from string import Template
16+
17+
from pydantic_ai import AgentRunResult, RunContext
18+
from pydantic_ai._agent_graph import GraphAgentState
19+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
20+
from pydantic_ai.direct import model_request
21+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
22+
from pydantic_ai.models import Model
23+
24+
from log import get_logger
25+
26+
logger = get_logger(__name__)
27+
28+
DEFAULT_MODEL_PROMPT = """
29+
Instructions:
30+
- You are a question classifying tool
31+
- You are an expert in kubernetes and openshift
32+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
33+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
34+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
35+
- If the given question is an empty string, answer with the word ${rejected}
36+
37+
38+
Example Question:
39+
Why is the sky blue?
40+
Example Response:
41+
${rejected}
42+
43+
Example Question:
44+
Why is the grass green?
45+
Example Response:
46+
${rejected}
47+
48+
Example Question:
49+
Why is sand yellow?
50+
Example Response:
51+
${rejected}
52+
53+
Example Question:
54+
Can you help configure my cluster to automatically scale?
55+
Example Response:
56+
${allowed}
57+
58+
Question:
59+
${message}
60+
Response:
61+
"""
62+
63+
DEFAULT_INVALID_QUESTION_RESPONSE = """
64+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
65+
please ask me a question related to OpenShift.
66+
"""
67+
68+
SUBJECT_REJECTED = "REJECTED"
69+
SUBJECT_ALLOWED = "ALLOWED"
70+
71+
72+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
73+
"""Extract and combine all text content into a string from an UserContent sequence"""
74+
str_arr: list[str] = []
75+
for c in user_content:
76+
match c:
77+
case str() as s:
78+
str_arr.append(s)
79+
case TextContent(content=c):
80+
str_arr.append(c)
81+
82+
return "\n".join(str_arr)
83+
84+
85+
def _remove_conversation_from_settings(model: Model) -> Model:
86+
"""Create a duplicate Model instance with conversation being removed from extra body"""
87+
_model = copy.copy(model)
88+
if settings := _model.settings:
89+
_settings = copy.copy(settings)
90+
if extra_body := _settings.get("extra_body"):
91+
if isinstance(extra_body, dict) and "conversation" in extra_body:
92+
_extra_body = {
93+
k: v for k, v in extra_body.items() if k != "conversation"
94+
}
95+
_settings["extra_body"] = _extra_body
96+
_model._settings = _settings
97+
return _model
98+
99+
100+
@dataclass
101+
class QuestionValidity(AbstractCapability):
102+
"""Block or modify user input based on a guardrail check.
103+
104+
The guard function receives the user prompt and returns True if safe.
105+
106+
Example:
107+
```python
108+
from pydantic_ai import Agent
109+
from pydantic_ai.models.openai import OpenAIResponsesModel
110+
111+
model = OpenAIResponsesModel("gpt-4o-mini")
112+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
113+
```
114+
"""
115+
116+
model: Model
117+
"""The model to use for the question validity check."""
118+
119+
model_prompt: str = DEFAULT_MODEL_PROMPT
120+
"""The prompt to use for the question validity check."""
121+
122+
invalid_question_response: str = DEFAULT_INVALID_QUESTION_RESPONSE
123+
"""The response to use when the question is determined to be invalid."""
124+
125+
def __post_init__(self) -> None:
126+
self.model = _remove_conversation_from_settings(self.model)
127+
128+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
129+
match message:
130+
case str() as s:
131+
_message = s
132+
case Sequence() as seq:
133+
_message = _extract_message_str_from_user_content(seq)
134+
case None:
135+
_message = ""
136+
137+
return Template(self.model_prompt).substitute(
138+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
139+
)
140+
141+
async def wrap_run(
142+
self, ctx: RunContext, *, handler: WrapRunHandler
143+
) -> AgentRunResult:
144+
prompt = self._build_prompt(ctx.prompt)
145+
146+
result = await model_request(
147+
model=self.model,
148+
messages=[ModelRequest.user_text_prompt(prompt)],
149+
)
150+
151+
# Include token usage from the question validity request
152+
ctx.usage.incr(result.usage)
153+
154+
if result.text == SUBJECT_ALLOWED:
155+
return await handler() # proceed with the real run
156+
else:
157+
# short-circuit: return the rejection message with shield usage tracked
158+
state = GraphAgentState(usage=ctx.usage)
159+
return AgentRunResult(output=self.invalid_question_response, _state=state)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)