Skip to content

Commit db2b52b

Browse files
fix: Support GuardrailResult return type in guardrail functions (#875)
- Modified task.py to accept both GuardrailResult and Tuple[bool, Any] return types - Updated _process_guardrail to handle GuardrailResult directly without conversion - Added comprehensive test cases for both return types - Updated documentation with GuardrailResult examples - Maintains backward compatibility with existing Tuple[bool, Any] guardrails Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent c971fa8 commit db2b52b

4 files changed

Lines changed: 382 additions & 19 deletions

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Fixed example demonstrating proper guardrail usage with PraisonAI Agents.
4+
This addresses the issues reported in issue #875.
5+
"""
6+
7+
from praisonaiagents import Agent, Task, GuardrailResult, PraisonAIAgents
8+
from typing import Tuple, Any
9+
import trafilatura
10+
11+
# Example 1: Using GuardrailResult return type (now supported!)
12+
def validate_length_guardrailresult(output) -> GuardrailResult:
13+
"""Ensure output is between 100-500 characters using GuardrailResult"""
14+
# Extract the raw text from the TaskOutput object
15+
text = output.raw if hasattr(output, 'raw') else str(output)
16+
length = len(text)
17+
18+
if 100 <= length <= 500:
19+
return GuardrailResult(
20+
success=True,
21+
result=output, # Pass through the original output
22+
error=""
23+
)
24+
else:
25+
return GuardrailResult(
26+
success=False,
27+
result=None,
28+
error=f"Output must be 100-500 chars, got {length}"
29+
)
30+
31+
# Example 2: Using Tuple[bool, Any] return type (original method)
32+
def validate_length_tuple(output) -> Tuple[bool, Any]:
33+
"""Ensure output is between 100-500 characters using tuple"""
34+
text = output.raw if hasattr(output, 'raw') else str(output)
35+
length = len(text)
36+
37+
if 100 <= length <= 500:
38+
return True, output
39+
else:
40+
return False, f"Output must be 100-500 chars, got {length}"
41+
42+
# Tool function
43+
def get_url_context(url):
44+
"""Fetch and extract content from a URL"""
45+
downloaded = trafilatura.fetch_url(url)
46+
if not downloaded:
47+
return "Sorry, I couldn't fetch the content from that URL."
48+
49+
extracted = trafilatura.extract(
50+
downloaded,
51+
include_comments=False,
52+
include_links=True,
53+
output_format='json',
54+
with_metadata=True,
55+
url=url
56+
)
57+
58+
if not extracted:
59+
return "Sorry, I couldn't extract readable content from that page."
60+
61+
return extracted # returns JSON string
62+
63+
# Create agent with FIXED tools parameter (must be a list!)
64+
agent = Agent(
65+
name="Content Summarizer",
66+
role="Content Analysis Expert",
67+
goal="Summarize web content concisely",
68+
instructions="You are a helpful assistant that summarizes web content",
69+
llm="gemini/gemini-2.5-flash-lite-preview-06-17",
70+
self_reflect=False,
71+
verbose=True,
72+
tools=[get_url_context] # FIX: tools must be a list, not a single function
73+
)
74+
75+
# Create task with GuardrailResult guardrail
76+
task_with_guardrailresult = Task(
77+
name="summarise article with GuardrailResult",
78+
description="get the context of this url: https://blog.google/technology/ai/dolphingemma/ and produce a summary below 500 characters",
79+
agent=agent,
80+
guardrail=validate_length_guardrailresult, # Using GuardrailResult
81+
expected_output="summary of the article below 500 characters",
82+
max_retries=3 # Will retry up to 3 times if guardrail fails
83+
)
84+
85+
# Alternative: Create task with tuple guardrail
86+
task_with_tuple = Task(
87+
name="summarise article with tuple",
88+
description="get the context of this url: https://blog.google/technology/ai/dolphingemma/ and produce a summary below 500 characters",
89+
agent=agent,
90+
guardrail=validate_length_tuple, # Using Tuple[bool, Any]
91+
expected_output="summary of the article below 500 characters",
92+
max_retries=3
93+
)
94+
95+
# Example with string-based LLM guardrail
96+
task_with_llm_guardrail = Task(
97+
name="summarise with LLM guardrail",
98+
description="get the context of this url: https://blog.google/technology/ai/dolphingemma/ and produce a summary",
99+
agent=agent,
100+
guardrail="Ensure the summary is professional, factual, and between 100-500 characters",
101+
expected_output="professional summary of the article"
102+
)
103+
104+
# Run with GuardrailResult example
105+
print("=== Running with GuardrailResult guardrail ===")
106+
agents_gr = PraisonAIAgents(
107+
agents=[agent],
108+
tasks=[task_with_guardrailresult]
109+
)
110+
111+
# Uncomment to run:
112+
# result_gr = agents_gr.start()
113+
114+
# Run with Tuple example
115+
print("\n=== Running with Tuple[bool, Any] guardrail ===")
116+
agents_tuple = PraisonAIAgents(
117+
agents=[agent],
118+
tasks=[task_with_tuple]
119+
)
120+
121+
# Uncomment to run:
122+
# result_tuple = agents_tuple.start()
123+
124+
# Run with LLM guardrail example
125+
print("\n=== Running with LLM-based guardrail ===")
126+
agents_llm = PraisonAIAgents(
127+
agents=[agent],
128+
tasks=[task_with_llm_guardrail]
129+
)
130+
131+
# Uncomment to run:
132+
# result_llm = agents_llm.start()
133+
134+
print("""
135+
Key fixes applied:
136+
1. GuardrailResult is now accepted as a valid return type annotation
137+
2. tools parameter must be a list: tools=[get_url_context] not tools=get_url_context
138+
3. Both GuardrailResult and Tuple[bool, Any] return types are supported
139+
4. String-based LLM guardrails are also supported
140+
141+
The guardrail will automatically retry (up to max_retries times) if validation fails.
142+
""")

src/praisonai-agents/CLAUDE.md

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,32 @@ task = Task(
139139
#### Task-Level Guardrails
140140
```python
141141
from typing import Tuple, Any
142+
from praisonaiagents import GuardrailResult
142143

143-
# Function-based guardrail
144-
def validate_output(task_output: TaskOutput) -> Tuple[bool, Any]:
145-
"""Custom validation function."""
144+
# Function-based guardrail (Option 1: Using GuardrailResult)
145+
def validate_output(task_output: TaskOutput) -> GuardrailResult:
146+
"""Custom validation function returning GuardrailResult."""
147+
if "error" in task_output.raw.lower():
148+
return GuardrailResult(
149+
success=False,
150+
result=None,
151+
error="Output contains errors"
152+
)
153+
if len(task_output.raw) < 10:
154+
return GuardrailResult(
155+
success=False,
156+
result=None,
157+
error="Output is too short"
158+
)
159+
return GuardrailResult(
160+
success=True,
161+
result=task_output,
162+
error=""
163+
)
164+
165+
# Function-based guardrail (Option 2: Using Tuple[bool, Any])
166+
def validate_output_tuple(task_output: TaskOutput) -> Tuple[bool, Any]:
167+
"""Custom validation function returning tuple."""
146168
if "error" in task_output.raw.lower():
147169
return False, "Output contains errors"
148170
if len(task_output.raw) < 10:
@@ -170,7 +192,27 @@ task = Task(
170192
#### Agent-Level Guardrails
171193
```python
172194
# Agent guardrails apply to ALL outputs from that agent
173-
def validate_professional_tone(task_output: TaskOutput) -> Tuple[bool, Any]:
195+
196+
# Option 1: Using GuardrailResult
197+
def validate_professional_tone(task_output: TaskOutput) -> GuardrailResult:
198+
"""Ensure professional tone in all agent responses."""
199+
content = task_output.raw.lower()
200+
casual_words = ['yo', 'dude', 'awesome', 'cool']
201+
for word in casual_words:
202+
if word in content:
203+
return GuardrailResult(
204+
success=False,
205+
result=None,
206+
error=f"Unprofessional language detected: {word}"
207+
)
208+
return GuardrailResult(
209+
success=True,
210+
result=task_output,
211+
error=""
212+
)
213+
214+
# Option 2: Using Tuple[bool, Any]
215+
def validate_professional_tone_tuple(task_output: TaskOutput) -> Tuple[bool, Any]:
174216
"""Ensure professional tone in all agent responses."""
175217
content = task_output.raw.lower()
176218
casual_words = ['yo', 'dude', 'awesome', 'cool']

src/praisonai-agents/praisonaiagents/task/task.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,33 @@ def _setup_guardrail(self):
172172
# Check return annotation if present
173173
return_annotation = sig.return_annotation
174174
if return_annotation != inspect.Signature.empty:
175-
return_annotation_args = get_args(return_annotation)
176-
if not (
177-
get_origin(return_annotation) is tuple
178-
and len(return_annotation_args) == 2
179-
and return_annotation_args[0] is bool
180-
and (
181-
return_annotation_args[1] is Any
182-
or return_annotation_args[1] is str
183-
or return_annotation_args[1] is TaskOutput
184-
or return_annotation_args[1] == Union[str, TaskOutput]
185-
)
175+
# Import GuardrailResult for checking
176+
from ..guardrails import GuardrailResult
177+
178+
# Check if it's a GuardrailResult type
179+
if return_annotation is GuardrailResult or (
180+
hasattr(return_annotation, '__name__') and
181+
return_annotation.__name__ == 'GuardrailResult'
186182
):
187-
raise ValueError(
188-
"If return type is annotated, it must be Tuple[bool, Any]"
189-
)
183+
# Valid GuardrailResult return type
184+
pass
185+
else:
186+
# Check for tuple return type
187+
return_annotation_args = get_args(return_annotation)
188+
if not (
189+
get_origin(return_annotation) is tuple
190+
and len(return_annotation_args) == 2
191+
and return_annotation_args[0] is bool
192+
and (
193+
return_annotation_args[1] is Any
194+
or return_annotation_args[1] is str
195+
or return_annotation_args[1] is TaskOutput
196+
or return_annotation_args[1] == Union[str, TaskOutput]
197+
)
198+
):
199+
raise ValueError(
200+
"If return type is annotated, it must be GuardrailResult or Tuple[bool, Any]"
201+
)
190202

191203
self._guardrail_fn = self.guardrail
192204
elif isinstance(self.guardrail, str):
@@ -447,7 +459,11 @@ def _process_guardrail(self, task_output: TaskOutput):
447459
# Call the guardrail function
448460
result = self._guardrail_fn(task_output)
449461

450-
# Convert the result to a GuardrailResult
462+
# Check if result is already a GuardrailResult
463+
if isinstance(result, GuardrailResult):
464+
return result
465+
466+
# Otherwise, convert the tuple result to a GuardrailResult
451467
return GuardrailResult.from_tuple(result)
452468

453469
except Exception as e:

0 commit comments

Comments
 (0)