Skip to content

Commit 654f03d

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
Ensure that events are appended to the session and processed sequentially before proceeding to the next step in BaseLlmFlow. PiperOrigin-RevId: 896524076
1 parent 78766c1 commit 654f03d

3 files changed

Lines changed: 160 additions & 19 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,31 @@ public Flowable<Event> run(InvocationContext invocationContext) {
461461

462462
private Flowable<Event> run(
463463
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
464-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
464+
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext);
465+
466+
Flowable<Event> processedEvents =
467+
currentStepEvents
468+
.concatMap(
469+
event ->
470+
invocationContext
471+
.sessionService()
472+
.appendEvent(invocationContext.session(), event)
473+
.flatMap(
474+
registeredEvent ->
475+
invocationContext
476+
.pluginManager()
477+
.onEventCallback(invocationContext, registeredEvent)
478+
.defaultIfEmpty(registeredEvent))
479+
.toFlowable())
480+
.cache();
481+
465482
if (stepsCompleted + 1 >= maxSteps) {
466483
logger.debug("Ending flow execution because max steps reached.");
467-
return currentStepEvents;
484+
return processedEvents;
468485
}
469486

470-
return currentStepEvents.concatWith(
471-
currentStepEvents
487+
return processedEvents.concatWith(
488+
processedEvents
472489
.toList()
473490
.flatMapPublisher(
474491
eventList -> {

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -570,19 +570,29 @@ private Flowable<Event> runAgentWithUpdatedSession(
570570
.agent()
571571
.runAsync(contextWithUpdatedSession)
572572
.concatMap(
573-
agentEvent ->
574-
this.sessionService
575-
.appendEvent(updatedSession, agentEvent)
576-
.flatMap(
577-
registeredEvent -> {
578-
// TODO: remove this hack after deprecating runAsync with Session.
579-
copySessionStates(updatedSession, initialContext.session());
580-
return contextWithUpdatedSession
581-
.pluginManager()
582-
.onEventCallback(contextWithUpdatedSession, registeredEvent)
583-
.defaultIfEmpty(registeredEvent);
584-
})
585-
.toFlowable());
573+
agentEvent -> {
574+
// TODO: remove this hack after deprecating runAsync with Session.
575+
copySessionStates(updatedSession, initialContext.session());
576+
577+
// TODO: b/502182243 - Investigate if appendEvent should be made idempotent in
578+
// SessionService to avoid this check.
579+
if (agentEvent.id() != null
580+
&& updatedSession.events().stream()
581+
.anyMatch(e -> agentEvent.id().equals(e.id()))) {
582+
// Already appended (e.g. by BaseLlmFlow). Still apply the hack.
583+
return Flowable.just(agentEvent);
584+
}
585+
return this.sessionService
586+
.appendEvent(updatedSession, agentEvent)
587+
.flatMap(
588+
registeredEvent -> {
589+
return contextWithUpdatedSession
590+
.pluginManager()
591+
.onEventCallback(contextWithUpdatedSession, registeredEvent)
592+
.defaultIfEmpty(registeredEvent);
593+
})
594+
.toFlowable();
595+
});
586596

587597
// If beforeRunCallback returns content, emit it and skip agent
588598
Context capturedContext = Context.current();

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@
4646
import com.google.adk.artifacts.BaseArtifactService;
4747
import com.google.adk.events.Event;
4848
import com.google.adk.flows.llmflows.Functions;
49+
import com.google.adk.models.LlmRequest;
4950
import com.google.adk.models.LlmResponse;
5051
import com.google.adk.plugins.BasePlugin;
5152
import com.google.adk.sessions.BaseSessionService;
53+
import com.google.adk.sessions.GetSessionConfig;
54+
import com.google.adk.sessions.InMemorySessionService;
5255
import com.google.adk.sessions.Session;
5356
import com.google.adk.sessions.SessionKey;
5457
import com.google.adk.summarizer.EventsCompactionConfig;
@@ -588,12 +591,22 @@ public void onToolErrorCallback_error() {
588591
@Test
589592
public void onEventCallback_success() {
590593
when(plugin.onEventCallback(any(), any()))
591-
.thenReturn(Maybe.just(TestUtils.createEvent("form plugin")));
594+
.thenAnswer(
595+
invocation -> {
596+
Event event = invocation.getArgument(1);
597+
return Maybe.just(
598+
Event.builder()
599+
.id(event.id())
600+
.invocationId(event.invocationId())
601+
.author("model")
602+
.content(createContent("from plugin"))
603+
.build());
604+
});
592605

593606
List<Event> events =
594607
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
595608

596-
assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin");
609+
assertThat(simplifyEvents(events)).containsExactly("model: from plugin");
597610

598611
verify(plugin).onEventCallback(any(), any());
599612
}
@@ -1686,4 +1699,105 @@ public void runner_executesSaveArtifactFlow() {
16861699
// agent was run
16871700
assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm");
16881701
}
1702+
1703+
@Test
1704+
public void runAsync_ensuresSequentialConsistencyForTools() {
1705+
// Arrange
1706+
TestLlm testLlm =
1707+
createTestLlm(
1708+
createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")),
1709+
createTextLlmResponse("Final response"));
1710+
1711+
LlmAgent agent =
1712+
createTestAgentBuilder(testLlm)
1713+
.tools(
1714+
ImmutableList.of(
1715+
FunctionTool.create(RaceConditionTools.class, "tool1"),
1716+
FunctionTool.create(RaceConditionTools.class, "tool2")))
1717+
.build();
1718+
1719+
BaseSessionService delegate = new InMemorySessionService();
1720+
BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 100);
1721+
1722+
Runner runner =
1723+
Runner.builder()
1724+
.app(App.builder().name("test").rootAgent(agent).build())
1725+
.sessionService(delayedSessionService)
1726+
.build();
1727+
Session session = runner.sessionService().createSession("test", "user").blockingGet();
1728+
1729+
// Act
1730+
var unused =
1731+
runner
1732+
.runAsync("user", session.id(), Content.fromParts(Part.fromText("start")))
1733+
.toList()
1734+
.blockingGet();
1735+
1736+
// Assert
1737+
ImmutableList<LlmRequest> requests = ImmutableList.copyOf(testLlm.getRequests());
1738+
assertThat(requests).hasSize(2);
1739+
1740+
// Second request should contain the result of tool1
1741+
LlmRequest secondRequest = requests.get(1);
1742+
List<Content> history = secondRequest.contents();
1743+
1744+
boolean foundToolResponse =
1745+
history.stream()
1746+
.flatMap(content -> content.parts().stream().flatMap(List::stream))
1747+
.filter(part -> part.functionResponse().isPresent())
1748+
.map(part -> part.functionResponse().get())
1749+
.anyMatch(
1750+
response ->
1751+
response.name().orElse("").equals("tool1")
1752+
&& response
1753+
.response()
1754+
.orElse(null)
1755+
.equals(ImmutableMap.of("result", "result_value1")));
1756+
1757+
assertThat(foundToolResponse).isTrue();
1758+
}
1759+
1760+
private static BaseSessionService createDelayedSessionService(
1761+
BaseSessionService delegate, long delayMs) {
1762+
BaseSessionService delayedSessionService = mock(BaseSessionService.class);
1763+
when(delayedSessionService.createSession(anyString(), anyString(), any(), anyString()))
1764+
.thenAnswer(
1765+
inv ->
1766+
delegate.createSession(
1767+
inv.getArgument(0),
1768+
inv.getArgument(1),
1769+
inv.getArgument(2),
1770+
inv.getArgument(3)));
1771+
when(delayedSessionService.createSession(anyString(), anyString()))
1772+
.thenAnswer(
1773+
inv ->
1774+
delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1)));
1775+
when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any()))
1776+
.thenAnswer(
1777+
inv ->
1778+
delegate.getSession(
1779+
(String) inv.getArgument(0),
1780+
(String) inv.getArgument(1),
1781+
(String) inv.getArgument(2),
1782+
(Optional<GetSessionConfig>) inv.getArgument(3)));
1783+
when(delayedSessionService.appendEvent(any(), any()))
1784+
.thenAnswer(
1785+
inv ->
1786+
delegate
1787+
.appendEvent(inv.getArgument(0), inv.getArgument(1))
1788+
.delay(delayMs, MILLISECONDS));
1789+
return delayedSessionService;
1790+
}
1791+
1792+
public static class RaceConditionTools {
1793+
private RaceConditionTools() {}
1794+
1795+
public static ImmutableMap<String, Object> tool1(String arg) {
1796+
return ImmutableMap.of("result", "result_" + arg);
1797+
}
1798+
1799+
public static ImmutableMap<String, Object> tool2(String input) {
1800+
return ImmutableMap.of("status", "received_" + input);
1801+
}
1802+
}
16891803
}

0 commit comments

Comments
 (0)