Skip to content

Commit fe2e532

Browse files
committed
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests.
1 parent d464fbc commit fe2e532

6 files changed

Lines changed: 693 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from .question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from ._capacity import QuestionValidity
4+
5+
__all__ = ["QuestionValidity"]
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import copy
13+
import logging
14+
from collections.abc import Sequence
15+
from dataclasses import dataclass
16+
from string import Template
17+
18+
from pydantic_ai import AgentRunResult, RunContext
19+
from pydantic_ai._agent_graph import GraphAgentState
20+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
21+
from pydantic_ai.direct import model_request
22+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
23+
from pydantic_ai.models import Model
24+
25+
logger = logging.getLogger(__name__)
26+
27+
DEFAULT_MODEL_PROMPT = """
28+
Instructions:
29+
- You are a question classifying tool
30+
- You are an expert in kubernetes and openshift
31+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
32+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
33+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
34+
- If the given question is an empty string, answer with the word ${rejected}
35+
36+
37+
Example Question:
38+
Why is the sky blue?
39+
Example Response:
40+
${rejected}
41+
42+
Example Question:
43+
Why is the grass green?
44+
Example Response:
45+
${rejected}
46+
47+
Example Question:
48+
Why is sand yellow?
49+
Example Response:
50+
${rejected}
51+
52+
Example Question:
53+
Can you help configure my cluster to automatically scale?
54+
Example Response:
55+
${allowed}
56+
57+
Question:
58+
${message}
59+
Response:
60+
"""
61+
62+
DEFAULT_INVALID_QUESTION_RESPONSE = """
63+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
64+
please ask me a question related to OpenShift.
65+
"""
66+
67+
SUBJECT_REJECTED = "REJECTED"
68+
SUBJECT_ALLOWED = "ALLOWED"
69+
70+
71+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
72+
str_arr: list[str] = []
73+
for c in user_content:
74+
match c:
75+
case str() as s:
76+
str_arr.append(s)
77+
case TextContent(content=c):
78+
str_arr.append(c)
79+
80+
return "\n".join(str_arr)
81+
82+
83+
def _remove_conversation_from_settings(model: Model) -> Model:
84+
"""Create a duplicate Model instance with conversation being removed from extra body"""
85+
_model = copy.copy(model)
86+
if settings := _model.settings:
87+
_settings = copy.copy(settings)
88+
if extra_body := _settings.get("extra_body"):
89+
if isinstance(extra_body, dict) and "conversation" in extra_body:
90+
_extra_body = {
91+
k: v for k, v in extra_body.items() if k != "conversation"
92+
}
93+
_settings["extra_body"] = _extra_body
94+
_model._settings = _settings
95+
return _model
96+
97+
98+
@dataclass
99+
class QuestionValidity(AbstractCapability):
100+
"""Block or modify user input based on a guardrail check.
101+
102+
The guard function receives the user prompt and returns True if safe.
103+
104+
Example:
105+
```python
106+
from pydantic_ai import Agent
107+
from pydantic_ai.models.openai import OpenAIResponsesModel
108+
109+
model = OpenAIResponsesModel("gpt-4o-mini")
110+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
111+
```
112+
"""
113+
114+
model: Model
115+
"""The model to use for the question validity check."""
116+
117+
model_prompt: str = DEFAULT_MODEL_PROMPT
118+
"""The prompt to use for the question validity check."""
119+
120+
invalid_question_response: str = DEFAULT_INVALID_QUESTION_RESPONSE
121+
"""The response to use when the question is determined to be invalid."""
122+
123+
def __post_init__(self):
124+
self.model = _remove_conversation_from_settings(self.model)
125+
126+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
127+
match message:
128+
case str() as s:
129+
_message = s
130+
case Sequence() as seq:
131+
_message = _extract_message_str_from_user_content(seq)
132+
case None:
133+
_message = ""
134+
135+
return Template(self.model_prompt).substitute(
136+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
137+
)
138+
139+
async def wrap_run(
140+
self, ctx: RunContext, *, handler: WrapRunHandler
141+
) -> AgentRunResult:
142+
prompt = self._build_prompt(ctx.prompt)
143+
144+
result = await model_request(
145+
model=self.model,
146+
messages=[ModelRequest.user_text_prompt(prompt)],
147+
)
148+
149+
# Include token usage from the question validity request
150+
ctx.usage.incr(result.usage)
151+
152+
if result.text == SUBJECT_ALLOWED:
153+
return await handler() # proceed with the real run
154+
else:
155+
# short-circuit: return the rejection message with shield usage tracked
156+
state = GraphAgentState(usage=ctx.usage)
157+
return AgentRunResult(output=self.invalid_question_response, _state=state)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)