Skip to content

Commit 9599ca5

Browse files
committed
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests.
1 parent d464fbc commit 9599ca5

6 files changed

Lines changed: 701 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from pydantic_ai_lightspeed.capabilities.question_validity._capacity import (
4+
QuestionValidity,
5+
)
6+
7+
__all__ = ["QuestionValidity"]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from collections.abc import Sequence
13+
from copy import copy
14+
from dataclasses import dataclass
15+
from string import Template
16+
17+
from pydantic_ai import AgentRunResult, RunContext
18+
from pydantic_ai._agent_graph import GraphAgentState
19+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
20+
from pydantic_ai.direct import model_request
21+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
22+
from pydantic_ai.models import Model
23+
24+
from log import get_logger
25+
26+
logger = get_logger(__name__)
27+
28+
DEFAULT_MODEL_PROMPT = """
29+
Instructions:
30+
- You are a question classifying tool
31+
- You are an expert in kubernetes and openshift
32+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
33+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
34+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
35+
- If the given question is an empty string, answer with the word ${rejected}
36+
37+
38+
Example Question:
39+
Why is the sky blue?
40+
Example Response:
41+
${rejected}
42+
43+
Example Question:
44+
Why is the grass green?
45+
Example Response:
46+
${rejected}
47+
48+
Example Question:
49+
Why is sand yellow?
50+
Example Response:
51+
${rejected}
52+
53+
Example Question:
54+
Can you help configure my cluster to automatically scale?
55+
Example Response:
56+
${allowed}
57+
58+
Question:
59+
${message}
60+
Response:
61+
"""
62+
63+
DEFAULT_INVALID_QUESTION_RESPONSE = """
64+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
65+
please ask me a question related to OpenShift.
66+
"""
67+
68+
SUBJECT_REJECTED = "REJECTED"
69+
SUBJECT_ALLOWED = "ALLOWED"
70+
71+
72+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
73+
"""Extract and combine all text content into a string from an UserContent sequence"""
74+
str_arr: list[str] = []
75+
for c in user_content:
76+
match c:
77+
case str() as s:
78+
str_arr.append(s)
79+
case TextContent(content=c):
80+
str_arr.append(c)
81+
82+
return "\n".join(str_arr)
83+
84+
85+
def _remove_conversation_from_settings(model: Model) -> Model:
86+
"""Return a Model with 'conversation' removed from extra_body settings.
87+
88+
Only creates a shallow copy if 'conversation' exists in extra_body; otherwise returns the original model unchanged.
89+
"""
90+
if settings := model.settings:
91+
if extra_body := settings.get("extra_body"):
92+
if isinstance(extra_body, dict) and "conversation" in extra_body:
93+
_extra_body = {
94+
k: v for k, v in extra_body.items() if k != "conversation"
95+
}
96+
_settings = copy(settings)
97+
_settings["extra_body"] = _extra_body
98+
_model = copy(model)
99+
_model._settings = _settings
100+
return _model
101+
return model
102+
103+
104+
@dataclass
105+
class QuestionValidity(AbstractCapability):
106+
"""Block or modify user input based on a guardrail check.
107+
108+
The guard function receives the user prompt and returns True if safe.
109+
110+
Example:
111+
```python
112+
from pydantic_ai import Agent
113+
from pydantic_ai.models.openai import OpenAIResponsesModel
114+
115+
model = OpenAIResponsesModel("gpt-4o-mini")
116+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
117+
```
118+
"""
119+
120+
model: Model
121+
"""The model to use for the question validity check."""
122+
123+
model_prompt: str = DEFAULT_MODEL_PROMPT
124+
"""The prompt to use for the question validity check."""
125+
126+
invalid_question_response: str = DEFAULT_INVALID_QUESTION_RESPONSE
127+
"""The response to use when the question is determined to be invalid."""
128+
129+
def __post_init__(self) -> None:
130+
self.model = _remove_conversation_from_settings(self.model)
131+
132+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
133+
match message:
134+
case str() as s:
135+
_message = s
136+
case Sequence() as seq:
137+
_message = _extract_message_str_from_user_content(seq)
138+
case None:
139+
_message = ""
140+
141+
return Template(self.model_prompt).substitute(
142+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
143+
)
144+
145+
async def wrap_run(
146+
self, ctx: RunContext, *, handler: WrapRunHandler
147+
) -> AgentRunResult:
148+
prompt = self._build_prompt(ctx.prompt)
149+
150+
result = await model_request(
151+
model=self.model,
152+
messages=[ModelRequest.user_text_prompt(prompt)],
153+
)
154+
155+
# Include token usage from the question validity request
156+
ctx.usage.incr(result.usage)
157+
158+
if result.text == SUBJECT_ALLOWED:
159+
return await handler() # proceed with the real run
160+
else:
161+
# short-circuit: return the rejection message with shield usage tracked
162+
state = GraphAgentState(usage=ctx.usage)
163+
return AgentRunResult(output=self.invalid_question_response, _state=state)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)