Skip to content

Commit baa21a0

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
Persist agent events to the session synchronously within each LLM step in BaseLlmFlow.run(), so the next step does not start before the previous step's events have been appended. Both BaseLlmFlow.run() and Runner skip the duplicate appendEvent for events already present in the session (by id), so events emitted by a transferred sub-agent are appended exactly once. PiperOrigin-RevId: 921989444
1 parent d608909 commit baa21a0

4 files changed

Lines changed: 217 additions & 12 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.adk.models.LlmRegistry;
3838
import com.google.adk.models.LlmRequest;
3939
import com.google.adk.models.LlmResponse;
40+
import com.google.adk.sessions.SessionUtils;
4041
import com.google.adk.telemetry.Tracing;
4142
import com.google.adk.tools.BaseTool;
4243
import com.google.adk.tools.BaseToolset;
@@ -503,7 +504,30 @@ public Flowable<Event> run(InvocationContext invocationContext) {
503504

504505
private Flowable<Event> run(
505506
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
506-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
507+
// Persist each event to the session synchronously within the step so that the next step does
508+
// not start before the previous step's events have been appended. Without this, the deferred
509+
// continuation (concatWith below) subscribes synchronously on runOneStep's upstream onComplete
510+
// signal, which can race with the downstream consumer's appendEvent calls in Runner.
511+
//
512+
// The Runner-side appendEvent still runs and deduplicates this event by id, so plugin
513+
// onEventCallback and non-LlmAgent paths are unaffected.
514+
//
515+
// Events emitted by a transferred sub-agent's nested BaseLlmFlow.run() have already been
516+
// appended by that nested flow, so skip them here to avoid duplicates. Deduplication is done
517+
// by event id against the session's existing events.
518+
Flowable<Event> currentStepEvents =
519+
runOneStep(spanContext, invocationContext)
520+
.concatMap(
521+
event -> {
522+
if (SessionUtils.isEventAlreadyAppended(invocationContext.session(), event)) {
523+
return Flowable.just(event);
524+
}
525+
return invocationContext
526+
.sessionService()
527+
.appendEvent(invocationContext.session(), event)
528+
.toFlowable();
529+
})
530+
.cache();
507531
if (stepsCompleted + 1 >= maxSteps) {
508532
logger.debug("Ending flow execution because max steps reached.");
509533
return currentStepEvents;

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import com.google.adk.sessions.InMemorySessionService;
3939
import com.google.adk.sessions.Session;
4040
import com.google.adk.sessions.SessionKey;
41+
import com.google.adk.sessions.SessionUtils;
4142
import com.google.adk.summarizer.EventsCompactionConfig;
4243
import com.google.adk.summarizer.LlmEventSummarizer;
4344
import com.google.adk.summarizer.SlidingWindowEventCompactor;
@@ -581,19 +582,25 @@ private Flowable<Event> runAgentWithUpdatedSession(
581582
.agent()
582583
.runAsync(contextWithUpdatedSession)
583584
.concatMap(
584-
agentEvent ->
585-
this.sessionService
586-
.appendEvent(updatedSession, agentEvent)
587-
.flatMap(
588-
registeredEvent -> {
589-
// TODO: remove this hack after deprecating runAsync with Session.
590-
copySessionStates(updatedSession, initialContext.session());
591-
return contextWithUpdatedSession
585+
agentEvent -> {
586+
// TODO: remove this hack after deprecating runAsync with Session.
587+
copySessionStates(updatedSession, initialContext.session());
588+
// BaseLlmFlow appends events synchronously to fix a race where the next LLM
589+
// step would otherwise start before the previous step's events were persisted.
590+
// Skip the duplicate append here so the event is not added twice.
591+
Single<Event> appendOrSkip =
592+
SessionUtils.isEventAlreadyAppended(updatedSession, agentEvent)
593+
? Single.just(agentEvent)
594+
: this.sessionService.appendEvent(updatedSession, agentEvent);
595+
return appendOrSkip
596+
.flatMap(
597+
registeredEvent ->
598+
contextWithUpdatedSession
592599
.pluginManager()
593600
.onEventCallback(contextWithUpdatedSession, registeredEvent)
594-
.defaultIfEmpty(registeredEvent);
595-
})
596-
.toFlowable());
601+
.defaultIfEmpty(registeredEvent))
602+
.toFlowable();
603+
});
597604

598605
// If beforeRunCallback returns content, emit it and skip agent
599606
Context capturedContext = Context.current();

core/src/main/java/com/google/adk/sessions/SessionUtils.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.sessions;
1818

19+
import com.google.adk.events.Event;
1920
import com.google.common.collect.ImmutableList;
2021
import com.google.genai.types.Blob;
2122
import com.google.genai.types.Content;
@@ -31,6 +32,32 @@ public final class SessionUtils {
3132

3233
public SessionUtils() {}
3334

35+
/**
36+
* Returns true if an event with the same id is already present in {@code session.events()}.
37+
*
38+
* <p>Used to deduplicate {@code appendEvent} calls when the same event flows through multiple
39+
* append points (e.g. {@code BaseLlmFlow.run} for a transferred sub-agent and the parent flow, or
40+
* {@code BaseLlmFlow.run} and {@code Runner}).
41+
*/
42+
public static boolean isEventAlreadyAppended(Session session, Event event) {
43+
String eventId = event.id();
44+
if (eventId == null) {
45+
return false;
46+
}
47+
List<Event> events = session.events();
48+
if (events == null || events.isEmpty()) {
49+
return false;
50+
}
51+
synchronized (events) {
52+
for (Event existing : events) {
53+
if (eventId.equals(existing.id())) {
54+
return true;
55+
}
56+
}
57+
}
58+
return false;
59+
}
60+
3461
/** Base64-encodes inline blobs in content. */
3562
public static Content encodeContent(Content content) {
3663
List<Part> encodedParts = new ArrayList<>();

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import static org.mockito.Mockito.when;
3939

4040
import com.google.adk.agents.BaseAgent;
41+
import com.google.adk.agents.Callbacks.AfterModelCallback;
4142
import com.google.adk.agents.InvocationContext;
4243
import com.google.adk.agents.LiveRequestQueue;
4344
import com.google.adk.agents.LlmAgent;
@@ -47,9 +48,14 @@
4748
import com.google.adk.artifacts.BaseArtifactService;
4849
import com.google.adk.events.Event;
4950
import com.google.adk.flows.llmflows.Functions;
51+
import com.google.adk.models.LlmRequest;
5052
import com.google.adk.models.LlmResponse;
5153
import com.google.adk.plugins.BasePlugin;
5254
import com.google.adk.sessions.BaseSessionService;
55+
import com.google.adk.sessions.GetSessionConfig;
56+
import com.google.adk.sessions.InMemorySessionService;
57+
import com.google.adk.sessions.ListEventsResponse;
58+
import com.google.adk.sessions.ListSessionsResponse;
5359
import com.google.adk.sessions.Session;
5460
import com.google.adk.sessions.SessionKey;
5561
import com.google.adk.summarizer.EventsCompactionConfig;
@@ -85,6 +91,7 @@
8591
import java.util.Optional;
8692
import java.util.UUID;
8793
import java.util.concurrent.ConcurrentHashMap;
94+
import java.util.concurrent.ConcurrentMap;
8895
import java.util.concurrent.atomic.AtomicInteger;
8996
import java.util.concurrent.atomic.AtomicReference;
9097
import org.junit.After;
@@ -860,6 +867,146 @@ public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exceptio
860867
subscriber2.assertValue(agentEvent);
861868
}
862869

870+
/**
871+
* A slow appendEvent must not let the next LLM step start with a stale session missing the
872+
* previous step's function-response event.
873+
*/
874+
@Test
875+
public void runAsync_slowAppendEvent_doesNotCauseStaleSessionInNextStep() throws Exception {
876+
TestLlm raceTestLlm =
877+
createTestLlm(
878+
createFunctionCallLlmResponse("call_1", echoTool.name(), ImmutableMap.of("arg", "v1")),
879+
createTextLlmResponse("done"));
880+
881+
LlmAgent agentForRace =
882+
createTestAgentBuilder(raceTestLlm).tools(ImmutableList.of(echoTool)).build();
883+
884+
BaseSessionService delayedSessionService =
885+
new AppendDelayingSessionService(new InMemorySessionService(), 50);
886+
887+
Runner runnerForRace =
888+
Runner.builder()
889+
.app(App.builder().name("test").rootAgent(agentForRace).build())
890+
.sessionService(delayedSessionService)
891+
.build();
892+
Session raceSession =
893+
runnerForRace.sessionService().createSession("test", "user").blockingGet();
894+
895+
var unused =
896+
runnerForRace
897+
.runAsync("user", raceSession.id(), createContent("start"))
898+
.toList()
899+
.blockingGet();
900+
901+
ImmutableList<LlmRequest> requests = ImmutableList.copyOf(raceTestLlm.getRequests());
902+
assertThat(requests).hasSize(2);
903+
904+
// Second LLM request must see the function response from step 1.
905+
boolean foundToolResponse =
906+
requests.get(1).contents().stream()
907+
.flatMap(c -> c.parts().stream().flatMap(List::stream))
908+
.anyMatch(part -> part.functionResponse().isPresent());
909+
assertThat(foundToolResponse).isTrue();
910+
}
911+
912+
/**
913+
* When an LlmAgent transfers control to a sub-LlmAgent, the sub-agent's events flow back up
914+
* through the parent's {@code BaseLlmFlow.run()} pipeline. Each event must be appended to the
915+
* session exactly once.
916+
*/
917+
@Test
918+
public void runAsync_transferToSubAgent_eventsAppendedOnce() throws Exception {
919+
LlmAgent subAgent =
920+
createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response")))
921+
.name("sub-agent")
922+
.build();
923+
924+
// Force a transfer to sub-agent using an afterModelCallback.
925+
AfterModelCallback transferCallback =
926+
(ctx, response) -> {
927+
ctx.eventActions().setTransferToAgent(subAgent.name());
928+
return Maybe.empty();
929+
};
930+
931+
TestLlm rootTestLlm = createTestLlm(createTextLlmResponse("initial"));
932+
LlmAgent rootAgent =
933+
createTestAgentBuilder(rootTestLlm)
934+
.subAgents(subAgent)
935+
.afterModelCallback(ImmutableList.of(transferCallback))
936+
.build();
937+
938+
Runner transferRunner =
939+
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
940+
Session transferSession =
941+
transferRunner.sessionService().createSession("test", "user").blockingGet();
942+
943+
var unused =
944+
transferRunner
945+
.runAsync("user", transferSession.id(), createContent("start"))
946+
.toList()
947+
.blockingGet();
948+
949+
Session finalSession =
950+
transferRunner
951+
.sessionService()
952+
.getSession(
953+
transferSession.appName(),
954+
transferSession.userId(),
955+
transferSession.id(),
956+
Optional.empty())
957+
.blockingGet();
958+
959+
// Each event id should appear at most once in the session.
960+
List<String> eventIds = finalSession.events().stream().map(Event::id).toList();
961+
assertThat(eventIds).containsNoDuplicates();
962+
}
963+
964+
/** {@link BaseSessionService} that delays {@link #appendEvent} to surface ordering bugs. */
965+
private static final class AppendDelayingSessionService implements BaseSessionService {
966+
private final BaseSessionService delegate;
967+
private final long appendDelayMs;
968+
969+
AppendDelayingSessionService(BaseSessionService delegate, long appendDelayMs) {
970+
this.delegate = delegate;
971+
this.appendDelayMs = appendDelayMs;
972+
}
973+
974+
// Delegates to the underlying BaseSessionService createSession overload, which is itself
975+
// deprecated; suppressed because the wrapper must preserve the same signature.
976+
@SuppressWarnings("deprecation")
977+
@Override
978+
public Single<Session> createSession(
979+
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
980+
return delegate.createSession(appName, userId, state, sessionId);
981+
}
982+
983+
@Override
984+
public Maybe<Session> getSession(
985+
String appName, String userId, String sessionId, Optional<GetSessionConfig> config) {
986+
return delegate.getSession(appName, userId, sessionId, config);
987+
}
988+
989+
@Override
990+
public Single<ListSessionsResponse> listSessions(String appName, String userId) {
991+
return delegate.listSessions(appName, userId);
992+
}
993+
994+
@Override
995+
public Completable deleteSession(String appName, String userId, String sessionId) {
996+
return delegate.deleteSession(appName, userId, sessionId);
997+
}
998+
999+
@Override
1000+
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
1001+
return delegate.listEvents(appName, userId, sessionId);
1002+
}
1003+
1004+
@Override
1005+
public Single<Event> appendEvent(Session session, Event event) {
1006+
return delegate.appendEvent(session, event).delay(appendDelayMs, MILLISECONDS);
1007+
}
1008+
}
1009+
8631010
@Test
8641011
public void runAsync_withSessionKey_success() {
8651012
var events =

0 commit comments

Comments
 (0)