Skip to content

Commit ee85112

Browse files
committed
Fix time bound approval and trust cascading test cases
1 parent d2c3e61 commit ee85112

File tree

8 files changed

+90
-33
lines changed

8 files changed

+90
-33
lines changed

tools/src/main/java/dev/cel/tools/ai/AgentMessageSet.java

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ abstract static class Builder {
114114
* filtered history.
115115
*/
116116
AgentContext filteredContext() {
117-
if (!context().hasExtension(AgentContextExtensions.agentContextMessageHistory)) {
117+
List<AgentMessage> msgs = context().getExtension(AgentContextExtensions.agentContextMessageHistory);
118+
if (msgs.isEmpty()) {
118119
return context();
119120
}
120-
List<AgentMessage> msgs = context().getExtension(AgentContextExtensions.agentContextMessageHistory);
121121
List<AgentMessage> filteredMsgs = new ArrayList<>();
122122

123123
for (AgentMessage msg : msgs) {
@@ -162,10 +162,6 @@ AgentContext filteredContext() {
162162
filteredParts.add(part);
163163
}
164164

165-
if (filteredParts.isEmpty()) {
166-
continue;
167-
}
168-
169165
filteredMsgs.add(msg.toBuilder().clearParts().addAllParts(filteredParts).build());
170166
}
171167

@@ -198,19 +194,19 @@ AgentMessageSet filterResultType(String resultType) {
198194
* Returns a new {@link AgentMessageSet} filtered to include messages before the
199195
* given timestamp.
200196
*/
201-
AgentMessageSet filterBefore(Timestamp timestamp) {
197+
AgentMessageSet filterBefore(Instant timestamp) {
202198
return toBuilder()
203-
.setBefore(Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()))
199+
.setBefore(timestamp)
204200
.build();
205201
}
206202

207203
/**
208204
* Returns a new {@link AgentMessageSet} filtered to include messages after the
209205
* given timestamp.
210206
*/
211-
AgentMessageSet filterAfter(Timestamp timestamp) {
207+
AgentMessageSet filterAfter(Instant timestamp) {
212208
return toBuilder()
213-
.setAfter(Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()))
209+
.setAfter(timestamp)
214210
.build();
215211
}
216212

tools/src/main/java/dev/cel/tools/ai/AgenticPolicyEnvironment.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import com.google.common.base.Ascii;
66
import com.google.common.collect.ImmutableCollection;
7+
import com.google.common.collect.ImmutableList;
78
import com.google.common.collect.ImmutableSet;
89
import com.google.common.io.Resources;
910
import dev.cel.bundle.Cel;
@@ -27,6 +28,7 @@
2728
import dev.cel.runtime.CelFunctionBinding;
2829
import java.io.IOException;
2930
import java.net.URL;
31+
import java.time.Instant;
3032
import java.util.List;
3133
import java.util.Optional;
3234

@@ -40,6 +42,7 @@ final class AgenticPolicyEnvironment {
4042
@SuppressWarnings("Immutable")
4143
static Cel newInstance(AgentClassifier classifier) {
4244
AgenticPolicyClassifiers classifiers = new AgenticPolicyClassifiers(classifier);
45+
4346
CelBuilder builder = CelFactory.standardCelBuilder()
4447
.setContainer(CelContainer.ofName("cel.expr.ai"))
4548
.addFileTypes(Agent.getDescriptor().getFile())
@@ -109,11 +112,50 @@ static Cel newInstance(AgentClassifier classifier) {
109112
AgentMessage.class,
110113
String.class,
111114
(msg, toolName) -> AgentMessageSet.of(msg).filterToolCall(toolName)),
115+
CelFunctionBinding.from(
116+
"AgentMessageSet_messages",
117+
Object.class,
118+
(set) -> {
119+
AgentMessageSet messageSet = (AgentMessageSet) set;
120+
List<AgentMessage> result = messageSet.filteredContext()
121+
.getExtension(AgentContextExtensions.agentContextMessageHistory);
122+
return ImmutableList.copyOf(result);
123+
}),
124+
CelFunctionBinding.from(
125+
"list(Finding)_hasAll_list(Finding)",
126+
List.class,
127+
List.class,
128+
(source, required) -> hasAllFindings(Optional.of((List<Finding>) source), (List<Finding>) required)),
112129
CelFunctionBinding.from(
113130
"AgentMessage_role_string",
114131
AgentMessage.class,
115-
String.class,
116-
(msg, role) -> AgentMessageSet.of(msg).filterRole(role)));
132+
Object.class,
133+
(msg, role) -> AgentMessageSet.of(msg).filterRole(String.valueOf(role))),
134+
CelFunctionBinding.from(
135+
"AgentMessageSet_role_string",
136+
AgentMessageSet.class,
137+
Object.class,
138+
(set, role) -> set.filterRole(String.valueOf(role))),
139+
CelFunctionBinding.from(
140+
"AgentMessageSet_before_timestamp",
141+
AgentMessageSet.class,
142+
Instant.class,
143+
AgentMessageSet::filterBefore),
144+
CelFunctionBinding.from(
145+
"AgentMessage_before_timestamp",
146+
AgentMessage.class,
147+
Instant.class,
148+
(msg, ts) -> AgentMessageSet.of(msg).filterBefore(ts)),
149+
CelFunctionBinding.from(
150+
"AgentMessageSet_after_timestamp",
151+
AgentMessageSet.class,
152+
Instant.class,
153+
AgentMessageSet::filterAfter),
154+
CelFunctionBinding.from(
155+
"AgentMessage_after_timestamp",
156+
AgentMessage.class,
157+
Instant.class,
158+
(msg, ts) -> AgentMessageSet.of(msg).filterAfter(ts)));
117159

118160
Cel celEnv = builder.build();
119161
celEnv = extendFromConfig(celEnv, "environment/agent_env.yaml");
@@ -131,10 +173,6 @@ private static boolean hasAllFindings(Optional<List<Finding>> sourceOpt, List<Fi
131173
act.getConfidence() >= req.getConfidence()));
132174
}
133175

134-
static Cel newInstance() {
135-
return newInstance(AgentClassifier.DEFAULT);
136-
}
137-
138176
private static Cel extendFromConfig(Cel cel, String yamlConfigPath) {
139177
String yamlEnv;
140178
try {

tools/src/main/java/dev/cel/tools/ai/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ java_library(
5353
"//common:container",
5454
"//common:options",
5555
"//common/types",
56+
"//common/types:message_type_provider",
5657
"//common/types:type_providers",
5758
"//parser:macro",
5859
"//runtime:function_binding",
5960
"//:auto_value",
6061
"@maven//:com_google_guava_guava",
6162
"@maven//:com_google_protobuf_protobuf_java",
63+
"@maven//:com_google_protobuf_protobuf_java_util",
6264
],
6365
)
6466

tools/src/main/resources/environment/common_env.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,22 @@ functions:
690690
params:
691691
- type_name: cel.expr.ai.AgentMessage.Part
692692

693+
- name: "messages"
694+
description: |
695+
Returns the ordered list of AgentMessages in the message set.
696+
overloads:
697+
- id: "AgentMessageSet_messages"
698+
examples:
699+
- |
700+
// Returns the ordered list of messages in the message set.
701+
agent.history.messages()
702+
target:
703+
type_name: cel.expr.ai.AgentMessageSet
704+
return:
705+
type_name: list
706+
params:
707+
- type_name: dyn
708+
693709
- name: "spec"
694710
description: |
695711
Returns the specification for the tool.

tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import com.google.common.base.Ascii;
66
import com.google.common.collect.ImmutableList;
7-
import com.google.common.collect.ImmutableMap;
87
import com.google.common.io.Resources;
98
import com.google.common.truth.Expect;
109
import com.google.testing.junit.testparameterinjector.TestParameter;
@@ -28,6 +27,7 @@
2827
import java.io.IOException;
2928
import java.net.URL;
3029
import java.util.ArrayList;
30+
import java.util.HashMap;
3131
import java.util.List;
3232
import java.util.Optional;
3333
import org.junit.Rule;
@@ -58,15 +58,13 @@ private enum AgenticPolicyTestCase {
5858
"require_user_confirmation_for_tool_tests.yaml"),
5959
OPEN_WORLD_TOOL_REPLAY(
6060
"open_world_tool_replay.celpolicy",
61-
"open_world_tool_replay_tests.yaml");
62-
// TRUST_CASCADING(
63-
// "trust_cascading.celpolicy",
64-
// "trust_cascading_tests.yaml"
65-
// ),
66-
// TIME_BOUND_APPROVAL(
67-
// "time_bound_approval.celpolicy",
68-
// "time_bound_approval_tests.yaml"
69-
// );
61+
"open_world_tool_replay_tests.yaml"),
62+
TRUST_CASCADING(
63+
"trust_cascading.celpolicy",
64+
"trust_cascading_tests.yaml"),
65+
TIME_BOUND_APPROVAL(
66+
"time_bound_approval.celpolicy",
67+
"time_bound_approval_tests.yaml");
7068

7169
private final String policyFilePath;
7270
private final String policyTestCaseFilePath;
@@ -94,17 +92,24 @@ private void runTests(Cel cel, CelAbstractSyntaxTree ast, PolicyTestSuite testSu
9492
String testName = String.format(
9593
"%s: %s", testSection.getName(), testCase.getName());
9694
try {
97-
ImmutableMap<String, Object> inputMap = testCase.toInputMap(cel);
95+
HashMap<String, Object> inputMap = new HashMap<>(testCase.toInputMap(cel));
9896

9997
List<AgentMessage> history = inputMap.containsKey("_test_history")
10098
? (List<AgentMessage>) inputMap.get("_test_history")
10199
: ImmutableList.of();
102100

101+
AgentContext context = AgentContext.newBuilder()
102+
.setExtension(AgentContextExtensions.agentContextMessageHistory, history)
103+
.build();
104+
AgentMessageSet messageSet = AgentMessageSet.of(context);
105+
106+
inputMap.put("agent.history", messageSet);
107+
103108
@SuppressWarnings("Immutable")
104109
CelLateFunctionBindings bindings = CelLateFunctionBindings.from(
105110
CelFunctionBinding.from(
106111
"agent_history",
107-
ImmutableList.of(), // No args
112+
ImmutableList.of(),
108113
(args) -> history));
109114

110115
Object evalResult = cel.createProgram(ast).eval(inputMap, bindings);

tools/src/test/resources/policy/require_user_confirmation_for_tool.celpolicy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ default: deny
33

44
variables:
55
- high_confidence_pii: >
6-
tool.call.sensitivityFindings('pii').orValue([]).exists(f, f.confidence >= 0.8)
6+
tool.call.sensitivityFindings('pii').hasAll([ai.finding("pii_score", 0.8)])
77

88
rules:
99
- description: "Confirm tool calls if high-confidence PII is detected"

tools/src/test/resources/policy/time_bound_approval.celpolicy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ variables:
55
# Define the validity window (30 seconds ago)
66
- approval_cutoff: now - duration('30s')
77

8-
# Find approval messages in the valid window
98
- valid_approvals: >
10-
agent.history()
9+
agent.history
1110
.after(variables.approval_cutoff)
1211
.role('model')
12+
.messages()
1313
.filter(m, has(m.metadata.step) && m.metadata.step == 'approval_granted')
1414

1515
- has_valid_approval: variables.valid_approvals.size() > 0

tools/src/test/resources/policy/trust_cascading.celpolicy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ default: allow
44
variables:
55
# Critical security threats
66
- is_compromised: >
7-
agent.context.trust.findings.contains([ai.finding("compromised_session", 0.9)])
7+
agent.context.trust.findings.hasAll([ai.finding("compromised_session", 0.9)])
88

99
# Compliance and/or hygiene issues with the source
1010
- is_unverified: >
11-
agent.context.trust.findings.contains([ai.finding("unverified_source", 0.8)])
11+
agent.context.trust.findings.hasAll([ai.finding("unverified_source", 0.8)])
1212

1313
rules:
1414
- description: "Block sessions with high-confidence compromise indicators"

0 commit comments

Comments
 (0)