Skip to content

Commit f5536cb

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: wait for the Runner to persist a step's events before the ADK flow's next step (sequential-tool-execution race)
`BaseLlmFlow.run()` builds each step's request from the session, but the `Runner` persists events asynchronously downstream of the flow. A slow `appendEvent` could let the next step start from a stale session missing the prior step's events (e.g. a tool's function response), making the model re-call the tool or hallucinate its result. The `Runner` stays the sole `appendEvent` caller and the flow waits: it calls `PersistBarrier.markPersisted(id)` after each append, and the flow calls `PersistBarrier.awaitPersisted(stepEvents)` between steps. The barrier is a reactive per-event signal in the shared `InvocationContext.callbackContextData` and never blocks a thread; `Contents` is unchanged. `PersistBarrier.enable()`, called by the `Runner`, keeps `awaitPersisted` a no-op when the flow runs without a `Runner`. The barrier self-cleans: a persisted id is recorded in a lightweight set, while a `CompletableSubject` is kept only for an id awaited before it is persisted and dropped once persisted -- so the subject map does not grow with invocation length (an id can be awaited at several flow levels, e.g. across an agent transfer). It is thread-safe (concurrent maps plus a register-then-recheck handshake), since `markPersisted` may run off-thread when an async `appendEvent` completes. PiperOrigin-RevId: 921989444
1 parent 9a06dd3 commit f5536cb

5 files changed

Lines changed: 652 additions & 1 deletion

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,10 @@ private Flowable<Event> run(
523523
return Flowable.empty();
524524
} else {
525525
logger.debug("Continuing to next step of the flow.");
526-
return run(spanContext, invocationContext, stepsCompleted + 1);
526+
// Wait until the Runner has persisted this step's events so the next step's
527+
// request is not built from a stale session (see PersistBarrier).
528+
return PersistBarrier.awaitPersisted(invocationContext, eventList)
529+
.andThen(run(spanContext, invocationContext, stepsCompleted + 1));
527530
}
528531
}));
529532
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.flows.llmflows;
18+
19+
import com.google.adk.agents.InvocationContext;
20+
import com.google.adk.events.Event;
21+
import com.google.common.annotations.VisibleForTesting;
22+
import io.reactivex.rxjava3.core.Completable;
23+
import io.reactivex.rxjava3.subjects.CompletableSubject;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Set;
27+
import java.util.concurrent.ConcurrentHashMap;
28+
29+
/**
30+
* Lets {@link BaseLlmFlow}'s multi-step loop wait until the {@code Runner} -- the sole event
31+
* persister -- has appended the current step's events, so the next step's request (built from
32+
* {@code session.events()} by {@link Contents}) is not assembled from a stale session. The {@code
33+
* Runner} calls {@link #markPersisted} after each append; the flow calls {@link #awaitPersisted}
34+
* between steps. State lives in the per-invocation {@link InvocationContext#callbackContextData()}
35+
* map, shared across the agent tree.
36+
*
37+
* <p>An event may be awaited at several flow levels (an agent transfer concatenates a sub-agent's
38+
* events into the parent's step) but is persisted once, so persisted ids are recorded in a {@link
39+
* Set} and a {@link CompletableSubject} is kept only for an id awaited before it is persisted, then
40+
* dropped -- the subject map self-drains instead of growing with the invocation.
41+
*
42+
* <p>Thread-safe: {@code markPersisted} may run off-thread (async {@code appendEvent}) concurrently
43+
* with {@code awaitPersisted}; the concurrent maps plus the register-then-recheck in {@link
44+
* #awaitOne} close that race without locking.
45+
*/
46+
public final class PersistBarrier {
47+
48+
private static final String ENABLED_KEY = "com.google.adk.flows.llmflows.persistBarrier.enabled";
49+
private static final String PERSISTED_KEY =
50+
"com.google.adk.flows.llmflows.persistBarrier.persisted";
51+
private static final String PENDING_KEY = "com.google.adk.flows.llmflows.persistBarrier.pending";
52+
53+
private PersistBarrier() {}
54+
55+
/**
56+
* Marks that a {@code Runner} is driving this invocation and will {@link #markPersisted} each
57+
* appended event. Otherwise (flow run directly, e.g. unit tests) {@link #awaitPersisted} is a
58+
* no-op, avoiding a deadlock waiting for a signal that never comes.
59+
*/
60+
public static void enable(InvocationContext context) {
61+
context.callbackContextData().put(ENABLED_KEY, true);
62+
}
63+
64+
/**
65+
* Completes once every event in {@code events} has been {@link #markPersisted}, or immediately if
66+
* the barrier was never {@link #enable}d. Already-persisted events complete immediately, so the
67+
* order of {@code awaitPersisted}/{@code markPersisted} does not matter.
68+
*/
69+
public static Completable awaitPersisted(InvocationContext context, List<Event> events) {
70+
if (!Boolean.TRUE.equals(context.callbackContextData().get(ENABLED_KEY))) {
71+
return Completable.complete();
72+
}
73+
Completable result = Completable.complete();
74+
for (Event event : events) {
75+
String eventId = event.id();
76+
if (eventId != null) {
77+
result = result.andThen(awaitOne(context, eventId));
78+
}
79+
}
80+
return result;
81+
}
82+
83+
private static Completable awaitOne(InvocationContext context, String eventId) {
84+
Set<String> persisted = persistedIds(context);
85+
if (persisted.contains(eventId)) {
86+
return Completable.complete();
87+
}
88+
CompletableSubject subject =
89+
pending(context).computeIfAbsent(eventId, unusedKey -> CompletableSubject.create());
90+
// Re-check after registering so a markPersisted that raced in between is not lost.
91+
if (persisted.contains(eventId)) {
92+
pending(context).remove(eventId);
93+
return Completable.complete();
94+
}
95+
return subject;
96+
}
97+
98+
/** Signals that the {@code Runner} has finished persisting the event with the given id. */
99+
public static void markPersisted(InvocationContext context, String eventId) {
100+
if (eventId == null) {
101+
return;
102+
}
103+
persistedIds(context).add(eventId);
104+
CompletableSubject subject = pending(context).remove(eventId);
105+
if (subject != null) {
106+
subject.onComplete();
107+
}
108+
}
109+
110+
/** Awaited-but-not-yet-persisted events; drains to 0 once a step's events are persisted. */
111+
@VisibleForTesting
112+
static int pendingCount(InvocationContext context) {
113+
return pending(context).size();
114+
}
115+
116+
@SuppressWarnings("unchecked")
117+
private static Set<String> persistedIds(InvocationContext context) {
118+
return (Set<String>)
119+
context
120+
.callbackContextData()
121+
.computeIfAbsent(PERSISTED_KEY, unusedKey -> ConcurrentHashMap.<String>newKeySet());
122+
}
123+
124+
@SuppressWarnings("unchecked")
125+
private static Map<String, CompletableSubject> pending(InvocationContext context) {
126+
return (Map<String, CompletableSubject>)
127+
context
128+
.callbackContextData()
129+
.computeIfAbsent(
130+
PENDING_KEY, unusedKey -> new ConcurrentHashMap<String, CompletableSubject>());
131+
}
132+
}

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.adk.artifacts.InMemoryArtifactService;
3131
import com.google.adk.events.Event;
3232
import com.google.adk.events.EventActions;
33+
import com.google.adk.flows.llmflows.PersistBarrier;
3334
import com.google.adk.memory.BaseMemoryService;
3435
import com.google.adk.models.Model;
3536
import com.google.adk.plugins.Plugin;
@@ -575,6 +576,9 @@ private Flowable<Event> runAgentWithUpdatedSession(
575576
.content(content)
576577
.build());
577578

579+
// Let BaseLlmFlow block each step until this Runner has persisted the prior step's events.
580+
PersistBarrier.enable(contextWithUpdatedSession);
581+
578582
// Agent execution
579583
Flowable<Event> agentEvents =
580584
contextWithUpdatedSession
@@ -584,6 +588,12 @@ private Flowable<Event> runAgentWithUpdatedSession(
584588
agentEvent ->
585589
this.sessionService
586590
.appendEvent(updatedSession, agentEvent)
591+
// Signal persistence so BaseLlmFlow's loop can release the next step; the
592+
// Runner stays the sole appendEvent caller (see PersistBarrier).
593+
.doOnSuccess(
594+
unusedEvent ->
595+
PersistBarrier.markPersisted(
596+
contextWithUpdatedSession, agentEvent.id()))
587597
.flatMap(
588598
registeredEvent -> {
589599
// TODO: remove this hack after deprecating runAsync with Session.
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.flows.llmflows;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.mockito.Mockito.mock;
21+
22+
import com.google.adk.agents.BaseAgent;
23+
import com.google.adk.agents.InvocationContext;
24+
import com.google.adk.events.Event;
25+
import com.google.adk.sessions.BaseSessionService;
26+
import com.google.adk.sessions.Session;
27+
import com.google.common.collect.ImmutableList;
28+
import io.reactivex.rxjava3.observers.TestObserver;
29+
import java.util.ArrayList;
30+
import java.util.Collections;
31+
import java.util.List;
32+
import java.util.concurrent.CountDownLatch;
33+
import org.junit.Before;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.JUnit4;
37+
38+
@RunWith(JUnit4.class)
39+
public final class PersistBarrierTest {
40+
41+
private InvocationContext context;
42+
43+
@Before
44+
public void setUp() {
45+
context =
46+
InvocationContext.builder()
47+
.sessionService(mock(BaseSessionService.class))
48+
.invocationId("inv-1")
49+
.agent(mock(BaseAgent.class))
50+
.session(Session.builder("s").build())
51+
.build();
52+
}
53+
54+
private static Event event(String id) {
55+
return Event.builder().id(id).author("agent").build();
56+
}
57+
58+
@Test
59+
public void awaitBeforeMark_completesOnMark_andDrainsPending() {
60+
PersistBarrier.enable(context);
61+
62+
TestObserver<Void> observer =
63+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
64+
65+
observer.assertNotComplete();
66+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(1);
67+
68+
PersistBarrier.markPersisted(context, "e1");
69+
70+
observer.assertComplete();
71+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
72+
}
73+
74+
@Test
75+
public void markBeforeAwait_completesImmediately_noPending() {
76+
PersistBarrier.enable(context);
77+
78+
PersistBarrier.markPersisted(context, "e1");
79+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
80+
81+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
82+
}
83+
84+
@Test
85+
public void sameEventAwaitedTwice_secondAwaitStillCompletes_andNothingLingers() {
86+
// Mirrors an agent transfer: a sub-agent event is awaited by both the sub-agent and parent
87+
// flows but persisted once; the second await must still complete.
88+
PersistBarrier.enable(context);
89+
90+
TestObserver<Void> subLevel =
91+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
92+
PersistBarrier.markPersisted(context, "e1");
93+
subLevel.assertComplete();
94+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
95+
96+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
97+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
98+
}
99+
100+
@Test
101+
public void multiEventStep_completesOnlyAfterAllMarked() {
102+
PersistBarrier.enable(context);
103+
104+
TestObserver<Void> observer =
105+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"), event("e2"))).test();
106+
107+
PersistBarrier.markPersisted(context, "e1");
108+
observer.assertNotComplete();
109+
110+
PersistBarrier.markPersisted(context, "e2");
111+
observer.assertComplete();
112+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
113+
}
114+
115+
@Test
116+
public void notEnabled_awaitIsNoOp() {
117+
// No enable(): flow runs without a Runner, so await must not block forever.
118+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
119+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
120+
}
121+
122+
@Test
123+
public void concurrentAwaitAndMark_allComplete_andDrain() throws Exception {
124+
// awaitPersisted (flow thread) and markPersisted (async appendEvent thread) race on each id;
125+
// none may be stranded and every subject must be dropped.
126+
PersistBarrier.enable(context);
127+
int eventCount = 1000;
128+
List<String> ids = new ArrayList<>();
129+
for (int i = 0; i < eventCount; i++) {
130+
ids.add("e" + i);
131+
}
132+
List<TestObserver<Void>> observers = Collections.synchronizedList(new ArrayList<>());
133+
CountDownLatch start = new CountDownLatch(1);
134+
135+
Thread awaiter =
136+
new Thread(
137+
() -> {
138+
awaitQuietly(start);
139+
for (String id : ids) {
140+
observers.add(
141+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event(id))).test());
142+
}
143+
});
144+
Thread marker =
145+
new Thread(
146+
() -> {
147+
awaitQuietly(start);
148+
for (String id : ids) {
149+
PersistBarrier.markPersisted(context, id);
150+
}
151+
});
152+
153+
awaiter.start();
154+
marker.start();
155+
start.countDown();
156+
awaiter.join();
157+
marker.join();
158+
159+
for (TestObserver<Void> observer : observers) {
160+
observer.assertComplete();
161+
}
162+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
163+
}
164+
165+
private static void awaitQuietly(CountDownLatch latch) {
166+
try {
167+
latch.await();
168+
} catch (InterruptedException e) {
169+
Thread.currentThread().interrupt();
170+
throw new AssertionError(e);
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)