Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions core/src/main/java/com/google/adk/sessions/VertexAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.common.net.UrlEscapers;
import com.google.genai.types.HttpOptions;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
Expand Down Expand Up @@ -111,11 +112,12 @@ Maybe<JsonNode> listSessions(String reasoningEngineId, String userId) {
.flatMapMaybe(VertexAiClient::getJsonResponse);
}

Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
return performApiRequest(
"GET",
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events",
"")
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId, @Nullable String filter) {
String path = "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events";
if (filter != null) {
path += "?filter=" + UrlEscapers.urlFormParameterEscaper().escape(filter);
}
return performApiRequest("GET", path, "")
.doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse))
.flatMapMaybe(VertexAiClient::getJsonResponse);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,14 @@ private ListSessionsResponse parseListSessionsResponse(

@Override
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
return listEventsInternal(appName, sessionId, /* filter= */ null);
}

private Single<ListEventsResponse> listEventsInternal(
String appName, String sessionId, @Nullable String filter) {
String reasoningEngineId = parseReasoningEngineId(appName);
return client
.listEvents(reasoningEngineId, sessionId)
.listEvents(reasoningEngineId, sessionId, filter)
.map(this::parseListEventsResponse)
.defaultIfEmpty(ListEventsResponse.builder().build());
}
Expand Down Expand Up @@ -212,7 +217,7 @@ public Maybe<Session> getSession(
new TypeReference<ConcurrentMap<String, Object>>() {}));
}

return listEvents(appName, userId, sessionId)
return listEventsInternal(appName, sessionId, afterTimestampFilter(config))
.map(
response -> {
Session.Builder sessionBuilder =
Expand All @@ -232,48 +237,41 @@ public Maybe<Session> getSession(
});
}

/**
* Builds the server-side events filter for {@code afterTimestamp}, mirroring the Python and Go
* implementations (inclusive {@code timestamp>=}). The filter is only applied when {@code
* numRecentEvents} is not set, matching the precedence in {@link #filterEvents}.
*/
private static @Nullable String afterTimestampFilter(Optional<GetSessionConfig> config) {
if (config.isPresent()
&& config.get().numRecentEvents().isEmpty()
&& config.get().afterTimestamp().isPresent()) {
return "timestamp>=\"" + config.get().afterTimestamp().get() + "\"";
}
return null;
}

private static List<Event> filterEvents(
List<Event> originalEvents, Optional<GetSessionConfig> config) {
// Preserve the full event stream that Vertex AI returns. Event timestamps are
// assigned client-side while updateTime is assigned server-side, so filtering
// on updateTime could silently drop the most recently appended event(s).
// afterTimestamp is filtered server-side (see afterTimestampFilter), so only
// numRecentEvents is applied here.
List<Event> events =
originalEvents.stream()
.sorted(Comparator.comparingLong(Event::timestamp))
.collect(toCollection(ArrayList::new));

if (config.isPresent()) {
if (config.get().numRecentEvents().isPresent()) {
int numRecentEvents = config.get().numRecentEvents().get();
if (events.size() > numRecentEvents) {
events = events.subList(events.size() - numRecentEvents, events.size());
}
} else if (config.get().afterTimestamp().isPresent()) {
long afterTimestampMillis = config.get().afterTimestamp().get().toEpochMilli();
events = events.subList(firstIndexAtOrAfter(events, afterTimestampMillis), events.size());
if (config.isPresent() && config.get().numRecentEvents().isPresent()) {
int numRecentEvents = config.get().numRecentEvents().get();
if (events.size() > numRecentEvents) {
events = events.subList(events.size() - numRecentEvents, events.size());
}
}
return events;
}

/**
* Returns the index of the first event whose timestamp is at or after {@code timestampMillis}, or
* the list size if there is none. {@code sortedEvents} must be sorted ascending by timestamp.
*/
private static int firstIndexAtOrAfter(List<Event> sortedEvents, long timestampMillis) {
int low = 0;
int high = sortedEvents.size();
while (low < high) {
int mid = (low + high) >>> 1;
if (sortedEvents.get(mid).timestamp() < timestampMillis) {
low = mid + 1;
} else {
high = mid;
}
}
return low;
}

@Override
public Completable deleteSession(String appName, String userId, String sessionId) {
String reasoningEngineId = parseReasoningEngineId(appName);
Expand Down
33 changes: 32 additions & 1 deletion core/src/test/java/com/google/adk/sessions/MockApiAnswer.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.events.Event;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -30,7 +33,8 @@ class MockApiAnswer implements Answer<ApiResponse> {
private static final Pattern APPEND_EVENT_REGEX =
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+):appendEvent$");
private static final Pattern EVENTS_REGEX =
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$");
Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\\?filter=(.*))?$");
private static final Pattern TIMESTAMP_FILTER_REGEX = Pattern.compile("timestamp>=\"(.*)\"");
private static final MediaType JSON_MEDIA_TYPE =
MediaType.parse("application/json; charset=utf-8");

Expand Down Expand Up @@ -200,8 +204,16 @@ private ApiResponse handleGetEvents(String path) throws Exception {
return null;
}
String sessionId = matcher.group(2);
// The client URL-escapes the filter value; decode it as the real server would.
String filter =
matcher.group(3) == null
? null
: URLDecoder.decode(matcher.group(3), StandardCharsets.UTF_8);
String eventData = eventMap.get(sessionId);
if (eventData != null) {
if (filter != null) {
eventData = applyTimestampFilter(eventData, filter);
}
return responseWithBody(
String.format(
"""
Expand All @@ -216,6 +228,25 @@ private ApiResponse handleGetEvents(String path) throws Exception {
}
}

/** Emulates the server-side inclusive {@code timestamp>=} filter on the events list. */
private static String applyTimestampFilter(String eventData, String filter) throws Exception {
Matcher filterMatcher = TIMESTAMP_FILTER_REGEX.matcher(filter);
if (!filterMatcher.matches()) {
return eventData;
}
Instant threshold = Instant.parse(filterMatcher.group(1));
List<Map<String, Object>> events =
mapper.readValue(eventData, new TypeReference<List<Map<String, Object>>>() {});
List<Map<String, Object>> kept = new ArrayList<>();
for (Map<String, Object> event : events) {
Instant timestamp = Instant.parse((String) event.get("timestamp"));
if (!timestamp.isBefore(threshold)) {
kept.add(event);
}
}
return mapper.writeValueAsString(kept);
}

private ApiResponse handleGetLro(String path) {
return responseWithBody(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -29,6 +32,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -444,6 +448,31 @@ public void getSession_afterTimestampBetweenEvents_dropsEventsBeforeThreshold()
assertThat(session.events().stream().map(Event::id)).containsExactly("e3");
}

@Test
public void getSession_afterTimestampConfig_urlEscapesFilterInRequest() {
sessionMap.put("9", mockSessionJson("9", "2024-12-12T12:00:30.000000Z"));
eventMap.put("9", mockEventsJson(mockEventJson("e1", "2024-12-12T12:00:15.000000Z")));
GetSessionConfig config =
GetSessionConfig.builder()
.afterTimestamp(Instant.parse("2024-12-12T12:00:10.000000Z"))
.build();

Object unused =
vertexAiSessionService.getSession("123", "user", "9", Optional.of(config)).blockingGet();

ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
verify(mockApiClient, atLeastOnce()).request(eq("GET"), pathCaptor.capture(), eq(""));
String eventsPath =
pathCaptor.getAllValues().stream()
.filter(path -> path.contains("/events"))
.findFirst()
.orElseThrow(() -> new AssertionError("No list-events request was made"));
// The filter operator and quotes are URL-escaped (>= -> %3E%3D, " -> %22),
// not sent raw.
assertThat(eventsPath).contains("filter=timestamp%3E%3D%22");
assertThat(eventsPath).doesNotContain("timestamp>=");
}

@Test
public void getSession_numRecentEventsConfig_returnsMostRecentEvents() {
sessionMap.put("7", mockSessionJson("7", "2024-12-12T12:00:30.000000Z"));
Expand Down
Loading