Skip to content

Commit 4908c93

Browse files
authored
Merge pull request #1913 from Jazzcort/question-validity-for-pydantic-ai
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI
2 parents e052d2d + d5466e8 commit 4908c93

10 files changed

Lines changed: 807 additions & 8 deletions

File tree

src/constants.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,46 @@
263263
"I cannot process this request due to policy restrictions."
264264
)
265265

266+
# The Default model prompt and the default invalid question response for QuestionValidityConfig
267+
DEFAULT_MODEL_PROMPT: Final[str] = """
268+
Instructions:
269+
- You are a question classifying tool
270+
- You are an expert in kubernetes and openshift
271+
- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response.
272+
- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}.
273+
- Do not explain your answer, just provide the one-word response. Do not give any other response.
274+
- If the given question is an empty string, answer with the word ${rejected}
275+
276+
277+
Example Question:
278+
Why is the sky blue?
279+
Example Response:
280+
${rejected}
281+
282+
Example Question:
283+
Why is the grass green?
284+
Example Response:
285+
${rejected}
286+
287+
Example Question:
288+
Why is sand yellow?
289+
Example Response:
290+
${rejected}
291+
292+
Example Question:
293+
Can you help configure my cluster to automatically scale?
294+
Example Response:
295+
${allowed}
296+
297+
Question:
298+
${message}
299+
Response:
300+
"""
301+
DEFAULT_INVALID_QUESTION_RESPONSE: Final[str] = """
302+
Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift,
303+
please ask me a question related to OpenShift.
304+
"""
305+
266306
# Placeholder slug used in responses when the server substituted its own
267307
# system prompt for the client's instructions. Avoids leaking the actual
268308
# server prompt back to the client.

src/models/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,6 +2307,24 @@ class SkillsConfiguration(ConfigurationBase):
23072307
)
23082308

23092309

2310+
class QuestionValidityConfig(ConfigurationBase):
2311+
"""Configuration for the question validity guardrail."""
2312+
2313+
model_id: str = Field(
2314+
..., title="Model id", description="The model_id to use for the guard"
2315+
)
2316+
model_prompt: str = Field(
2317+
default=constants.DEFAULT_MODEL_PROMPT,
2318+
title="Model prompt",
2319+
description="The default prompt sent to the LLM used to validate the Users' question.",
2320+
)
2321+
invalid_question_response: str = Field(
2322+
default=constants.DEFAULT_INVALID_QUESTION_RESPONSE,
2323+
title="Invalid question response",
2324+
description="The default response when the Users' question is determined to be invalid.",
2325+
)
2326+
2327+
23102328
class Configuration(ConfigurationBase):
23112329
"""Global service configuration."""
23122330

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Pluggable capabilities for pydantic-ai agents in Lightspeed.
2+
3+
Provides safety, guardrail, and policy capabilities that hook into
4+
pydantic-ai's AbstractCapability lifecycle to enforce constraints
5+
before, during, or after agent runs.
6+
"""
7+
8+
from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity
9+
10+
__all__ = ["QuestionValidity"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Question validity capability for agent input validation."""
2+
3+
from pydantic_ai_lightspeed.capabilities.question_validity._capability import (
4+
QuestionValidity,
5+
)
6+
7+
__all__ = ["QuestionValidity"]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Question validity capability for filtering off-topic user queries.
2+
3+
This module implements a guardrail that classifies user questions as
4+
Kubernetes/OpenShift-related or not (It can be customized to any
5+
topic as well), using an LLM-based check before the main agent
6+
processes the request. Invalid questions are rejected with a
7+
predefined response, bypassing the primary agent entirely.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from collections.abc import Sequence
13+
from dataclasses import dataclass, field
14+
from string import Template
15+
16+
from pydantic_ai import AgentRunResult, RunContext
17+
from pydantic_ai._agent_graph import GraphAgentState
18+
from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler
19+
from pydantic_ai.direct import model_request
20+
from pydantic_ai.messages import ModelRequest, TextContent, UserContent
21+
from pydantic_ai.models import Model
22+
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
23+
24+
from client import AsyncLlamaStackClientHolder
25+
from log import get_logger
26+
from models.config import (
27+
QuestionValidityConfig,
28+
)
29+
from pydantic_ai_lightspeed.llamastack import LlamaStackResponsesModel
30+
from utils.pydantic_ai import llama_stack_provider_from_client
31+
32+
logger = get_logger(__name__)
33+
34+
SUBJECT_REJECTED = "REJECTED"
35+
SUBJECT_ALLOWED = "ALLOWED"
36+
37+
38+
def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str:
39+
"""Extract and combine all text content into a string from a UserContent sequence.
40+
41+
Parameters:
42+
user_content: A sequence of user content items to extract text from.
43+
44+
Returns:
45+
A single string with all text content joined by newlines.
46+
"""
47+
str_arr: list[str] = []
48+
for c in user_content:
49+
match c:
50+
case str() as s:
51+
str_arr.append(s)
52+
case TextContent(content=c):
53+
str_arr.append(c)
54+
55+
return "\n".join(str_arr)
56+
57+
58+
def _create_model_from_llama_stack_client(model_id: str) -> LlamaStackResponsesModel:
59+
"""Create a LlamaStackResponsesModel from the shared Llama Stack client.
60+
61+
Parameters:
62+
model_id: The model identifier to use for the responses model.
63+
64+
Returns:
65+
A configured LlamaStackResponsesModel instance.
66+
"""
67+
client = AsyncLlamaStackClientHolder().get_client()
68+
provider = llama_stack_provider_from_client(client)
69+
settings = OpenAIResponsesModelSettings(openai_store=False)
70+
return LlamaStackResponsesModel(model_id, provider=provider, settings=settings)
71+
72+
73+
@dataclass
74+
class QuestionValidity(AbstractCapability[None]):
75+
"""Block or modify user input based on a guardrail check.
76+
77+
The guard function receives the user prompt and returns True if safe.
78+
79+
Example:
80+
```python
81+
from pydantic_ai import Agent
82+
from pydantic_ai.models.openai import OpenAIResponsesModel
83+
84+
model = OpenAIResponsesModel("gpt-4o-mini")
85+
agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)])
86+
```
87+
"""
88+
89+
config: QuestionValidityConfig
90+
_model: Model = field(init=False)
91+
92+
def __post_init__(self) -> None:
93+
"""Initialize the model instance from the configured model ID."""
94+
self._model = _create_model_from_llama_stack_client(self.config.model_id)
95+
96+
def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
97+
"""Build the classification prompt from the user message.
98+
99+
Parameters:
100+
message: The user input as a string, sequence of user content, or None.
101+
102+
Returns:
103+
The rendered prompt string ready to send to the validity model.
104+
"""
105+
match message:
106+
case str() as s:
107+
_message = s
108+
case Sequence() as seq:
109+
_message = _extract_message_str_from_user_content(seq)
110+
case None:
111+
_message = ""
112+
113+
return Template(self.config.model_prompt).substitute(
114+
message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED
115+
)
116+
117+
async def wrap_run(
118+
self, ctx: RunContext, *, handler: WrapRunHandler
119+
) -> AgentRunResult:
120+
"""Run the question validity check before delegating to the main agent.
121+
122+
Sends the user prompt to the validity model for classification.
123+
If the question is allowed, the handler proceeds normally.
124+
Otherwise, a rejection response is returned and the main agent
125+
is bypassed.
126+
127+
Parameters:
128+
ctx: The run context containing the user prompt and usage tracker.
129+
handler: The handler that invokes the main agent run.
130+
131+
Returns:
132+
The agent run result, either from the main agent or a rejection.
133+
"""
134+
prompt = self._build_prompt(ctx.prompt)
135+
136+
result = await model_request(
137+
model=self._model,
138+
messages=[ModelRequest.user_text_prompt(prompt)],
139+
)
140+
141+
# Include token usage from the question validity request
142+
ctx.usage.incr(result.usage)
143+
144+
if result.text is not None and result.text.strip() == SUBJECT_ALLOWED:
145+
return await handler() # proceed with the real run
146+
147+
# short-circuit: return the rejection message with shield usage tracked
148+
state = GraphAgentState(usage=ctx.usage)
149+
return AgentRunResult(
150+
output=self.config.invalid_question_response, _state=state
151+
)

src/utils/pydantic_ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434

3535

36-
def _llama_stack_provider_from_client(
36+
def llama_stack_provider_from_client(
3737
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
3838
) -> LlamaStackProvider:
3939
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
@@ -133,7 +133,7 @@ def build_agent(
133133
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
134134
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
135135
"""
136-
provider = _llama_stack_provider_from_client(client)
136+
provider = llama_stack_provider_from_client(client)
137137
settings = _model_settings_from_responses_params(responses_params)
138138

139139
model = LlamaStackResponsesModel(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed capabilities."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for question validity capability."""

0 commit comments

Comments
 (0)