Skip to content

Commit bfaf0eb

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
Persist agent events to the session synchronously within each LLM step in BaseLlmFlow.run(), so the next step does not start before the previous step's events have been appended. Both BaseLlmFlow.run() and Runner skip the duplicate appendEvent for events already present in the session (by id), so events emitted by a transferred sub-agent are appended exactly once. PiperOrigin-RevId: 921989444
1 parent e6fe9aa commit bfaf0eb

4 files changed

Lines changed: 217 additions & 12 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.adk.models.LlmRegistry;
3838
import com.google.adk.models.LlmRequest;
3939
import com.google.adk.models.LlmResponse;
40+
import com.google.adk.sessions.SessionUtils;
4041
import com.google.adk.telemetry.Tracing;
4142
import com.google.adk.tools.BaseTool;
4243
import com.google.adk.tools.BaseToolset;
@@ -503,7 +504,30 @@ public Flowable<Event> run(InvocationContext invocationContext) {
503504

504505
private Flowable<Event> run(
505506
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
506-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
507+
// Persist each event to the session synchronously within the step so that the next step does
508+
// not start before the previous step's events have been appended. Without this, the deferred
509+
// continuation (concatWith below) subscribes synchronously on runOneStep's upstream onComplete
510+
// signal, which can race with the downstream consumer's appendEvent calls in Runner.
511+
//
512+
// The Runner-side appendEvent still runs and deduplicates this event by id, so plugin
513+
// onEventCallback and non-LlmAgent paths are unaffected.
514+
//
515+
// Events emitted by a transferred sub-agent's nested BaseLlmFlow.run() have already been
516+
// appended by that nested flow, so skip them here to avoid duplicates. Deduplication is done
517+
// by event id against the session's existing events.
518+
Flowable<Event> currentStepEvents =
519+
runOneStep(spanContext, invocationContext)
520+
.concatMap(
521+
event -> {
522+
if (SessionUtils.isEventAlreadyAppended(invocationContext.session(), event)) {
523+
return Flowable.just(event);
524+
}
525+
return invocationContext
526+
.sessionService()
527+
.appendEvent(invocationContext.session(), event)
528+
.toFlowable();
529+
})
530+
.cache();
507531
if (stepsCompleted + 1 >= maxSteps) {
508532
logger.debug("Ending flow execution because max steps reached.");
509533
return currentStepEvents;

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.google.adk.sessions.InMemorySessionService;
3737
import com.google.adk.sessions.Session;
3838
import com.google.adk.sessions.SessionKey;
39+
import com.google.adk.sessions.SessionUtils;
3940
import com.google.adk.summarizer.EventsCompactionConfig;
4041
import com.google.adk.summarizer.LlmEventSummarizer;
4142
import com.google.adk.summarizer.SlidingWindowEventCompactor;
@@ -576,19 +577,25 @@ private Flowable<Event> runAgentWithUpdatedSession(
576577
.agent()
577578
.runAsync(contextWithUpdatedSession)
578579
.concatMap(
579-
agentEvent ->
580-
this.sessionService
581-
.appendEvent(updatedSession, agentEvent)
582-
.flatMap(
583-
registeredEvent -> {
584-
// TODO: remove this hack after deprecating runAsync with Session.
585-
copySessionStates(updatedSession, initialContext.session());
586-
return contextWithUpdatedSession
580+
agentEvent -> {
581+
// TODO: remove this hack after deprecating runAsync with Session.
582+
copySessionStates(updatedSession, initialContext.session());
583+
// BaseLlmFlow appends events synchronously to fix a race where the next LLM
584+
// step would otherwise start before the previous step's events were persisted.
585+
// Skip the duplicate append here so the event is not added twice.
586+
Single<Event> appendOrSkip =
587+
SessionUtils.isEventAlreadyAppended(updatedSession, agentEvent)
588+
? Single.just(agentEvent)
589+
: this.sessionService.appendEvent(updatedSession, agentEvent);
590+
return appendOrSkip
591+
.flatMap(
592+
registeredEvent ->
593+
contextWithUpdatedSession
587594
.pluginManager()
588595
.onEventCallback(contextWithUpdatedSession, registeredEvent)
589-
.defaultIfEmpty(registeredEvent);
590-
})
591-
.toFlowable());
596+
.defaultIfEmpty(registeredEvent))
597+
.toFlowable();
598+
});
592599

593600
// If beforeRunCallback returns content, emit it and skip agent
594601
Context capturedContext = Context.current();

core/src/main/java/com/google/adk/sessions/SessionUtils.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.sessions;
1818

19+
import com.google.adk.events.Event;
1920
import com.google.common.collect.ImmutableList;
2021
import com.google.genai.types.Blob;
2122
import com.google.genai.types.Content;
@@ -31,6 +32,32 @@ public final class SessionUtils {
3132

3233
public SessionUtils() {}
3334

35+
/**
36+
* Returns true if an event with the same id is already present in {@code session.events()}.
37+
*
38+
* <p>Used to deduplicate {@code appendEvent} calls when the same event flows through multiple
39+
* append points (e.g. {@code BaseLlmFlow.run} for a transferred sub-agent and the parent flow, or
40+
* {@code BaseLlmFlow.run} and {@code Runner}).
41+
*/
42+
public static boolean isEventAlreadyAppended(Session session, Event event) {
43+
String eventId = event.id();
44+
if (eventId == null) {
45+
return false;
46+
}
47+
List<Event> events = session.events();
48+
if (events == null || events.isEmpty()) {
49+
return false;
50+
}
51+
synchronized (events) {
52+
for (Event existing : events) {
53+
if (eventId.equals(existing.id())) {
54+
return true;
55+
}
56+
}
57+
}
58+
return false;
59+
}
60+
3461
/** Base64-encodes inline blobs in content. */
3562
public static Content encodeContent(Content content) {
3663
List<Part> encodedParts = new ArrayList<>();

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import static org.mockito.Mockito.when;
3939

4040
import com.google.adk.agents.BaseAgent;
41+
import com.google.adk.agents.Callbacks.AfterModelCallback;
4142
import com.google.adk.agents.InvocationContext;
4243
import com.google.adk.agents.LiveRequestQueue;
4344
import com.google.adk.agents.LlmAgent;
@@ -46,9 +47,14 @@
4647
import com.google.adk.artifacts.BaseArtifactService;
4748
import com.google.adk.events.Event;
4849
import com.google.adk.flows.llmflows.Functions;
50+
import com.google.adk.models.LlmRequest;
4951
import com.google.adk.models.LlmResponse;
5052
import com.google.adk.plugins.BasePlugin;
5153
import com.google.adk.sessions.BaseSessionService;
54+
import com.google.adk.sessions.GetSessionConfig;
55+
import com.google.adk.sessions.InMemorySessionService;
56+
import com.google.adk.sessions.ListEventsResponse;
57+
import com.google.adk.sessions.ListSessionsResponse;
5258
import com.google.adk.sessions.Session;
5359
import com.google.adk.sessions.SessionKey;
5460
import com.google.adk.summarizer.EventsCompactionConfig;
@@ -84,6 +90,7 @@
8490
import java.util.Optional;
8591
import java.util.UUID;
8692
import java.util.concurrent.ConcurrentHashMap;
93+
import java.util.concurrent.ConcurrentMap;
8794
import java.util.concurrent.atomic.AtomicInteger;
8895
import java.util.concurrent.atomic.AtomicReference;
8996
import org.junit.After;
@@ -859,6 +866,146 @@ public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exceptio
859866
subscriber2.assertValue(agentEvent);
860867
}
861868

869+
/**
870+
* A slow appendEvent must not let the next LLM step start with a stale session missing the
871+
* previous step's function-response event.
872+
*/
873+
@Test
874+
public void runAsync_slowAppendEvent_doesNotCauseStaleSessionInNextStep() throws Exception {
875+
TestLlm raceTestLlm =
876+
createTestLlm(
877+
createFunctionCallLlmResponse("call_1", echoTool.name(), ImmutableMap.of("arg", "v1")),
878+
createTextLlmResponse("done"));
879+
880+
LlmAgent agentForRace =
881+
createTestAgentBuilder(raceTestLlm).tools(ImmutableList.of(echoTool)).build();
882+
883+
BaseSessionService delayedSessionService =
884+
new AppendDelayingSessionService(new InMemorySessionService(), 50);
885+
886+
Runner runnerForRace =
887+
Runner.builder()
888+
.app(App.builder().name("test").rootAgent(agentForRace).build())
889+
.sessionService(delayedSessionService)
890+
.build();
891+
Session raceSession =
892+
runnerForRace.sessionService().createSession("test", "user").blockingGet();
893+
894+
var unused =
895+
runnerForRace
896+
.runAsync("user", raceSession.id(), createContent("start"))
897+
.toList()
898+
.blockingGet();
899+
900+
ImmutableList<LlmRequest> requests = ImmutableList.copyOf(raceTestLlm.getRequests());
901+
assertThat(requests).hasSize(2);
902+
903+
// Second LLM request must see the function response from step 1.
904+
boolean foundToolResponse =
905+
requests.get(1).contents().stream()
906+
.flatMap(c -> c.parts().stream().flatMap(List::stream))
907+
.anyMatch(part -> part.functionResponse().isPresent());
908+
assertThat(foundToolResponse).isTrue();
909+
}
910+
911+
/**
912+
* When an LlmAgent transfers control to a sub-LlmAgent, the sub-agent's events flow back up
913+
* through the parent's {@code BaseLlmFlow.run()} pipeline. Each event must be appended to the
914+
* session exactly once.
915+
*/
916+
@Test
917+
public void runAsync_transferToSubAgent_eventsAppendedOnce() throws Exception {
918+
LlmAgent subAgent =
919+
createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response")))
920+
.name("sub-agent")
921+
.build();
922+
923+
// Force a transfer to sub-agent using an afterModelCallback.
924+
AfterModelCallback transferCallback =
925+
(ctx, response) -> {
926+
ctx.eventActions().setTransferToAgent(subAgent.name());
927+
return Maybe.empty();
928+
};
929+
930+
TestLlm rootTestLlm = createTestLlm(createTextLlmResponse("initial"));
931+
LlmAgent rootAgent =
932+
createTestAgentBuilder(rootTestLlm)
933+
.subAgents(subAgent)
934+
.afterModelCallback(ImmutableList.of(transferCallback))
935+
.build();
936+
937+
Runner transferRunner =
938+
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
939+
Session transferSession =
940+
transferRunner.sessionService().createSession("test", "user").blockingGet();
941+
942+
var unused =
943+
transferRunner
944+
.runAsync("user", transferSession.id(), createContent("start"))
945+
.toList()
946+
.blockingGet();
947+
948+
Session finalSession =
949+
transferRunner
950+
.sessionService()
951+
.getSession(
952+
transferSession.appName(),
953+
transferSession.userId(),
954+
transferSession.id(),
955+
Optional.empty())
956+
.blockingGet();
957+
958+
// Each event id should appear at most once in the session.
959+
List<String> eventIds = finalSession.events().stream().map(Event::id).toList();
960+
assertThat(eventIds).containsNoDuplicates();
961+
}
962+
963+
/** {@link BaseSessionService} that delays {@link #appendEvent} to surface ordering bugs. */
964+
private static final class AppendDelayingSessionService implements BaseSessionService {
965+
private final BaseSessionService delegate;
966+
private final long appendDelayMs;
967+
968+
AppendDelayingSessionService(BaseSessionService delegate, long appendDelayMs) {
969+
this.delegate = delegate;
970+
this.appendDelayMs = appendDelayMs;
971+
}
972+
973+
// Delegates to the underlying BaseSessionService createSession overload, which is itself
974+
// deprecated; suppressed because the wrapper must preserve the same signature.
975+
@SuppressWarnings("deprecation")
976+
@Override
977+
public Single<Session> createSession(
978+
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
979+
return delegate.createSession(appName, userId, state, sessionId);
980+
}
981+
982+
@Override
983+
public Maybe<Session> getSession(
984+
String appName, String userId, String sessionId, Optional<GetSessionConfig> config) {
985+
return delegate.getSession(appName, userId, sessionId, config);
986+
}
987+
988+
@Override
989+
public Single<ListSessionsResponse> listSessions(String appName, String userId) {
990+
return delegate.listSessions(appName, userId);
991+
}
992+
993+
@Override
994+
public Completable deleteSession(String appName, String userId, String sessionId) {
995+
return delegate.deleteSession(appName, userId, sessionId);
996+
}
997+
998+
@Override
999+
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
1000+
return delegate.listEvents(appName, userId, sessionId);
1001+
}
1002+
1003+
@Override
1004+
public Single<Event> appendEvent(Session session, Event event) {
1005+
return delegate.appendEvent(session, event).delay(appendDelayMs, MILLISECONDS);
1006+
}
1007+
}
1008+
8621009
@Test
8631010
public void runAsync_withSessionKey_success() {
8641011
var events =

0 commit comments

Comments
 (0)