Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,21 @@ public Maybe<Session> getSession(
if (events.isEmpty()) {
return sessionBuilder.build();
}
events = filterEvents(events, updateTimestamp, config);
events = filterEvents(events, config);
return sessionBuilder.events(events).build();
})
.toMaybe();
});
}

private static List<Event> filterEvents(
List<Event> originalEvents,
@Nullable Instant updateTimestamp,
Optional<GetSessionConfig> config) {
List<Event> originalEvents, Optional<GetSessionConfig> config) {
// Preserve the full event stream that Vertex AI returns. Event timestamps are
// assigned client-side while updateTime is assigned server-side, so filtering
// on updateTime could silently drop the most recently appended event(s).
List<Event> events =
originalEvents.stream()
.filter(
event ->
updateTimestamp == null
|| Instant.ofEpochMilli(event.timestamp()).isBefore(updateTimestamp))
.sorted(Comparator.comparing(Event::timestamp))
.sorted(Comparator.comparingLong(Event::timestamp))
.collect(toCollection(ArrayList::new));

if (config.isPresent()) {
Expand All @@ -252,22 +249,31 @@ private static List<Event> filterEvents(
events = events.subList(events.size() - numRecentEvents, events.size());
}
} else if (config.get().afterTimestamp().isPresent()) {
Instant afterTimestamp = config.get().afterTimestamp().get();
int i = events.size() - 1;
while (i >= 0) {
if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) {
break;
}
i -= 1;
}
if (i >= 0) {
events = events.subList(i, events.size());
}
long afterTimestampMillis = config.get().afterTimestamp().get().toEpochMilli();
events = events.subList(firstIndexAtOrAfter(events, afterTimestampMillis), events.size());
}
}
return events;
}

/**
* Returns the index of the first event whose timestamp is at or after {@code timestampMillis}, or
* the list size if there is none. {@code sortedEvents} must be sorted ascending by timestamp.
*/
private static int firstIndexAtOrAfter(List<Event> sortedEvents, long timestampMillis) {
int low = 0;
int high = sortedEvents.size();
while (low < high) {
int mid = (low + high) >>> 1;
if (sortedEvents.get(mid).timestamp() < timestampMillis) {
low = mid + 1;
} else {
high = mid;
}
}
return low;
}

@Override
public Completable deleteSession(String appName, String userId, String sessionId) {
String reasoningEngineId = parseReasoningEngineId(appName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,4 +379,115 @@ public void appendEvent_withStateRemoved_updatesSessionState() {
assertThat(updatedSession.state()).containsExactly("key1", "value1");
assertThat(updatedSession.state()).doesNotContainKey("key2");
}

@Test
public void getSession_eventTimestampAfterUpdateTime_doesNotDropEvent() {
// Regression test: event timestamps are assigned client-side while the
// session updateTime is assigned server-side, so clock skew can make the
// latest event newer than updateTime. Such events must not be dropped by
// getSession().
sessionMap.put("5", mockSessionJson("5", "2024-12-12T12:12:12.000000Z"));
eventMap.put(
"5",
mockEventsJson(
mockEventJson("before", "2024-12-12T12:12:11.000000Z"),
mockEventJson("after", "2024-12-12T12:12:12.500000Z")));

Session session =
vertexAiSessionService.getSession("123", "user", "5", Optional.empty()).blockingGet();

assertThat(session.events().stream().map(Event::id))
.containsExactly("before", "after")
.inOrder();
}

@Test
public void getSession_afterTimestampConfig_keepsEventsAtOrAfterThreshold() {
sessionMap.put("6", mockSessionJson("6", "2024-12-12T12:00:30.000000Z"));
eventMap.put(
"6",
mockEventsJson(
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
GetSessionConfig config =
GetSessionConfig.builder()
.afterTimestamp(Instant.parse("2024-12-12T12:00:10.000000Z"))
.build();

Session session =
vertexAiSessionService.getSession("123", "user", "6", Optional.of(config)).blockingGet();

// The threshold is inclusive: e2 (== afterTimestamp) and e3 are kept, e1 is
// dropped.
assertThat(session.events().stream().map(Event::id)).containsExactly("e2", "e3").inOrder();
}

@Test
public void getSession_afterTimestampBetweenEvents_dropsEventsBeforeThreshold() {
sessionMap.put("8", mockSessionJson("8", "2024-12-12T12:00:30.000000Z"));
eventMap.put(
"8",
mockEventsJson(
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
GetSessionConfig config =
GetSessionConfig.builder()
.afterTimestamp(Instant.parse("2024-12-12T12:00:12.000000Z"))
.build();

Session session =
vertexAiSessionService.getSession("123", "user", "8", Optional.of(config)).blockingGet();

// afterTimestamp falls strictly between e2 and e3, so only e3 is kept.
assertThat(session.events().stream().map(Event::id)).containsExactly("e3");
}

@Test
public void getSession_numRecentEventsConfig_returnsMostRecentEvents() {
sessionMap.put("7", mockSessionJson("7", "2024-12-12T12:00:30.000000Z"));
eventMap.put(
"7",
mockEventsJson(
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(2).build();

Session session =
vertexAiSessionService.getSession("123", "user", "7", Optional.of(config)).blockingGet();

assertThat(session.events().stream().map(Event::id)).containsExactly("e2", "e3").inOrder();
}

private static String mockSessionJson(String sessionId, String updateTime) {
return String.format(
"""
{
"name" : "reasoningEngines/123/sessions/%s",
"userId" : "user",
"updateTime" : "%s"
}\
""",
sessionId, updateTime);
}

private static String mockEventJson(String eventId, String timestamp) {
return String.format(
"""
{
"name" : "reasoningEngines/123/sessions/x/events/%s",
"invocationId" : "%s",
"author" : "agent",
"timestamp" : "%s",
"content" : { "role" : "model", "parts" : [ { "text" : "%s" } ] }
}\
""",
eventId, eventId, timestamp, eventId);
}

private static String mockEventsJson(String... events) {
return "[" + String.join(",", events) + "]";
}
}
Loading