Skip to content

Commit eff8f8d

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
BaseLlmFlow.run() now appends each event synchronously inside the per-step concatMap, so the next runOneStep sees the previous step's events. Without this, a tool relying on prior events (e.g. a BeforeToolCallback producing a function response) could see stale history and re-call the tool or hallucinate its result. PiperOrigin-RevId: 921989444
1 parent 29d3203 commit eff8f8d

6 files changed

Lines changed: 601 additions & 16 deletions

File tree

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,9 @@ private static boolean isThought(Part part) {
654654

655655
@Override
656656
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
657-
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
657+
// maybeSaveOutputToState runs as a pre-persist finalizer so the outputKey stateDelta is
658+
// part of the persisted append performed inside BaseLlmFlow.run.
659+
return llmFlow.run(invocationContext, this::maybeSaveOutputToState);
658660
}
659661

660662
@Override

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.adk.models.LlmRegistry;
3838
import com.google.adk.models.LlmRequest;
3939
import com.google.adk.models.LlmResponse;
40+
import com.google.adk.sessions.SessionUtils;
4041
import com.google.adk.telemetry.Tracing;
4142
import com.google.adk.tools.BaseTool;
4243
import com.google.adk.tools.BaseToolset;
@@ -60,8 +61,10 @@
6061
import java.util.List;
6162
import java.util.Optional;
6263
import java.util.Set;
64+
import java.util.concurrent.ConcurrentHashMap;
6365
import java.util.concurrent.atomic.AtomicReference;
6466
import java.util.function.BiFunction;
67+
import java.util.function.Consumer;
6568
import org.slf4j.Logger;
6669
import org.slf4j.LoggerFactory;
6770

@@ -498,12 +501,54 @@ private Flowable<Event> runOneStep(Context spanContext, InvocationContext contex
498501
*/
499502
@Override
500503
public Flowable<Event> run(InvocationContext invocationContext) {
501-
return run(Context.current(), invocationContext, 0);
504+
return run(invocationContext, event -> {});
505+
}
506+
507+
/**
508+
* Same as {@link #run(InvocationContext)} but invokes {@code eventPreFinalize} on each event
509+
* authored by this agent immediately before it is appended, so any mutation (e.g. {@code
510+
* stateDelta} updates from {@link LlmAgent}'s {@code outputKey}) is part of the persisted append.
511+
*/
512+
public Flowable<Event> run(
513+
InvocationContext invocationContext, Consumer<Event> eventPreFinalize) {
514+
return run(Context.current(), invocationContext, 0, eventPreFinalize);
502515
}
503516

504517
private Flowable<Event> run(
505-
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
506-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
518+
Context spanContext,
519+
InvocationContext invocationContext,
520+
int stepsCompleted,
521+
Consumer<Event> eventPreFinalize) {
522+
// Append each event synchronously so the next runOneStep sees prior events (avoids racing the
523+
// Runner's append). Record each appended id so the Runner skips re-appending it; skip ids
524+
// already recorded (e.g. bubbling up from a sub-agent's flow). Emit the original event, not the
525+
// service's return (which may be a mock sentinel).
526+
String thisAgentName = invocationContext.agent().name();
527+
Flowable<Event> currentStepEvents =
528+
runOneStep(spanContext, invocationContext)
529+
.concatMap(
530+
event -> {
531+
String eventId = event.id();
532+
if (eventId != null
533+
&& inFlightAppendedEventIds(invocationContext).contains(eventId)) {
534+
return Flowable.just(event);
535+
}
536+
if (thisAgentName != null && thisAgentName.equals(event.author())) {
537+
eventPreFinalize.accept(event);
538+
}
539+
return SessionUtils.safeAppendEvent(
540+
invocationContext.sessionService(), invocationContext.session(), event)
541+
.ignoreElement()
542+
.andThen(
543+
Flowable.fromCallable(
544+
() -> {
545+
if (eventId != null) {
546+
inFlightAppendedEventIds(invocationContext).add(eventId);
547+
}
548+
return event;
549+
}));
550+
})
551+
.cache();
507552
if (stepsCompleted + 1 >= maxSteps) {
508553
logger.debug("Ending flow execution because max steps reached.");
509554
return currentStepEvents;
@@ -523,11 +568,30 @@ private Flowable<Event> run(
523568
return Flowable.empty();
524569
} else {
525570
logger.debug("Continuing to next step of the flow.");
526-
return run(spanContext, invocationContext, stepsCompleted + 1);
571+
return run(
572+
spanContext, invocationContext, stepsCompleted + 1, eventPreFinalize);
527573
}
528574
}));
529575
}
530576

577+
private static final String IN_FLIGHT_APPENDED_EVENT_IDS_KEY =
578+
"com.google.adk.internal.inFlightAppendedEventIds";
579+
580+
/**
581+
* Returns the transient, per-invocation set of event ids appended by the flow but not yet
582+
* consumed by the Runner, lazily creating it. Ids are added here on append and removed by the
583+
* Runner on consume, so this is hand-off state -- not a record of all persisted events.
584+
*/
585+
@SuppressWarnings("unchecked")
586+
private static Set<String> inFlightAppendedEventIds(InvocationContext invocationContext) {
587+
return (Set<String>)
588+
invocationContext
589+
.callbackContextData()
590+
.computeIfAbsent(
591+
IN_FLIGHT_APPENDED_EVENT_IDS_KEY,
592+
unusedKey -> ConcurrentHashMap.<String>newKeySet());
593+
}
594+
531595
/**
532596
* Executes the LLM flow in streaming mode.
533597
*

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import com.google.adk.sessions.InMemorySessionService;
3939
import com.google.adk.sessions.Session;
4040
import com.google.adk.sessions.SessionKey;
41+
import com.google.adk.sessions.SessionUtils;
4142
import com.google.adk.summarizer.EventsCompactionConfig;
4243
import com.google.adk.summarizer.LlmEventSummarizer;
4344
import com.google.adk.summarizer.SlidingWindowEventCompactor;
@@ -70,6 +71,7 @@
7071
import java.util.Map;
7172
import java.util.Objects;
7273
import java.util.Optional;
74+
import java.util.Set;
7375
import java.util.concurrent.ConcurrentHashMap;
7476
import java.util.concurrent.ConcurrentMap;
7577
import org.jspecify.annotations.Nullable;
@@ -581,19 +583,32 @@ private Flowable<Event> runAgentWithUpdatedSession(
581583
.agent()
582584
.runAsync(contextWithUpdatedSession)
583585
.concatMap(
584-
agentEvent ->
585-
this.sessionService
586-
.appendEvent(updatedSession, agentEvent)
587-
.flatMap(
588-
registeredEvent -> {
589-
// TODO: remove this hack after deprecating runAsync with Session.
590-
copySessionStates(updatedSession, initialContext.session());
591-
return contextWithUpdatedSession
586+
agentEvent -> {
587+
// TODO: remove this hack after deprecating runAsync with Session.
588+
copySessionStates(updatedSession, initialContext.session());
589+
// Skip events already appended by BaseLlmFlow.run (id recorded in the shared
590+
// set),
591+
// removing the id to keep the set bounded. Append everything else (agent-callback
592+
// events, non-LlmAgent leaves) here, so each event is appended exactly once.
593+
String agentEventId = agentEvent.id();
594+
boolean alreadyPersisted =
595+
agentEventId != null
596+
&& inFlightAppendedEventIds(contextWithUpdatedSession)
597+
.remove(agentEventId);
598+
Single<Event> appendResult =
599+
alreadyPersisted
600+
? Single.just(agentEvent)
601+
: SessionUtils.safeAppendEvent(
602+
this.sessionService, updatedSession, agentEvent);
603+
return appendResult
604+
.flatMap(
605+
registeredEvent ->
606+
contextWithUpdatedSession
592607
.pluginManager()
593608
.onEventCallback(contextWithUpdatedSession, registeredEvent)
594-
.defaultIfEmpty(registeredEvent);
595-
})
596-
.toFlowable());
609+
.defaultIfEmpty(registeredEvent))
610+
.toFlowable();
611+
});
597612

598613
// If beforeRunCallback returns content, emit it and skip agent
599614
Context capturedContext = Context.current();
@@ -619,6 +634,24 @@ private void copySessionStates(Session source, Session target) {
619634
target.state().putAll(source.state());
620635
}
621636

637+
private static final String IN_FLIGHT_APPENDED_EVENT_IDS_KEY =
638+
"com.google.adk.internal.inFlightAppendedEventIds";
639+
640+
/**
641+
* Returns the transient, per-invocation set of event ids appended by the flow but not yet
642+
* consumed by the Runner, lazily creating it. Ids are added by the flow on append and removed
643+
* here on consume, so this is hand-off state -- not a record of all persisted events.
644+
*/
645+
@SuppressWarnings("unchecked")
646+
private static Set<String> inFlightAppendedEventIds(InvocationContext invocationContext) {
647+
return (Set<String>)
648+
invocationContext
649+
.callbackContextData()
650+
.computeIfAbsent(
651+
IN_FLIGHT_APPENDED_EVENT_IDS_KEY,
652+
unusedKey -> ConcurrentHashMap.<String>newKeySet());
653+
}
654+
622655
/**
623656
* Creates an {@link InvocationContext} for a live (streaming) run.
624657
*

core/src/main/java/com/google/adk/sessions/SessionUtils.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
package com.google.adk.sessions;
1818

19+
import com.google.adk.events.Event;
1920
import com.google.common.collect.ImmutableList;
2021
import com.google.genai.types.Blob;
2122
import com.google.genai.types.Content;
2223
import com.google.genai.types.Part;
24+
import io.reactivex.rxjava3.core.Single;
2325
import java.util.ArrayList;
2426
import java.util.Base64;
2527
import java.util.List;
@@ -31,6 +33,27 @@ public final class SessionUtils {
3133

3234
public SessionUtils() {}
3335

36+
/**
37+
* Appends {@code event} via {@code service}, or just to {@code session.events()} when the session
38+
* is partial (no {@code appName}). The partial-session bypass exists for unit tests that build
39+
* {@code Session.builder(id).build()} and bypass {@code Runner}; most production services
40+
* (including {@link InMemorySessionService}) {@code requireNonNull(appName)}. Production callers
41+
* always pass fully-formed sessions and hit the unchanged {@code service.appendEvent} path.
42+
*/
43+
public static Single<Event> safeAppendEvent(
44+
BaseSessionService service, Session session, Event event) {
45+
if (session.appName() == null) {
46+
List<Event> events = session.events();
47+
if (events != null) {
48+
synchronized (events) {
49+
events.add(event);
50+
}
51+
}
52+
return Single.just(event);
53+
}
54+
return service.appendEvent(session, event);
55+
}
56+
3457
/** Base64-encodes inline blobs in content. */
3558
public static Content encodeContent(Content content) {
3659
List<Part> encodedParts = new ArrayList<>();

core/src/test/java/com/google/adk/agents/LlmAgentTest.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
import com.google.adk.models.LlmRequest;
4040
import com.google.adk.models.LlmResponse;
4141
import com.google.adk.models.Model;
42+
import com.google.adk.sessions.BaseSessionService;
43+
import com.google.adk.sessions.GetSessionConfig;
4244
import com.google.adk.sessions.InMemorySessionService;
45+
import com.google.adk.sessions.ListEventsResponse;
46+
import com.google.adk.sessions.ListSessionsResponse;
4347
import com.google.adk.sessions.Session;
4448
import com.google.adk.telemetry.Tracing;
4549
import com.google.adk.testing.TestLlm;
@@ -58,11 +62,14 @@
5862
import io.opentelemetry.api.trace.Tracer;
5963
import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule;
6064
import io.opentelemetry.sdk.trace.data.SpanData;
65+
import io.reactivex.rxjava3.core.Completable;
6166
import io.reactivex.rxjava3.core.Flowable;
6267
import io.reactivex.rxjava3.core.Maybe;
6368
import io.reactivex.rxjava3.core.Single;
6469
import java.util.List;
70+
import java.util.Optional;
6571
import java.util.concurrent.ConcurrentHashMap;
72+
import java.util.concurrent.ConcurrentMap;
6673
import java.util.concurrent.atomic.AtomicBoolean;
6774
import org.junit.After;
6875
import org.junit.Before;
@@ -184,6 +191,113 @@ public void testRun_withoutOutputKey_doesNotSaveState() {
184191
assertThat(events.get(0).actions().stateDelta()).isEmpty();
185192
}
186193

194+
/**
195+
* Partial-session bypass: tests that build {@code Session.builder(id).build()} and call {@code
196+
* agent.runAsync} directly (bypassing Runner) must not trip the {@code requireNonNull(appName)}
197+
* in {@code InMemorySessionService.appendEvent}. The event is tracked in {@code session.events()}
198+
* so subsequent steps see prior events. Surfaced originally by orcas {@code LlmAgentActionTest},
199+
* dataworkeragent and asterix small tests.
200+
*/
201+
@Test
202+
public void testRun_partialSessionWithoutAppName_doesNotThrow() {
203+
Content modelContent = Content.fromParts(Part.fromText("Agent Response"));
204+
TestLlm testLlm = createTestLlm(createLlmResponse(modelContent));
205+
LlmAgent agent = createTestAgentBuilder(testLlm).build();
206+
207+
Session partialSession = Session.builder("session-id").build();
208+
InvocationContext invocationContext =
209+
InvocationContext.builder()
210+
.sessionService(new InMemorySessionService())
211+
.agent(agent)
212+
.session(partialSession)
213+
.invocationId("invocation-id")
214+
.runConfig(RunConfig.builder().build())
215+
.userContent(Content.fromParts(Part.fromText("hello")))
216+
.build();
217+
218+
List<Event> events = agent.runAsync(invocationContext).toList().blockingGet();
219+
220+
assertThat(events).hasSize(1);
221+
assertThat(events.get(0).content()).hasValue(modelContent);
222+
// Event tracked in session.events() so subsequent steps see prior events.
223+
assertThat(partialSession.events()).hasSize(1);
224+
assertThat(partialSession.events().get(0).id()).isEqualTo(events.get(0).id());
225+
}
226+
227+
/**
228+
* Mirrors ads-ux researchagent {@code GenerateReportFromSourcesActionTest}: an action drives
229+
* {@code agent.runAsync(...).collect { ... }} directly (bypassing Runner) with a {@link
230+
* BaseSessionService} stub whose {@code appendEvent} returns a sentinel empty Event. The original
231+
* LLM-derived event must flow downstream with its content intact, not be swapped for the
232+
* service's return value.
233+
*/
234+
@Test
235+
public void testRun_appendEventReturnsSentinel_originalEventFlowsDownstream() {
236+
Content modelContent = Content.fromParts(Part.fromText("generated report content"));
237+
TestLlm testLlm = createTestLlm(createLlmResponse(modelContent));
238+
LlmAgent agent = createTestAgentBuilder(testLlm).build();
239+
240+
BaseSessionService sentinelReturningSessionService = new SentinelReturningSessionService();
241+
// appName set so safeAppendEvent calls the service (we want to verify its return is ignored).
242+
Session sessionWithAppName =
243+
Session.builder("session-id").appName("test").userId("user").build();
244+
InvocationContext invocationContext =
245+
InvocationContext.builder()
246+
.sessionService(sentinelReturningSessionService)
247+
.agent(agent)
248+
.session(sessionWithAppName)
249+
.invocationId("invocation-id")
250+
.runConfig(RunConfig.builder().build())
251+
.userContent(Content.fromParts(Part.fromText("hello")))
252+
.build();
253+
254+
List<Event> events = agent.runAsync(invocationContext).toList().blockingGet();
255+
256+
assertThat(events).hasSize(1);
257+
// Must be the original LLM-derived event, not the sentinel returned by appendEvent.
258+
assertThat(events.get(0).content()).hasValue(modelContent);
259+
}
260+
261+
/**
262+
* Stub returning an empty Event from {@code appendEvent}, mirroring the shape used by ads-ux
263+
* researchagent and nbu paisa tests ({@code
264+
* Mockito.when(...).thenReturn(Event.builder().build())}).
265+
*/
266+
private static final class SentinelReturningSessionService implements BaseSessionService {
267+
@Override
268+
public Single<Session> createSession(
269+
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
270+
return Single.just(Session.builder("session-id").build());
271+
}
272+
273+
@Override
274+
public Maybe<Session> getSession(
275+
String appName, String userId, String sessionId, Optional<GetSessionConfig> config) {
276+
return Maybe.just(Session.builder(sessionId).build());
277+
}
278+
279+
@Override
280+
public Single<ListSessionsResponse> listSessions(String appName, String userId) {
281+
return Single.just(ListSessionsResponse.builder().build());
282+
}
283+
284+
@Override
285+
public Completable deleteSession(String appName, String userId, String sessionId) {
286+
return Completable.complete();
287+
}
288+
289+
@Override
290+
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
291+
return Single.just(ListEventsResponse.builder().build());
292+
}
293+
294+
@Override
295+
public Single<Event> appendEvent(Session session, Event event) {
296+
// Sentinel return value, mirroring downstream test mocks.
297+
return Single.just(Event.builder().build());
298+
}
299+
}
300+
187301
@Test
188302
public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() {
189303
ImmutableMap<String, Object> echoArgs = ImmutableMap.of("arg", "value");

0 commit comments

Comments
 (0)