Skip to content

Commit d627fa3

Browse files
kvmiloscopybara-github
authored andcommitted
fix: stop dropping the latest event(s) in VertexAiSessionService.getSession
PiperOrigin-RevId: 931027486
1 parent fb9274e commit d627fa3

2 files changed

Lines changed: 137 additions & 20 deletions

File tree

core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -225,24 +225,21 @@ public Maybe<Session> getSession(
225225
if (events.isEmpty()) {
226226
return sessionBuilder.build();
227227
}
228-
events = filterEvents(events, updateTimestamp, config);
228+
events = filterEvents(events, config);
229229
return sessionBuilder.events(events).build();
230230
})
231231
.toMaybe();
232232
});
233233
}
234234

235235
private static List<Event> filterEvents(
236-
List<Event> originalEvents,
237-
@Nullable Instant updateTimestamp,
238-
Optional<GetSessionConfig> config) {
236+
List<Event> originalEvents, Optional<GetSessionConfig> config) {
237+
// Preserve the full event stream that Vertex AI returns. Event timestamps are
238+
// assigned client-side while updateTime is assigned server-side, so filtering
239+
// on updateTime could silently drop the most recently appended event(s).
239240
List<Event> events =
240241
originalEvents.stream()
241-
.filter(
242-
event ->
243-
updateTimestamp == null
244-
|| Instant.ofEpochMilli(event.timestamp()).isBefore(updateTimestamp))
245-
.sorted(Comparator.comparing(Event::timestamp))
242+
.sorted(Comparator.comparingLong(Event::timestamp))
246243
.collect(toCollection(ArrayList::new));
247244

248245
if (config.isPresent()) {
@@ -252,22 +249,31 @@ private static List<Event> filterEvents(
252249
events = events.subList(events.size() - numRecentEvents, events.size());
253250
}
254251
} else if (config.get().afterTimestamp().isPresent()) {
255-
Instant afterTimestamp = config.get().afterTimestamp().get();
256-
int i = events.size() - 1;
257-
while (i >= 0) {
258-
if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) {
259-
break;
260-
}
261-
i -= 1;
262-
}
263-
if (i >= 0) {
264-
events = events.subList(i, events.size());
265-
}
252+
long afterTimestampMillis = config.get().afterTimestamp().get().toEpochMilli();
253+
events = events.subList(firstIndexAtOrAfter(events, afterTimestampMillis), events.size());
266254
}
267255
}
268256
return events;
269257
}
270258

259+
/**
260+
* Returns the index of the first event whose timestamp is at or after {@code timestampMillis}, or
261+
* the list size if there is none. {@code sortedEvents} must be sorted ascending by timestamp.
262+
*/
263+
private static int firstIndexAtOrAfter(List<Event> sortedEvents, long timestampMillis) {
264+
int low = 0;
265+
int high = sortedEvents.size();
266+
while (low < high) {
267+
int mid = (low + high) >>> 1;
268+
if (sortedEvents.get(mid).timestamp() < timestampMillis) {
269+
low = mid + 1;
270+
} else {
271+
high = mid;
272+
}
273+
}
274+
return low;
275+
}
276+
271277
@Override
272278
public Completable deleteSession(String appName, String userId, String sessionId) {
273279
String reasoningEngineId = parseReasoningEngineId(appName);

core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,4 +379,115 @@ public void appendEvent_withStateRemoved_updatesSessionState() {
379379
assertThat(updatedSession.state()).containsExactly("key1", "value1");
380380
assertThat(updatedSession.state()).doesNotContainKey("key2");
381381
}
382+
383+
@Test
384+
public void getSession_eventTimestampAfterUpdateTime_doesNotDropEvent() {
385+
// Regression test: event timestamps are assigned client-side while the
386+
// session updateTime is assigned server-side, so clock skew can make the
387+
// latest event newer than updateTime. Such events must not be dropped by
388+
// getSession().
389+
sessionMap.put("5", mockSessionJson("5", "2024-12-12T12:12:12.000000Z"));
390+
eventMap.put(
391+
"5",
392+
mockEventsJson(
393+
mockEventJson("before", "2024-12-12T12:12:11.000000Z"),
394+
mockEventJson("after", "2024-12-12T12:12:12.500000Z")));
395+
396+
Session session =
397+
vertexAiSessionService.getSession("123", "user", "5", Optional.empty()).blockingGet();
398+
399+
assertThat(session.events().stream().map(Event::id))
400+
.containsExactly("before", "after")
401+
.inOrder();
402+
}
403+
404+
@Test
405+
public void getSession_afterTimestampConfig_keepsEventsAtOrAfterThreshold() {
406+
sessionMap.put("6", mockSessionJson("6", "2024-12-12T12:00:30.000000Z"));
407+
eventMap.put(
408+
"6",
409+
mockEventsJson(
410+
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
411+
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
412+
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
413+
GetSessionConfig config =
414+
GetSessionConfig.builder()
415+
.afterTimestamp(Instant.parse("2024-12-12T12:00:10.000000Z"))
416+
.build();
417+
418+
Session session =
419+
vertexAiSessionService.getSession("123", "user", "6", Optional.of(config)).blockingGet();
420+
421+
// The threshold is inclusive: e2 (== afterTimestamp) and e3 are kept, e1 is
422+
// dropped.
423+
assertThat(session.events().stream().map(Event::id)).containsExactly("e2", "e3").inOrder();
424+
}
425+
426+
@Test
427+
public void getSession_afterTimestampBetweenEvents_dropsEventsBeforeThreshold() {
428+
sessionMap.put("8", mockSessionJson("8", "2024-12-12T12:00:30.000000Z"));
429+
eventMap.put(
430+
"8",
431+
mockEventsJson(
432+
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
433+
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
434+
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
435+
GetSessionConfig config =
436+
GetSessionConfig.builder()
437+
.afterTimestamp(Instant.parse("2024-12-12T12:00:12.000000Z"))
438+
.build();
439+
440+
Session session =
441+
vertexAiSessionService.getSession("123", "user", "8", Optional.of(config)).blockingGet();
442+
443+
// afterTimestamp falls strictly between e2 and e3, so only e3 is kept.
444+
assertThat(session.events().stream().map(Event::id)).containsExactly("e3");
445+
}
446+
447+
@Test
448+
public void getSession_numRecentEventsConfig_returnsMostRecentEvents() {
449+
sessionMap.put("7", mockSessionJson("7", "2024-12-12T12:00:30.000000Z"));
450+
eventMap.put(
451+
"7",
452+
mockEventsJson(
453+
mockEventJson("e1", "2024-12-12T12:00:05.000000Z"),
454+
mockEventJson("e2", "2024-12-12T12:00:10.000000Z"),
455+
mockEventJson("e3", "2024-12-12T12:00:15.000000Z")));
456+
GetSessionConfig config = GetSessionConfig.builder().numRecentEvents(2).build();
457+
458+
Session session =
459+
vertexAiSessionService.getSession("123", "user", "7", Optional.of(config)).blockingGet();
460+
461+
assertThat(session.events().stream().map(Event::id)).containsExactly("e2", "e3").inOrder();
462+
}
463+
464+
private static String mockSessionJson(String sessionId, String updateTime) {
465+
return String.format(
466+
"""
467+
{
468+
"name" : "reasoningEngines/123/sessions/%s",
469+
"userId" : "user",
470+
"updateTime" : "%s"
471+
}\
472+
""",
473+
sessionId, updateTime);
474+
}
475+
476+
private static String mockEventJson(String eventId, String timestamp) {
477+
return String.format(
478+
"""
479+
{
480+
"name" : "reasoningEngines/123/sessions/x/events/%s",
481+
"invocationId" : "%s",
482+
"author" : "agent",
483+
"timestamp" : "%s",
484+
"content" : { "role" : "model", "parts" : [ { "text" : "%s" } ] }
485+
}\
486+
""",
487+
eventId, eventId, timestamp, eventId);
488+
}
489+
490+
private static String mockEventsJson(String... events) {
491+
return "[" + String.join(",", events) + "]";
492+
}
382493
}

0 commit comments

Comments
 (0)