-
Notifications
You must be signed in to change notification settings - Fork 114
Expand file tree
/
Copy pathguardian.py
More file actions
158 lines (128 loc) · 6.13 KB
/
guardian.py
File metadata and controls
158 lines (128 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# pytest: ollama, e2e
"""Example of using the Enhanced Guardian Requirement with Granite Guardian 3.3 8B"""
from mellea import MelleaSession
from mellea.backends import model_ids
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends.tools import MelleaTool
from mellea.core import ContextTurn, ModelOutputThunk, ModelToolCall
from mellea.stdlib.components import Message
from mellea.stdlib.context import ChatContext
from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk
# Enhanced GuardianCheck with Granite Guardian 3.3 8B support
print("=== Enhanced GuardianCheck Examples ===")
# Basic harm checking with new features
guardian = GuardianCheck(GuardianRisk.HARM, thinking=True, backend_type="ollama")
print(f"Guardian configured: {guardian._risk} with thinking mode: {guardian._thinking}")
# Try different risk types with new 3.3 8B model
jailbreak_guardian = GuardianCheck(GuardianRisk.JAILBREAK, thinking=True)
print(f"Jailbreak guardian: {jailbreak_guardian._risk}")
# Custom risk criteria example
custom_guardian = GuardianCheck(
custom_criteria="Check for inappropriate content in educational context"
)
print(f"Custom guardian: {custom_guardian.get_effective_risk()}")
print("\n=== Test 1: Normal Content ===")
# Run a query against an Ollama model with ChatContext to support insert_turn
m = MelleaSession(OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext())
m.chat("Write a professional email to your colleague. Use less than 50 words.")
# Validate on the last turn of the conversation
valid = m.validate([guardian])
print(f"Content is safe: {valid[0]._result}")
if valid[0]._reason:
print(f"Guardian feedback: {valid[0]._reason}...")
print("\n=== Test 2: Harmful Content ===")
# Create harmful conversation and add to context
m.ctx = m.ctx.add(Message("user", "Hello")).add(
Message("assistant", "You are very ugly and stupid.")
)
valid = m.validate([guardian])
print(f"Content is safe: {valid[0]._result}")
if valid[0]._reason:
print(f"Guardian feedback: {valid[0]._reason}...")
print("\n=== Test 3: Enhanced Features ===")
# Create multiple guardians with different configurations
guardians = [
GuardianCheck(GuardianRisk.HARM, thinking=True),
GuardianCheck(GuardianRisk.JAILBREAK, thinking=True),
GuardianCheck(GuardianRisk.SOCIAL_BIAS),
GuardianCheck(custom_criteria="Check for financial advice"),
]
print(f"Available risk types ({len(GuardianCheck.get_available_risks())} total):")
for risk in GuardianCheck.get_available_risks(): # Show first 5
print(f" - {risk}")
print(" ...")
print(f"\nConfigured guardians: {len(guardians)} total")
# Show Ollama backend configuration
ollama_guardian = GuardianCheck(GuardianRisk.HARM, backend_type="ollama")
print(f" Ollama backend: {ollama_guardian._backend.model_version}") # type: ignore[attr-defined]
print("\n=== Test 4: Groundedness Detection ===")
# Test groundedness - detecting when responses lack factual grounding
context_text = "One significant part of treaty making is that signing a treaty implies recognition that the other side is a sovereign state and that the agreement being considered is enforceable under international law. Hence, nations can be very careful about terming an agreement to be a treaty. For example, within the United States, agreements between states are compacts and agreements between states and the federal government or between agencies of the government are memoranda of understanding."
groundedness_guardian = GuardianCheck(
GuardianRisk.GROUNDEDNESS,
thinking=True,
backend_type="ollama",
context_text=context_text,
)
# Create a response that makes ungrounded claims relative to provided context
groundedness_session = MelleaSession(
OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()
)
groundedness_session.ctx = groundedness_session.ctx.add(
Message("user", "What is the history of treaty making?")
).add(
Message(
"assistant",
"Treaty making began in ancient Rome when Julius Caesar invented the concept in 44 BC. The first treaty was signed between Rome and the Moon people, establishing trade routes through space.",
)
)
print("Testing response with ungrounded claims...")
groundedness_valid = groundedness_session.validate([groundedness_guardian])
print(f"Response is grounded: {groundedness_valid[0]._result}")
if groundedness_valid[0]._reason:
print(f"Groundedness feedback: {groundedness_valid[0]._reason}...")
print("\n=== Test 5: Function Call Hallucination Detection ===")
# Test function calling hallucination using IBM video example
from mellea.core import ModelOutputThunk, ModelToolCall
tools = [
{
"name": "views_list",
"description": "Fetches total views for a specified IBM video using the given API.",
"parameters": {
"video_id": {
"description": "The ID of the IBM video.",
"type": "int",
"default": "7178094165614464282",
}
},
}
]
function_guardian = GuardianCheck(
GuardianRisk.FUNCTION_CALL, thinking=True, backend_type="ollama", tools=tools
)
# User asks for views but assistant calls wrong function (comments_list instead of views_list)
# Create a proper ModelOutputThunk with tool_calls
def dummy_func(**kwargs):
pass
hallucinated_tool_calls = {
"comments_list": ModelToolCall(
name="comments_list",
func=MelleaTool.from_callable(dummy_func),
args={"video_id": 456789123, "count": 15},
)
}
hallucinated_output = ModelOutputThunk(
value="I'll fetch the views for you.", tool_calls=hallucinated_tool_calls
)
function_session = MelleaSession(
OllamaModelBackend(model_ids.DEEPSEEK_R1_8B), ctx=ChatContext()
)
function_session.ctx = function_session.ctx.add(
Message("user", "Fetch total views for the IBM video with ID 456789123.")
).add(hallucinated_output)
print("Testing response with function call hallucination...")
function_valid = function_session.validate([function_guardian])
print(f"Function calls are valid: {function_valid[0]._result}")
if function_valid[0]._reason:
print(f"Function call feedback: {function_valid[0]._reason}...")
print("\n=== GuardianCheck Demo Complete ===")