Skip to content

Commit d60eb25

Browse files
authored
Merge pull request #14 from ivanmkc/tk-safety
Feat: Adding go snippets for safety guidelines
2 parents 33dfea7 + f30e65f commit d60eb25

1 file changed

Lines changed: 182 additions & 40 deletions

File tree

docs/safety/index.md

Lines changed: 182 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ As AI agents grow in capability, ensuring they operate safely, securely, and ali
99
1. **Identity and Authorization**: Control who the agent **acts as** by defining agent and user auth.
1010
2. **Guardrails to screen inputs and outputs:** Control your model and tool calls precisely.
1111

12-
* *In-Tool Guardrails:* Design tools defensively, using developer-set tool context to enforce policies (e.g., allowing queries only on specific tables).
13-
* *Built-in Gemini Safety Features:* If using Gemini models, benefit from content filters to block harmful outputs and system Instructions to guide the model's behavior and safety guidelines
12+
* *In-Tool Guardrails:* Design tools defensively, using developer-set tool context to enforce policies (e.g., allowing queries only on specific tables).
13+
* *Built-in Gemini Safety Features:* If using Gemini models, benefit from content filters to block harmful outputs and system Instructions to guide the model's behavior and safety guidelines
1414
* *Callbacks and Plugins:* Validate model and tool calls before or after execution, checking parameters against agent state or external policies.
1515
* *Using Gemini as a safety guardrail:* Implement an additional safety layer using a cheap and fast model (like Gemini Flash Lite) configured via callbacks to screen inputs and outputs.
1616

17-
3. **Sandboxed code execution:** Prevent model-generated code to cause security issues by sandboxing the environment
17+
3. **Sandboxed code execution:** Prevent model-generated code to cause security issues by sandboxing the environment
1818
4. **Evaluation and tracing**: Use evaluation tools to assess the quality, relevance, and correctness of the agent's final output. Use tracing to gain visibility into agent actions to analyze the steps an agent takes to reach a solution, including its choice of tools, strategies, and the efficiency of its approach.
1919
5. **Network Controls and VPC-SC:** Confine agent activity within secure perimeters (like VPC Service Controls) to prevent data exfiltration and limit the potential impact radius.
2020

@@ -25,20 +25,20 @@ Before implementing safety measures, perform a thorough risk assessment specific
2525
***Sources*** **of risk** include:
2626

2727
* Ambiguous agent instructions
28-
* Prompt injection and jailbreak attempts from adversarial users
28+
* Prompt injection and jailbreak attempts from adversarial users
2929
* Indirect prompt injections via tool use
3030

3131
**Risk categories** include:
3232

33-
* **Misalignment & goal corruption**
34-
* Pursuing unintended or proxy goals that lead to harmful outcomes ("reward hacking")
35-
* Misinterpreting complex or ambiguous instructions
33+
* **Misalignment & goal corruption**
34+
* Pursuing unintended or proxy goals that lead to harmful outcomes ("reward hacking")
35+
* Misinterpreting complex or ambiguous instructions
3636
* **Harmful content generation, including brand safety**
37-
* Generating toxic, hateful, biased, sexually explicit, discriminatory, or illegal content
38-
* Brand safety risks such as Using language that goes against the brand’s values or off-topic conversations
39-
* **Unsafe actions**
37+
* Generating toxic, hateful, biased, sexually explicit, discriminatory, or illegal content
38+
* Brand safety risks such as Using language that goes against the brand’s values or off-topic conversations
39+
* **Unsafe actions**
4040
* Executing commands that damage systems
41-
* Making unauthorized purchases or financial transactions.
41+
* Making unauthorized purchases or financial transactions.
4242
* Leaking sensitive personal data (PII)
4343
* Data exfiltration
4444

@@ -78,16 +78,16 @@ For example, a query tool can be designed to expect a policy to be read from the
7878
# Conceptual example: Setting policy data intended for tool context
7979
# In a real ADK app, this might be set in InvocationContext.session.state
8080
# or passed during tool initialization, then retrieved via ToolContext.
81-
81+
8282
policy = {} # Assuming policy is a dictionary
8383
policy['select_only'] = True
8484
policy['tables'] = ['mytable1', 'mytable2']
85-
85+
8686
# Conceptual: Storing policy where the tool can access it via ToolContext later.
8787
# This specific line might look different in practice.
8888
# For example, storing in session state:
8989
invocation_context.session.state["query_tool_policy"] = policy
90-
90+
9191
# Or maybe passing during tool init:
9292
query_tool = QueryTool(policy=policy)
9393
# For this example, we'll assume it gets stored somewhere accessible.
@@ -98,20 +98,43 @@ For example, a query tool can be designed to expect a policy to be read from the
9898
// Conceptual example: Setting policy data intended for tool context
9999
// In a real ADK app, this might be set in InvocationContext.session.state
100100
// or passed during tool initialization, then retrieved via ToolContext.
101-
101+
102102
policy = new HashMap<String, Object>(); // Assuming policy is a Map
103103
policy.put("select_only", true);
104104
policy.put("tables", new ArrayList<>("mytable1", "mytable2"));
105-
105+
106106
// Conceptual: Storing policy where the tool can access it via ToolContext later.
107107
// This specific line might look different in practice.
108108
// For example, storing in session state:
109109
invocationContext.session().state().put("query_tool_policy", policy);
110-
110+
111111
// Or maybe passing during tool init:
112112
query_tool = QueryTool(policy);
113113
// For this example, we'll assume it gets stored somewhere accessible.
114114
```
115+
=== "Go"
116+
117+
```go
118+
// Conceptual example: Setting policy data intended for tool context
119+
// In a real ADK app, this might be set using the session state service.
120+
// `ctx` is an `agent.Context` available in callbacks or custom agents.
121+
122+
policy := map[string]interface{}{
123+
"select_only": true,
124+
"tables": []string{"mytable1", "mytable2"},
125+
}
126+
127+
// Conceptual: Storing policy where the tool can access it via ToolContext later.
128+
// This specific line might look different in practice.
129+
// For example, storing in session state:
130+
if err := ctx.Session().State().Set("query_tool_policy", policy); err != nil {
131+
// Handle error, e.g., log it.
132+
}
133+
134+
// Or maybe passing during tool init:
135+
// queryTool := NewQueryTool(policy)
136+
// For this example, we'll assume it gets stored somewhere accessible.
137+
```
115138

116139
During the tool execution, [**`Tool Context`**](../tools/index.md#tool-context) will be passed to the tool:
117140

@@ -121,60 +144,60 @@ During the tool execution, [**`Tool Context`**](../tools/index.md#tool-context)
121144
def query(query: str, tool_context: ToolContext) -> str | dict:
122145
# Assume 'policy' is retrieved from context, e.g., via session state:
123146
# policy = tool_context.invocation_context.session.state.get('query_tool_policy', {})
124-
147+
125148
# --- Placeholder Policy Enforcement ---
126149
policy = tool_context.invocation_context.session.state.get('query_tool_policy', {}) # Example retrieval
127150
actual_tables = explainQuery(query) # Hypothetical function call
128-
151+
129152
if not set(actual_tables).issubset(set(policy.get('tables', []))):
130153
# Return an error message for the model
131154
allowed = ", ".join(policy.get('tables', ['(None defined)']))
132155
return f"Error: Query targets unauthorized tables. Allowed: {allowed}"
133-
156+
134157
if policy.get('select_only', False):
135158
if not query.strip().upper().startswith("SELECT"):
136159
return "Error: Policy restricts queries to SELECT statements only."
137160
# --- End Policy Enforcement ---
138-
161+
139162
print(f"Executing validated query (hypothetical): {query}")
140163
return {"status": "success", "results": [...]} # Example successful return
141164
```
142165

143166
=== "Java"
144167

145168
```java
146-
169+
147170
import com.google.adk.tools.ToolContext;
148171
import java.util.*;
149-
172+
150173
class ToolContextQuery {
151-
174+
152175
public Object query(String query, ToolContext toolContext) {
153176

154177
// Assume 'policy' is retrieved from context, e.g., via session state:
155178
Map<String, Object> queryToolPolicy =
156179
toolContext.invocationContext.session().state().getOrDefault("query_tool_policy", null);
157180
List<String> actualTables = explainQuery(query);
158-
181+
159182
// --- Placeholder Policy Enforcement ---
160183
if (!queryToolPolicy.get("tables").containsAll(actualTables)) {
161184
List<String> allowedPolicyTables =
162185
(List<String>) queryToolPolicy.getOrDefault("tables", new ArrayList<String>());
163186

164187
String allowedTablesString =
165188
allowedPolicyTables.isEmpty() ? "(None defined)" : String.join(", ", allowedPolicyTables);
166-
189+
167190
return String.format(
168191
"Error: Query targets unauthorized tables. Allowed: %s", allowedTablesString);
169192
}
170-
193+
171194
if (!queryToolPolicy.get("select_only")) {
172195
if (!query.trim().toUpperCase().startswith("SELECT")) {
173196
return "Error: Policy restricts queries to SELECT statements only.";
174197
}
175198
}
176199
// --- End Policy Enforcement ---
177-
200+
178201
System.out.printf("Executing validated query (hypothetical) %s:", query);
179202
Map<String, Object> successResult = new HashMap<>();
180203
successResult.put("status", "success");
@@ -183,14 +206,69 @@ During the tool execution, [**`Tool Context`**](../tools/index.md#tool-context)
183206
}
184207
}
185208
```
209+
=== "Go"
210+
211+
```go
212+
import (
213+
"fmt"
214+
"strings"
215+
216+
"google.golang.org/adk/tool"
217+
)
218+
219+
func query(query string, toolContext *tool.Context) (any, error) {
220+
// Assume 'policy' is retrieved from context, e.g., via session state:
221+
policyAny, err := toolContext.State().Get("query_tool_policy")
222+
if err != nil {
223+
return nil, fmt.Errorf("could not retrieve policy: %w", err)
224+
} policy, _ := policyAny.(map[string]interface{})
225+
actualTables := explainQuery(query) // Hypothetical function call
226+
227+
// --- Placeholder Policy Enforcement ---
228+
if tables, ok := policy["tables"].([]string); ok {
229+
if !isSubset(actualTables, tables) {
230+
// Return an error to signal failure
231+
allowed := strings.Join(tables, ", ")
232+
if allowed == "" {
233+
allowed = "(None defined)"
234+
}
235+
return nil, fmt.Errorf("query targets unauthorized tables. Allowed: %s", allowed)
236+
}
237+
}
238+
239+
if selectOnly, _ := policy["select_only"].(bool); selectOnly {
240+
if !strings.HasPrefix(strings.ToUpper(strings.TrimSpace(query)), "SELECT") {
241+
return nil, fmt.Errorf("policy restricts queries to SELECT statements only")
242+
}
243+
}
244+
// --- End Policy Enforcement ---
245+
246+
fmt.Printf("Executing validated query (hypothetical): %s\n", query)
247+
return map[string]interface{}{"status": "success", "results": []string{"..."}}, nil
248+
}
249+
250+
// Helper function to check if a is a subset of b
251+
func isSubset(a, b []string) bool {
252+
set := make(map[string]bool)
253+
for _, item := range b {
254+
set[item] = true
255+
}
256+
for _, item := range a {
257+
if _, found := set[item]; !found {
258+
return false
259+
}
260+
}
261+
return true
262+
}
263+
```
186264

187265
#### Built-in Gemini Safety Features
188266

189267
Gemini models come with in-built safety mechanisms that can be leveraged to improve content and brand safety.
190268

191-
* **Content safety filters**: [Content filters](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes) can help block the output of harmful content. They function independently from Gemini models as part of a layered defense against threat actors who attempt to jailbreak the model. Gemini models on Vertex AI use two types of content filters:
192-
* **Non-configurable safety filters** automatically block outputs containing prohibited content, such as child sexual abuse material (CSAM) and personally identifiable information (PII).
193-
* **Configurable content filters** allow you to define blocking thresholds in four harm categories (hate speech, harassment, sexually explicit, and dangerous content,) based on probability and severity scores. These filters are default off but you can configure them according to your needs.
269+
* **Content safety filters**: [Content filters](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes) can help block the output of harmful content. They function independently from Gemini models as part of a layered defense against threat actors who attempt to jailbreak the model. Gemini models on Vertex AI use two types of content filters:
270+
* **Non-configurable safety filters** automatically block outputs containing prohibited content, such as child sexual abuse material (CSAM) and personally identifiable information (PII).
271+
* **Configurable content filters** allow you to define blocking thresholds in four harm categories (hate speech, harassment, sexually explicit, and dangerous content,) based on probability and severity scores. These filters are default off but you can configure them according to your needs.
194272
* **System instructions for safety**: [System instructions](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/safety-system-instructions) for Gemini models in Vertex AI provide direct guidance to the model on how to behave and what type of content to generate. By providing specific instructions, you can proactively steer the model away from generating undesirable content to meet your organization’s unique needs. You can craft system instructions to define content safety guidelines, such as prohibited and sensitive topics, and disclaimer language, as well as brand safety guidelines to ensure the model's outputs align with your brand's voice, tone, values, and target audience.
195273

196274
While these measures are robust against content safety, you need additional checks to reduce agent misalignment, unsafe actions, and brand safety risks.
@@ -211,22 +289,22 @@ When modifications to the tools to add guardrails aren't possible, the [**`Befor
211289
args: Dict[str, Any],
212290
tool_context: ToolContext
213291
) -> Optional[Dict]: # Correct return type for before_tool_callback
214-
292+
215293
print(f"Callback triggered for tool: {tool.name}, args: {args}")
216-
294+
217295
# Example validation: Check if a required user ID from state matches an arg
218296
expected_user_id = callback_context.state.get("session_user_id")
219297
actual_user_id_in_args = args.get("user_id_param") # Assuming tool takes 'user_id_param'
220-
298+
221299
if actual_user_id_in_args != expected_user_id:
222300
print("Validation Failed: User ID mismatch!")
223301
# Return a dictionary to prevent tool execution and provide feedback
224302
return {"error": f"Tool call blocked: User ID mismatch."}
225-
303+
226304
# Return None to allow the tool call to proceed if validation passes
227305
print("Callback validation passed.")
228306
return None
229-
307+
230308
# Hypothetical Agent setup
231309
root_agent = LlmAgent( # Use specific agent type
232310
model='gemini-2.0-flash',
@@ -251,22 +329,22 @@ When modifications to the tools to add guardrails aren't possible, the [**`Befor
251329
ToolContext toolContext) {
252330

253331
System.out.printf("Callback triggered for tool: %s, Args: %s", baseTool.name(), input);
254-
332+
255333
// Example validation: Check if a required user ID from state matches an input parameter
256334
Object expectedUserId = callbackContext.state().get("session_user_id");
257335
Object actualUserIdInput = input.get("user_id_param"); // Assuming tool takes 'user_id_param'
258-
336+
259337
if (!actualUserIdInput.equals(expectedUserId)) {
260338
System.out.println("Validation Failed: User ID mismatch!");
261339
// Return to prevent tool execution and provide feedback
262340
return Optional.of(Map.of("error", "Tool call blocked: User ID mismatch."));
263341
}
264-
342+
265343
// Return to allow the tool call to proceed if validation passes
266344
System.out.println("Callback validation passed.");
267345
return Optional.empty();
268346
}
269-
347+
270348
// Hypothetical Agent setup
271349
public void runAgent() {
272350
LlmAgent agent =
@@ -279,6 +357,70 @@ When modifications to the tools to add guardrails aren't possible, the [**`Befor
279357
.build();
280358
}
281359
```
360+
=== "Go"
361+
362+
```go
363+
import (
364+
"fmt"
365+
"reflect"
366+
367+
"google.golang.org/adk/agent/llmagent"
368+
"google.golang.org/adk/tool"
369+
)
370+
371+
// Hypothetical callback function
372+
func validateToolParams(
373+
ctx tool.Context,
374+
t tool.Tool,
375+
args map[string]any,
376+
) (map[string]any, error) {
377+
fmt.Printf("Callback triggered for tool: %s, args: %v\n", t.Name(), args)
378+
379+
// Example validation: Check if a required user ID from state matches an arg
380+
expectedUserID, err := ctx.State().Get("session_user_id")
381+
if err != nil {
382+
// This is an unexpected failure, return an error.
383+
return nil, fmt.Errorf("internal error: session_user_id not found in state: %w", err)
384+
}
385+
expectedUserID, ok := expectedUserIDVal.(string)
386+
if !ok {
387+
return nil, fmt.Errorf("internal error: session_user_id in state is not a string, got %T", expectedUserIDVal)
388+
}
389+
390+
actualUserIDInArgs, exists := args["user_id_param"]
391+
if !exists {
392+
// Handle case where user_id_param is not in args
393+
fmt.Println("Validation Failed: user_id_param missing from arguments!")
394+
return map[string]any{"error": "Tool call blocked: user_id_param missing from arguments."}, nil
395+
}
396+
397+
actualUserID, ok := actualUserIDInArgs.(string)
398+
if !ok {
399+
// Handle case where user_id_param is not a string
400+
fmt.Println("Validation Failed: user_id_param is not a string!")
401+
return map[string]any{"error": "Tool call blocked: user_id_param is not a string."}, nil
402+
}
403+
404+
if actualUserID != expectedUserID {
405+
fmt.Println("Validation Failed: User ID mismatch!")
406+
// Return a map to prevent tool execution and provide feedback to the model.
407+
// This is not a Go error, but a message for the agent.
408+
return map[string]any{"error": "Tool call blocked: User ID mismatch."}, nil
409+
}
410+
// Return nil, nil to allow the tool call to proceed if validation passes
411+
fmt.Println("Callback validation passed.")
412+
return nil, nil
413+
}
414+
415+
// Hypothetical Agent setup
416+
// rootAgent, err := llmagent.New(llmagent.Config{
417+
// Model: "gemini-2.0-flash",
418+
// Name: "root_agent",
419+
// Instruction: "...",
420+
// BeforeToolCallbacks: []llmagent.BeforeToolCallback{validateToolParams},
421+
// Tools: []tool.Tool{queryToolInstance},
422+
// })
423+
```
282424

283425
However, when adding security guardrails to your agent applications, plugins are the recommended approach for implementing policies that are not specific to a single agent. Plugins are designed to be self-contained and modular, allowing you to create individual plugins for specific security policies, and apply them globally at the runner level. This means that a security plugin can be configured once and applied to every agent that uses the runner, ensuring consistent security guardrails across your entire application without repetitive code.
284426

0 commit comments

Comments
 (0)