Skip to content

Commit d2c3e61

Browse files
committed
Agent message set, classifiers etc.
1 parent bb10b79 commit d2c3e61

File tree

9 files changed

+699
-101
lines changed

9 files changed

+699
-101
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package dev.cel.tools.ai;
2+
3+
import dev.cel.expr.ai.Finding;
4+
import java.util.List;
5+
import java.util.Optional;
6+
7+
/**
8+
* Interface for providing content classifiers.
9+
*/
10+
public interface AgentClassifier {
11+
/**
12+
* Classifies the given input and returns a list of findings.
13+
*
14+
* @param input the input object (e.g., AgentContext, AgentMessage, ToolCall)
15+
* @param label the classification label to match (or "*" for all)
16+
*/
17+
Optional<List<Finding>> classify(Object input, String label);
18+
19+
/** A default classifier that returns no findings. */
20+
AgentClassifier DEFAULT = (input, label) -> Optional.empty();
21+
}
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package dev.cel.tools.ai;
2+
3+
import com.google.auto.value.AutoValue;
4+
import com.google.protobuf.Timestamp;
5+
import dev.cel.expr.ai.AgentContext;
6+
import dev.cel.expr.ai.AgentContextExtensions;
7+
import dev.cel.expr.ai.AgentMessage;
8+
import dev.cel.expr.ai.AgentMessage.Part;
9+
import java.time.Instant;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.Optional;
13+
14+
/**
15+
* AgentMessageSet value which represents a filtered set of agent messages.
16+
*/
17+
@AutoValue
18+
abstract class AgentMessageSet {
19+
20+
/**
21+
* Underlying {@link AgentContext} containing the message history.
22+
*/
23+
abstract AgentContext context();
24+
25+
/** Returns the role to filter by, if present. */
26+
abstract Optional<String> role();
27+
28+
/** Returns the tool call name to filter by, if present. */
29+
abstract Optional<String> toolCallName();
30+
31+
/** Returns the result type (MIME type) to filter by, if present. */
32+
abstract Optional<String> resultType();
33+
34+
/**
35+
* Returns the exclusive upper bound timestamp for filtering messages, if
36+
* present.
37+
*/
38+
abstract Optional<Instant> before();
39+
40+
/**
41+
* Returns the exclusive lower bound timestamp for filtering messages, if
42+
* present.
43+
*/
44+
abstract Optional<Instant> after();
45+
46+
/** Returns true if only keys (prompts) should be included, false otherwise. */
47+
abstract boolean prompts();
48+
49+
/**
50+
* Creates a new {@link AgentMessageSet} from the given {@link AgentContext}.
51+
*/
52+
static AgentMessageSet of(AgentContext context) {
53+
return builder().setContext(context).setPrompts(false).build();
54+
}
55+
56+
/**
57+
* Creates a new {@link AgentMessageSet} containing a single
58+
* {@link AgentMessage}.
59+
*
60+
* <p>
61+
* This convenience method wraps the message in a new {@link AgentContext}.
62+
*/
63+
static AgentMessageSet of(AgentMessage message) {
64+
AgentContext.Builder contextBuilder = AgentContext.newBuilder();
65+
contextBuilder.addExtension(AgentContextExtensions.agentContextMessageHistory, message);
66+
return of(contextBuilder.build());
67+
}
68+
69+
/** Returns a new {@link Builder} for {@link AgentMessageSet}. */
70+
static Builder builder() {
71+
return new AutoValue_AgentMessageSet.Builder();
72+
}
73+
74+
/**
75+
* Returns a new {@link Builder} initialized with the values of this instance.
76+
*/
77+
abstract Builder toBuilder();
78+
79+
/** Builder for {@link AgentMessageSet}. */
80+
@AutoValue.Builder
81+
abstract static class Builder {
82+
/** Sets the {@link AgentContext}. */
83+
abstract Builder setContext(AgentContext context);
84+
85+
/** Sets the role filter. */
86+
abstract Builder setRole(String role);
87+
88+
/** Sets the tool call name filter. */
89+
abstract Builder setToolCallName(String toolCallName);
90+
91+
/** Sets the result type filter. */
92+
abstract Builder setResultType(String resultType);
93+
94+
/** Sets the before timestamp filter. */
95+
abstract Builder setBefore(Instant before);
96+
97+
/** Sets the after timestamp filter. */
98+
abstract Builder setAfter(Instant after);
99+
100+
/** Sets whether to include prompts only. */
101+
abstract Builder setPrompts(boolean prompts);
102+
103+
/** Builds the {@link AgentMessageSet}. */
104+
abstract AgentMessageSet build();
105+
}
106+
107+
/**
108+
* Returns the filtered messages as an {@link AgentContext} proto.
109+
*
110+
* <p>
111+
* This method applies all configured filters (role, time, tool call, etc.) to
112+
* the messages in
113+
* the underlying context and returns a new {@link AgentContext} with the
114+
* filtered history.
115+
*/
116+
AgentContext filteredContext() {
117+
if (!context().hasExtension(AgentContextExtensions.agentContextMessageHistory)) {
118+
return context();
119+
}
120+
List<AgentMessage> msgs = context().getExtension(AgentContextExtensions.agentContextMessageHistory);
121+
List<AgentMessage> filteredMsgs = new ArrayList<>();
122+
123+
for (AgentMessage msg : msgs) {
124+
if (role().isPresent() && !msg.getRole().equals(role().get())) {
125+
continue;
126+
}
127+
Timestamp msgTime = msg.getTime();
128+
Instant time = Instant.ofEpochSecond(msgTime.getSeconds(), msgTime.getNanos());
129+
130+
if (after().isPresent() && time.isBefore(after().get())) {
131+
continue;
132+
}
133+
if (before().isPresent() && time.isAfter(before().get())) {
134+
continue;
135+
}
136+
137+
List<Part> filteredParts = new ArrayList<>();
138+
for (Part part : msg.getPartsList()) {
139+
if (prompts() && !part.hasPrompt()) {
140+
continue;
141+
}
142+
if (toolCallName().isPresent()) {
143+
if (!part.hasToolCall()) {
144+
continue;
145+
}
146+
if (!part.getToolCall().getName().equals(toolCallName().get())) {
147+
continue;
148+
}
149+
}
150+
if (resultType().isPresent()) {
151+
if (part.hasToolCall() && part.getToolCall().hasResult()) {
152+
if (part.getToolCall().getResult().getMimeType().equals(resultType().get())) {
153+
filteredParts.add(part);
154+
}
155+
} else if (part.hasAttachment()) {
156+
if (part.getAttachment().getMimeType().equals(resultType().get())) {
157+
filteredParts.add(part);
158+
}
159+
}
160+
continue;
161+
}
162+
filteredParts.add(part);
163+
}
164+
165+
if (filteredParts.isEmpty()) {
166+
continue;
167+
}
168+
169+
filteredMsgs.add(msg.toBuilder().clearParts().addAllParts(filteredParts).build());
170+
}
171+
172+
return context().toBuilder()
173+
.setExtension(AgentContextExtensions.agentContextMessageHistory, filteredMsgs)
174+
.build();
175+
}
176+
177+
/** Returns a new {@link AgentMessageSet} filtered by the given role. */
178+
AgentMessageSet filterRole(String role) {
179+
return toBuilder().setRole(role).build();
180+
}
181+
182+
/**
183+
* Returns a new {@link AgentMessageSet} filtered by the given tool call name.
184+
*/
185+
AgentMessageSet filterToolCall(String toolCallName) {
186+
return toBuilder().setToolCallName(toolCallName).build();
187+
}
188+
189+
/**
190+
* Returns a new {@link AgentMessageSet} filtered by the given result type (MIME
191+
* type).
192+
*/
193+
AgentMessageSet filterResultType(String resultType) {
194+
return toBuilder().setResultType(resultType).build();
195+
}
196+
197+
/**
198+
* Returns a new {@link AgentMessageSet} filtered to include messages before the
199+
* given timestamp.
200+
*/
201+
AgentMessageSet filterBefore(Timestamp timestamp) {
202+
return toBuilder()
203+
.setBefore(Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()))
204+
.build();
205+
}
206+
207+
/**
208+
* Returns a new {@link AgentMessageSet} filtered to include messages after the
209+
* given timestamp.
210+
*/
211+
AgentMessageSet filterAfter(Timestamp timestamp) {
212+
return toBuilder()
213+
.setAfter(Instant.ofEpochSecond(timestamp.getSeconds(), timestamp.getNanos()))
214+
.build();
215+
}
216+
217+
/**
218+
* Returns a new {@link AgentMessageSet} filtered to include only prompts (keys)
219+
* if true.
220+
*/
221+
AgentMessageSet filterPrompts(boolean prompts) {
222+
return toBuilder().setPrompts(prompts).build();
223+
}
224+
}

0 commit comments

Comments
 (0)