Skip to content

Commit e12baa2

Browse files
kvmiloscopybara-github
authored andcommitted
perf: filter session events server-side by afterTimestamp in VertexAiSessionService.getSession
PiperOrigin-RevId: 932376696
1 parent 987ef4e commit e12baa2

4 files changed

Lines changed: 95 additions & 35 deletions

File tree

core/src/main/java/com/google/adk/sessions/VertexAiClient.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.google.auth.oauth2.GoogleCredentials;
99
import com.google.common.base.Splitter;
1010
import com.google.common.collect.Iterables;
11+
import com.google.common.net.UrlEscapers;
1112
import com.google.genai.types.HttpOptions;
1213
import io.reactivex.rxjava3.core.Completable;
1314
import io.reactivex.rxjava3.core.Maybe;
@@ -111,11 +112,12 @@ Maybe<JsonNode> listSessions(String reasoningEngineId, String userId) {
111112
.flatMapMaybe(VertexAiClient::getJsonResponse);
112113
}
113114

114-
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
115-
return performApiRequest(
116-
"GET",
117-
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events",
118-
"")
115+
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId, @Nullable String filter) {
116+
String path = "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events";
117+
if (filter != null) {
118+
path += "?filter=" + UrlEscapers.urlFormParameterEscaper().escape(filter);
119+
}
120+
return performApiRequest("GET", path, "")
119121
.doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse))
120122
.flatMapMaybe(VertexAiClient::getJsonResponse);
121123
}

core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,14 @@ private ListSessionsResponse parseListSessionsResponse(
164164

165165
@Override
166166
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
167+
return listEventsInternal(appName, sessionId, /* filter= */ null);
168+
}
169+
170+
private Single<ListEventsResponse> listEventsInternal(
171+
String appName, String sessionId, @Nullable String filter) {
167172
String reasoningEngineId = parseReasoningEngineId(appName);
168173
return client
169-
.listEvents(reasoningEngineId, sessionId)
174+
.listEvents(reasoningEngineId, sessionId, filter)
170175
.map(this::parseListEventsResponse)
171176
.defaultIfEmpty(ListEventsResponse.builder().build());
172177
}
@@ -212,7 +217,7 @@ public Maybe<Session> getSession(
212217
new TypeReference<ConcurrentMap<String, Object>>() {}));
213218
}
214219

215-
return listEvents(appName, userId, sessionId)
220+
return listEventsInternal(appName, sessionId, afterTimestampFilter(config))
216221
.map(
217222
response -> {
218223
Session.Builder sessionBuilder =
@@ -232,48 +237,41 @@ public Maybe<Session> getSession(
232237
});
233238
}
234239

240+
/**
241+
* Builds the server-side events filter for {@code afterTimestamp}, mirroring the Python and Go
242+
* implementations (inclusive {@code timestamp>=}). The filter is only applied when {@code
243+
* numRecentEvents} is not set, matching the precedence in {@link #filterEvents}.
244+
*/
245+
private static @Nullable String afterTimestampFilter(Optional<GetSessionConfig> config) {
246+
if (config.isPresent()
247+
&& config.get().numRecentEvents().isEmpty()
248+
&& config.get().afterTimestamp().isPresent()) {
249+
return "timestamp>=\"" + config.get().afterTimestamp().get() + "\"";
250+
}
251+
return null;
252+
}
253+
235254
private static List<Event> filterEvents(
236255
List<Event> originalEvents, Optional<GetSessionConfig> config) {
237256
// Preserve the full event stream that Vertex AI returns. Event timestamps are
238257
// assigned client-side while updateTime is assigned server-side, so filtering
239258
// on updateTime could silently drop the most recently appended event(s).
259+
// afterTimestamp is filtered server-side (see afterTimestampFilter), so only
260+
// numRecentEvents is applied here.
240261
List<Event> events =
241262
originalEvents.stream()
242263
.sorted(Comparator.comparingLong(Event::timestamp))
243264
.collect(toCollection(ArrayList::new));
244265

245-
if (config.isPresent()) {
246-
if (config.get().numRecentEvents().isPresent()) {
247-
int numRecentEvents = config.get().numRecentEvents().get();
248-
if (events.size() > numRecentEvents) {
249-
events = events.subList(events.size() - numRecentEvents, events.size());
250-
}
251-
} else if (config.get().afterTimestamp().isPresent()) {
252-
long afterTimestampMillis = config.get().afterTimestamp().get().toEpochMilli();
253-
events = events.subList(firstIndexAtOrAfter(events, afterTimestampMillis), events.size());
266+
if (config.isPresent() && config.get().numRecentEvents().isPresent()) {
267+
int numRecentEvents = config.get().numRecentEvents().get();
268+
if (events.size() > numRecentEvents) {
269+
events = events.subList(events.size() - numRecentEvents, events.size());
254270
}
255271
}
256272
return events;
257273
}
258274

259-
/**
260-
* Returns the index of the first event whose timestamp is at or after {@code timestampMillis}, or
261-
* the list size if there is none. {@code sortedEvents} must be sorted ascending by timestamp.
262-
*/
263-
private static int firstIndexAtOrAfter(List<Event> sortedEvents, long timestampMillis) {
264-
int low = 0;
265-
int high = sortedEvents.size();
266-
while (low < high) {
267-
int mid = (low + high) >>> 1;
268-
if (sortedEvents.get(mid).timestamp() < timestampMillis) {
269-
low = mid + 1;
270-
} else {
271-
high = mid;
272-
}
273-
}
274-
return low;
275-
}
276-
277275
@Override
278276
public Completable deleteSession(String appName, String userId, String sessionId) {
279277
String reasoningEngineId = parseReasoningEngineId(appName);

core/src/test/java/com/google/adk/sessions/MockApiAnswer.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import com.fasterxml.jackson.databind.ObjectMapper;
55
import com.google.adk.JsonBaseModel;
66
import com.google.adk.events.Event;
7+
import java.net.URLDecoder;
8+
import java.nio.charset.StandardCharsets;
9+
import java.time.Instant;
710
import java.util.ArrayList;
811
import java.util.HashMap;
912
import java.util.List;
@@ -30,7 +33,8 @@ class MockApiAnswer implements Answer<ApiResponse> {
3033
private static final Pattern APPEND_EVENT_REGEX =
3134
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+):appendEvent$");
3235
private static final Pattern EVENTS_REGEX =
33-
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$");
36+
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\\?filter=(.*))?$");
37+
private static final Pattern TIMESTAMP_FILTER_REGEX = Pattern.compile("timestamp>=\"(.*)\"");
3438
private static final MediaType JSON_MEDIA_TYPE =
3539
MediaType.parse("application/json; charset=utf-8");
3640

@@ -200,8 +204,16 @@ private ApiResponse handleGetEvents(String path) throws Exception {
200204
return null;
201205
}
202206
String sessionId = matcher.group(2);
207+
// The client URL-escapes the filter value; decode it as the real server would.
208+
String filter =
209+
matcher.group(3) == null
210+
? null
211+
: URLDecoder.decode(matcher.group(3), StandardCharsets.UTF_8);
203212
String eventData = eventMap.get(sessionId);
204213
if (eventData != null) {
214+
if (filter != null) {
215+
eventData = applyTimestampFilter(eventData, filter);
216+
}
205217
return responseWithBody(
206218
String.format(
207219
"""
@@ -216,6 +228,25 @@ private ApiResponse handleGetEvents(String path) throws Exception {
216228
}
217229
}
218230

231+
/** Emulates the server-side inclusive {@code timestamp>=} filter on the events list. */
232+
private static String applyTimestampFilter(String eventData, String filter) throws Exception {
233+
Matcher filterMatcher = TIMESTAMP_FILTER_REGEX.matcher(filter);
234+
if (!filterMatcher.matches()) {
235+
return eventData;
236+
}
237+
Instant threshold = Instant.parse(filterMatcher.group(1));
238+
List<Map<String, Object>> events =
239+
mapper.readValue(eventData, new TypeReference<List<Map<String, Object>>>() {});
240+
List<Map<String, Object>> kept = new ArrayList<>();
241+
for (Map<String, Object> event : events) {
242+
Instant timestamp = Instant.parse((String) event.get("timestamp"));
243+
if (!timestamp.isBefore(threshold)) {
244+
kept.add(event);
245+
}
246+
}
247+
return mapper.writeValueAsString(kept);
248+
}
249+
219250
private ApiResponse handleGetLro(String path) {
220251
return responseWithBody(
221252
String.format(

core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import static com.google.common.truth.Truth.assertThat;
55
import static org.junit.Assert.assertThrows;
66
import static org.mockito.ArgumentMatchers.anyString;
7+
import static org.mockito.ArgumentMatchers.eq;
8+
import static org.mockito.Mockito.atLeastOnce;
9+
import static org.mockito.Mockito.verify;
710
import static org.mockito.Mockito.when;
811

912
import com.fasterxml.jackson.core.type.TypeReference;
@@ -29,6 +32,7 @@
2932
import org.junit.Test;
3033
import org.junit.runner.RunWith;
3134
import org.junit.runners.JUnit4;
35+
import org.mockito.ArgumentCaptor;
3236
import org.mockito.Mock;
3337
import org.mockito.MockitoAnnotations;
3438

@@ -444,6 +448,31 @@ public void getSession_afterTimestampBetweenEvents_dropsEventsBeforeThreshold()
444448
assertThat(session.events().stream().map(Event::id)).containsExactly("e3");
445449
}
446450

451+
@Test
452+
public void getSession_afterTimestampConfig_urlEscapesFilterInRequest() {
453+
sessionMap.put("9", mockSessionJson("9", "2024-12-12T12:00:30.000000Z"));
454+
eventMap.put("9", mockEventsJson(mockEventJson("e1", "2024-12-12T12:00:15.000000Z")));
455+
GetSessionConfig config =
456+
GetSessionConfig.builder()
457+
.afterTimestamp(Instant.parse("2024-12-12T12:00:10.000000Z"))
458+
.build();
459+
460+
Object unused =
461+
vertexAiSessionService.getSession("123", "user", "9", Optional.of(config)).blockingGet();
462+
463+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
464+
verify(mockApiClient, atLeastOnce()).request(eq("GET"), pathCaptor.capture(), eq(""));
465+
String eventsPath =
466+
pathCaptor.getAllValues().stream()
467+
.filter(path -> path.contains("/events"))
468+
.findFirst()
469+
.orElseThrow(() -> new AssertionError("No list-events request was made"));
470+
// The filter operator and quotes are URL-escaped (>= -> %3E%3D, " -> %22),
471+
// not sent raw.
472+
assertThat(eventsPath).contains("filter=timestamp%3E%3D%22");
473+
assertThat(eventsPath).doesNotContain("timestamp>=");
474+
}
475+
447476
@Test
448477
public void getSession_numRecentEventsConfig_returnsMostRecentEvents() {
449478
sessionMap.put("7", mockSessionJson("7", "2024-12-12T12:00:30.000000Z"));

0 commit comments

Comments
 (0)