-
Notifications
You must be signed in to change notification settings - Fork 91
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI #1913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| """Pluggable capabilities for pydantic-ai agents in Lightspeed. | ||
|
|
||
| Provides safety, guardrail, and policy capabilities that hook into | ||
| pydantic-ai's AbstractCapability lifecycle to enforce constraints | ||
| before, during, or after agent runs. | ||
| """ | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """Question validity capability for agent input validation.""" | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity._capacity import ( | ||
| QuestionValidity, | ||
| ) | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| """Question validity capability for filtering off-topic user queries. | ||
|
|
||
| This module implements a guardrail that classifies user questions as | ||
| Kubernetes/OpenShift-related or not (It can be customized to any | ||
| topic as well), using an LLM-based check before the main agent | ||
| processes the request. Invalid questions are rejected with a | ||
| predefined response, bypassing the primary agent entirely. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from copy import copy | ||
| from dataclasses import dataclass | ||
| from string import Template | ||
|
|
||
| from pydantic_ai import AgentRunResult, RunContext | ||
| from pydantic_ai._agent_graph import GraphAgentState | ||
| from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler | ||
| from pydantic_ai.direct import model_request | ||
| from pydantic_ai.messages import ModelRequest, TextContent, UserContent | ||
| from pydantic_ai.models import Model | ||
|
|
||
| from log import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| DEFAULT_MODEL_PROMPT = """ | ||
| Instructions: | ||
| - You are a question classifying tool | ||
| - You are an expert in kubernetes and openshift | ||
| - Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response. | ||
| - If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}. | ||
| - Do not explain your answer, just provide the one-word response. Do not give any other response. | ||
| - If the given question is an empty string, answer with the word ${rejected} | ||
|
|
||
|
|
||
| Example Question: | ||
| Why is the sky blue? | ||
| Example Response: | ||
| ${rejected} | ||
|
|
||
| Example Question: | ||
| Why is the grass green? | ||
| Example Response: | ||
| ${rejected} | ||
|
|
||
| Example Question: | ||
| Why is sand yellow? | ||
| Example Response: | ||
| ${rejected} | ||
|
|
||
| Example Question: | ||
| Can you help configure my cluster to automatically scale? | ||
| Example Response: | ||
| ${allowed} | ||
|
|
||
| Question: | ||
| ${message} | ||
| Response: | ||
| """ | ||
|
|
||
| DEFAULT_INVALID_QUESTION_RESPONSE = """ | ||
| Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift, | ||
| please ask me a question related to OpenShift. | ||
| """ | ||
|
|
||
| SUBJECT_REJECTED = "REJECTED" | ||
| SUBJECT_ALLOWED = "ALLOWED" | ||
|
|
||
|
|
||
| def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str: | ||
| """Extract and combine all text content into a string from an UserContent sequence""" | ||
| str_arr: list[str] = [] | ||
| for c in user_content: | ||
| match c: | ||
| case str() as s: | ||
| str_arr.append(s) | ||
| case TextContent(content=c): | ||
| str_arr.append(c) | ||
|
|
||
| return "\n".join(str_arr) | ||
|
|
||
|
|
||
| def _remove_conversation_from_settings(model: Model) -> Model: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. During my testing, if we share |
||
| """Return a Model with 'conversation' removed from extra_body settings. | ||
|
|
||
| Only creates a shallow copy if 'conversation' exists in extra_body; otherwise returns the original model unchanged. | ||
| """ | ||
| if settings := model.settings: | ||
| if extra_body := settings.get("extra_body"): | ||
| if isinstance(extra_body, dict) and "conversation" in extra_body: | ||
| _extra_body = { | ||
| k: v for k, v in extra_body.items() if k != "conversation" | ||
| } | ||
| _settings = copy(settings) | ||
| _settings["extra_body"] = _extra_body | ||
| _model = copy(model) | ||
| _model._settings = _settings | ||
| return _model | ||
| return model | ||
|
|
||
|
|
||
| @dataclass | ||
| class QuestionValidity(AbstractCapability): | ||
| """Block or modify user input based on a guardrail check. | ||
|
|
||
| The guard function receives the user prompt and returns True if safe. | ||
|
|
||
| Example: | ||
| ```python | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.openai import OpenAIResponsesModel | ||
|
|
||
| model = OpenAIResponsesModel("gpt-4o-mini") | ||
| agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)]) | ||
| ``` | ||
| """ | ||
|
|
||
| model: Model | ||
| """The model to use for the question validity check.""" | ||
|
|
||
| model_prompt: str = DEFAULT_MODEL_PROMPT | ||
| """The prompt to use for the question validity check.""" | ||
|
|
||
| invalid_question_response: str = DEFAULT_INVALID_QUESTION_RESPONSE | ||
| """The response to use when the question is determined to be invalid.""" | ||
|
|
||
| def __post_init__(self) -> None: | ||
| self.model = _remove_conversation_from_settings(self.model) | ||
|
|
||
| def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str: | ||
| match message: | ||
| case str() as s: | ||
| _message = s | ||
| case Sequence() as seq: | ||
| _message = _extract_message_str_from_user_content(seq) | ||
| case None: | ||
| _message = "" | ||
|
|
||
| return Template(self.model_prompt).substitute( | ||
| message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED | ||
| ) | ||
|
|
||
| async def wrap_run( | ||
| self, ctx: RunContext, *, handler: WrapRunHandler | ||
| ) -> AgentRunResult: | ||
| prompt = self._build_prompt(ctx.prompt) | ||
|
|
||
| result = await model_request( | ||
| model=self.model, | ||
| messages=[ModelRequest.user_text_prompt(prompt)], | ||
| ) | ||
|
|
||
| # Include token usage from the question validity request | ||
| ctx.usage.incr(result.usage) | ||
|
|
||
| if result.text == SUBJECT_ALLOWED: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify guardrail tests currently cover normalization behavior.
rg -n -C2 'SUBJECT_ALLOWED|result\.text|wrap_run' tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.pyRepository: lightspeed-core/lightspeed-stack Length of output: 4763 🏁 Script executed: #!/bin/bash
set -euo pipefail
# Show the exact comparison at the referenced line range
python - <<'PY'
from pathlib import Path
path = Path("src/pydantic_ai_lightspeed/capabilities/question_validity/_capacity.py")
lines = path.read_text().splitlines()
start = 130
end = 175
for i in range(start, min(end, len(lines))):
print(f"{i+1:4d}: {lines[i]}")
PY
# Check tests for normalization/whitespace/casing expectations
rg -n "SUBJECT_ALLOWED|SUBJECT_REJECTED|strip\(|upper\(|lower\(|whitespace|casefold|normalize" \
tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.pyRepository: lightspeed-core/lightspeed-stack Length of output: 2123 🏁 Script executed: #!/bin/bash
set -euo pipefail
# 1) Inspect the wrap_run tests area
python - <<'PY'
from pathlib import Path
path = Path("tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.py")
lines = path.read_text().splitlines()
# Print a focused window around the wrap_run tests
start = 240
end = 560
for i in range(start, min(end, len(lines))):
print(f"{i+1:4d}: {lines[i]}")
PY
# 2) Search for any normalization on result.text in this capability
rg -n "result\.text|strip\(|upper\(|lower\(|casefold\(|normalize" \
src/pydantic_ai_lightspeed/capabilities/question_validity/_capacity.py src/pydantic_ai_lightspeed/capabilities/question_validity -S
# 3) Search the tests for any variants (e.g., whitespace/case) around SUBJECT_ALLOWED/REJECTED
rg -n "TextPart\(content=\s*['\"]\s*(SUBJECT_ALLOWED|SUBJECT_REJECTED)" tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.py
rg -n "SUBJECT_ALLOWED\s*(\)|\]|,)|SUBJECT_REJECTED\s*(\)|\]|,)" tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.pyRepository: lightspeed-core/lightspeed-stack Length of output: 12090 🏁 Script executed: #!/bin/bash
set -euo pipefail
# Locate model_request implementation (within this capability if present)
rg -n "def model_request|async def model_request|model_request\s*\(" -S src/pydantic_ai_lightspeed/capabilities/question_validity
# Show the top of _capacity.py to understand how result.text is produced/used
python - <<'PY'
from pathlib import Path
path = Path("src/pydantic_ai_lightspeed/capabilities/question_validity/_capacity.py")
lines = path.read_text().splitlines()
for i in range(1, min(220, len(lines))+1):
if i<=40 or (140<=i<=180):
print(f"{i:4d}: {lines[i-1]}")
PY
# Search for other SUBJECT_ALLOWED comparisons
rg -n "result\.text\s*==\s*SUBJECT_ALLOWED|SUBJECT_ALLOWED\s*in\s*result\.text|SUBJECT_ALLOWED" \
src/pydantic_ai_lightspeed/capabilities/question_validity -S
# Search tests for any normalization or casing/whitespace cases
rg -n "SUBJECT_ALLOWED.*(strip|upper|lower|casefold)|SUBJECT_REJECTED.*(strip|upper|lower|casefold)" \
tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capacity.pyRepository: lightspeed-core/lightspeed-stack Length of output: 3233 Normalize guardrail classifier output before ALLOWED/REJECTED comparison.
Suggested fix- if result.text == SUBJECT_ALLOWED:
+ classification = result.text.strip().upper()
+ if classification == SUBJECT_ALLOWED:
return await handler() # proceed with the real run
else:
# short-circuit: return the rejection message with shield usage tracked
state = GraphAgentState(usage=ctx.usage)
return AgentRunResult(output=self.invalid_question_response, _state=state)🤖 Prompt for AI Agents |
||
| return await handler() # proceed with the real run | ||
| else: | ||
| # short-circuit: return the rejection message with shield usage tracked | ||
| state = GraphAgentState(usage=ctx.usage) | ||
| return AgentRunResult(output=self.invalid_question_response, _state=state) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for pydantic_ai_lightspeed capabilities.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for question validity capability.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Add missing function docstrings and complete annotations.
Multiple functions in this module are missing descriptive docstrings, and
__post_init__(Line 123) is missing an explicit-> Nonereturn annotation.As per coding guidelines,
src/**/*.py: “All functions must have complete type annotations for parameters and return types … and include descriptive docstrings” and “Follow Google Python docstring conventions with required sections”.Also applies to: 83-83, 123-123, 126-126, 139-141
🤖 Prompt for AI Agents
Source: Coding guidelines