Skip to content

Commit 1e004b6

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: wait for the Runner to persist a step's events before the ADK flow's next step (sequential-tool-execution race)
`BaseLlmFlow.run()` builds each step's request from the session, but the `Runner` persists events asynchronously downstream of the flow. A slow `appendEvent` could let the next step start from a stale session missing the prior step's events (e.g. a tool's function response), making the model re-call the tool or hallucinate its result. The `Runner` stays the sole `appendEvent` caller and the flow waits: it calls `PersistBarrier.markPersisted(id)` after each append (or `markFailed(id, error)` if it fails), and the flow calls `PersistBarrier.awaitPersisted(stepEvents)` between steps. The barrier is a reactive per-event signal in the shared `InvocationContext.callbackContextData` and never blocks a thread; `Contents` is unchanged. `PersistBarrier.enable()`, called by the `Runner`, keeps `awaitPersisted` a no-op when the flow runs without a `Runner`. Each event id maps to a `CompletableSubject`: pending until its append finishes, then terminally completed or failed. The subject retains its terminal state, so `awaitPersisted`/`mark*` may happen in any order and a late await -- e.g. at a higher flow level across an agent transfer -- resolves immediately. If an append fails, the matching await fails rather than blocking forever. It is thread-safe and lock-free: `markPersisted`/`markFailed` may run off-thread when an async `appendEvent` completes, and `ConcurrentHashMap.computeIfAbsent` hands both sides the same subject. PiperOrigin-RevId: 921989444
1 parent 9a06dd3 commit 1e004b6

5 files changed

Lines changed: 697 additions & 1 deletion

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,10 @@ private Flowable<Event> run(
523523
return Flowable.empty();
524524
} else {
525525
logger.debug("Continuing to next step of the flow.");
526-
return run(spanContext, invocationContext, stepsCompleted + 1);
526+
// Wait until the Runner has persisted this step's events so the next step's
527+
// request is not built from a stale session (see PersistBarrier).
528+
return PersistBarrier.awaitPersisted(invocationContext, eventList)
529+
.andThen(run(spanContext, invocationContext, stepsCompleted + 1));
527530
}
528531
}));
529532
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.flows.llmflows;
18+
19+
import com.google.adk.agents.InvocationContext;
20+
import com.google.adk.events.Event;
21+
import com.google.common.annotations.VisibleForTesting;
22+
import io.reactivex.rxjava3.core.Completable;
23+
import io.reactivex.rxjava3.subjects.CompletableSubject;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.concurrent.ConcurrentHashMap;
27+
28+
/**
29+
* Lets {@link BaseLlmFlow}'s multi-step loop wait until the {@code Runner} -- the sole event
30+
* persister -- has appended the current step's events, so the next step's request (built from
31+
* {@code session.events()} by {@link Contents}) is not assembled from a stale session. The {@code
32+
* Runner} calls {@link #markPersisted} (or {@link #markFailed}) after each append; the flow calls
33+
* {@link #awaitPersisted} between steps. State lives in the per-invocation {@link
34+
* InvocationContext#callbackContextData()} map, shared across the agent tree.
35+
*
36+
* <p>Each event id maps to a {@link CompletableSubject}: pending until its append finishes, then
37+
* terminally completed or failed. The subject retains its terminal state, so {@code
38+
* awaitPersisted}/{@code mark*} may happen in any order and a late await -- e.g. at a higher flow
39+
* level across an agent transfer -- resolves immediately. If an append fails, the matching await
40+
* fails with that error rather than blocking forever.
41+
*
42+
* <p>Thread-safe and lock-free: {@code markPersisted}/{@code markFailed} may run off-thread (async
43+
* {@code appendEvent}) concurrently with {@code awaitPersisted}; {@link
44+
* java.util.concurrent.ConcurrentHashMap#computeIfAbsent} hands both sides the same subject, which
45+
* itself serializes its terminal signal against subscription.
46+
*/
47+
public final class PersistBarrier {
48+
49+
private static final String ENABLED_KEY = "com.google.adk.flows.llmflows.persistBarrier.enabled";
50+
private static final String BARRIERS_KEY =
51+
"com.google.adk.flows.llmflows.persistBarrier.barriers";
52+
53+
private PersistBarrier() {}
54+
55+
/**
56+
* Marks that a {@code Runner} is driving this invocation and will resolve each appended event.
57+
* Otherwise (flow run directly, e.g. unit tests) {@link #awaitPersisted} is a no-op, avoiding a
58+
* deadlock waiting for a signal that never comes.
59+
*/
60+
public static void enable(InvocationContext context) {
61+
context.callbackContextData().put(ENABLED_KEY, true);
62+
}
63+
64+
/**
65+
* Completes once every event in {@code events} has been {@link #markPersisted}, or fails if any
66+
* was {@link #markFailed}; completes immediately if the barrier was never {@link #enable}d.
67+
* Already-resolved events resolve immediately, so the order of {@code awaitPersisted}/{@code
68+
* mark*} does not matter.
69+
*/
70+
public static Completable awaitPersisted(InvocationContext context, List<Event> events) {
71+
Boolean enabled = (Boolean) context.callbackContextData().get(ENABLED_KEY);
72+
if (enabled == null || !enabled) {
73+
return Completable.complete();
74+
}
75+
Completable result = Completable.complete();
76+
for (Event event : events) {
77+
String eventId = event.id();
78+
if (eventId != null) {
79+
result = result.andThen(barrier(context, eventId));
80+
}
81+
}
82+
return result;
83+
}
84+
85+
/** Signals that the {@code Runner} persisted the event with the given id. */
86+
public static void markPersisted(InvocationContext context, String eventId) {
87+
if (eventId != null) {
88+
barrier(context, eventId).onComplete();
89+
}
90+
}
91+
92+
/**
93+
* Signals that persisting the event with the given id failed, so an await on it fails with {@code
94+
* error} instead of blocking forever.
95+
*/
96+
public static void markFailed(InvocationContext context, String eventId, Throwable error) {
97+
if (eventId != null) {
98+
barrier(context, eventId).onError(error);
99+
}
100+
}
101+
102+
/**
103+
* The per-event subject, created on first use. {@code computeIfAbsent} is atomic, so an awaiter
104+
* and a concurrent mark share one subject regardless of order.
105+
*/
106+
private static CompletableSubject barrier(InvocationContext context, String eventId) {
107+
return barriers(context).computeIfAbsent(eventId, unusedKey -> CompletableSubject.create());
108+
}
109+
110+
/** Awaited-but-unresolved events; drains to 0 once a step's events are persisted or failed. */
111+
@VisibleForTesting
112+
static int pendingCount(InvocationContext context) {
113+
int pending = 0;
114+
for (CompletableSubject barrier : barriers(context).values()) {
115+
if (!barrier.hasComplete() && !barrier.hasThrowable()) {
116+
pending++;
117+
}
118+
}
119+
return pending;
120+
}
121+
122+
// Safe: BARRIERS_KEY only ever holds the Map created here.
123+
@SuppressWarnings("unchecked")
124+
private static Map<String, CompletableSubject> barriers(InvocationContext context) {
125+
return (Map<String, CompletableSubject>)
126+
context
127+
.callbackContextData()
128+
.computeIfAbsent(
129+
BARRIERS_KEY, unusedKey -> new ConcurrentHashMap<String, CompletableSubject>());
130+
}
131+
}

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.adk.artifacts.InMemoryArtifactService;
3131
import com.google.adk.events.Event;
3232
import com.google.adk.events.EventActions;
33+
import com.google.adk.flows.llmflows.PersistBarrier;
3334
import com.google.adk.memory.BaseMemoryService;
3435
import com.google.adk.models.Model;
3536
import com.google.adk.plugins.Plugin;
@@ -575,6 +576,9 @@ private Flowable<Event> runAgentWithUpdatedSession(
575576
.content(content)
576577
.build());
577578

579+
// Let BaseLlmFlow block each step until this Runner has persisted the prior step's events.
580+
PersistBarrier.enable(contextWithUpdatedSession);
581+
578582
// Agent execution
579583
Flowable<Event> agentEvents =
580584
contextWithUpdatedSession
@@ -584,6 +588,16 @@ private Flowable<Event> runAgentWithUpdatedSession(
584588
agentEvent ->
585589
this.sessionService
586590
.appendEvent(updatedSession, agentEvent)
591+
// Release (or fail) BaseLlmFlow's wait for this step; the Runner stays the
592+
// sole appendEvent caller (see PersistBarrier).
593+
.doOnSuccess(
594+
unusedEvent ->
595+
PersistBarrier.markPersisted(
596+
contextWithUpdatedSession, agentEvent.id()))
597+
.doOnError(
598+
error ->
599+
PersistBarrier.markFailed(
600+
contextWithUpdatedSession, agentEvent.id(), error))
587601
.flatMap(
588602
registeredEvent -> {
589603
// TODO: remove this hack after deprecating runAsync with Session.
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.flows.llmflows;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.mockito.Mockito.mock;
21+
22+
import com.google.adk.agents.BaseAgent;
23+
import com.google.adk.agents.InvocationContext;
24+
import com.google.adk.events.Event;
25+
import com.google.adk.sessions.BaseSessionService;
26+
import com.google.adk.sessions.Session;
27+
import com.google.common.collect.ImmutableList;
28+
import io.reactivex.rxjava3.observers.TestObserver;
29+
import java.util.ArrayList;
30+
import java.util.Collections;
31+
import java.util.List;
32+
import java.util.concurrent.CountDownLatch;
33+
import org.junit.Before;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
import org.junit.runners.JUnit4;
37+
38+
@RunWith(JUnit4.class)
39+
public final class PersistBarrierTest {
40+
41+
private InvocationContext context;
42+
43+
@Before
44+
public void setUp() {
45+
context =
46+
InvocationContext.builder()
47+
.sessionService(mock(BaseSessionService.class))
48+
.invocationId("inv-1")
49+
.agent(mock(BaseAgent.class))
50+
.session(Session.builder("s").build())
51+
.build();
52+
}
53+
54+
private static Event event(String id) {
55+
return Event.builder().id(id).author("agent").build();
56+
}
57+
58+
@Test
59+
public void awaitBeforeMark_completesOnMark_andDrainsPending() {
60+
PersistBarrier.enable(context);
61+
62+
TestObserver<Void> observer =
63+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
64+
65+
observer.assertNotComplete();
66+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(1);
67+
68+
PersistBarrier.markPersisted(context, "e1");
69+
70+
observer.assertComplete();
71+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
72+
}
73+
74+
@Test
75+
public void markBeforeAwait_completesImmediately_noPending() {
76+
PersistBarrier.enable(context);
77+
78+
PersistBarrier.markPersisted(context, "e1");
79+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
80+
81+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
82+
}
83+
84+
@Test
85+
public void sameEventAwaitedTwice_secondAwaitStillCompletes_andNothingLingers() {
86+
// Mirrors an agent transfer: a sub-agent event is awaited by both the sub-agent and parent
87+
// flows but persisted once; the second await must still complete.
88+
PersistBarrier.enable(context);
89+
90+
TestObserver<Void> subLevel =
91+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
92+
PersistBarrier.markPersisted(context, "e1");
93+
subLevel.assertComplete();
94+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
95+
96+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
97+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
98+
}
99+
100+
@Test
101+
public void multiEventStep_completesOnlyAfterAllMarked() {
102+
PersistBarrier.enable(context);
103+
104+
TestObserver<Void> observer =
105+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"), event("e2"))).test();
106+
107+
PersistBarrier.markPersisted(context, "e1");
108+
observer.assertNotComplete();
109+
110+
PersistBarrier.markPersisted(context, "e2");
111+
observer.assertComplete();
112+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
113+
}
114+
115+
@Test
116+
public void markFailedBeforeAwait_awaitFails() {
117+
PersistBarrier.enable(context);
118+
RuntimeException error = new RuntimeException("append failed");
119+
120+
PersistBarrier.markFailed(context, "e1", error);
121+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertError(error);
122+
123+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
124+
}
125+
126+
@Test
127+
public void awaitBeforeMarkFailed_awaitFails() {
128+
PersistBarrier.enable(context);
129+
RuntimeException error = new RuntimeException("append failed");
130+
131+
TestObserver<Void> observer =
132+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
133+
observer.assertNotComplete();
134+
135+
PersistBarrier.markFailed(context, "e1", error);
136+
137+
observer.assertError(error);
138+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
139+
}
140+
141+
@Test
142+
public void stepWithOneFailedEvent_awaitFails() {
143+
// A step's await fails if any of its events fails to persist, so the next step does not run.
144+
PersistBarrier.enable(context);
145+
RuntimeException error = new RuntimeException("append failed");
146+
147+
TestObserver<Void> observer =
148+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"), event("e2"))).test();
149+
150+
PersistBarrier.markPersisted(context, "e1");
151+
PersistBarrier.markFailed(context, "e2", error);
152+
153+
observer.assertError(error);
154+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
155+
}
156+
157+
@Test
158+
public void notEnabled_awaitIsNoOp() {
159+
// No enable(): flow runs without a Runner, so await must not block forever.
160+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
161+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
162+
}
163+
164+
@Test
165+
public void concurrentAwaitAndMark_allComplete_andDrain() throws Exception {
166+
// awaitPersisted (flow thread) and markPersisted (async appendEvent thread) race on each id;
167+
// none may be stranded and every subject must be dropped.
168+
PersistBarrier.enable(context);
169+
int eventCount = 1000;
170+
List<String> ids = new ArrayList<>();
171+
for (int i = 0; i < eventCount; i++) {
172+
ids.add("e" + i);
173+
}
174+
List<TestObserver<Void>> observers = Collections.synchronizedList(new ArrayList<>());
175+
CountDownLatch start = new CountDownLatch(1);
176+
177+
Thread awaiter =
178+
new Thread(
179+
() -> {
180+
awaitQuietly(start);
181+
for (String id : ids) {
182+
observers.add(
183+
PersistBarrier.awaitPersisted(context, ImmutableList.of(event(id))).test());
184+
}
185+
});
186+
Thread marker =
187+
new Thread(
188+
() -> {
189+
awaitQuietly(start);
190+
for (String id : ids) {
191+
PersistBarrier.markPersisted(context, id);
192+
}
193+
});
194+
195+
awaiter.start();
196+
marker.start();
197+
start.countDown();
198+
awaiter.join();
199+
marker.join();
200+
201+
for (TestObserver<Void> observer : observers) {
202+
observer.assertComplete();
203+
}
204+
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
205+
}
206+
207+
private static void awaitQuietly(CountDownLatch latch) {
208+
try {
209+
latch.await();
210+
} catch (InterruptedException e) {
211+
Thread.currentThread().interrupt();
212+
throw new AssertionError(e);
213+
}
214+
}
215+
}

0 commit comments

Comments
 (0)