Skip to content

Commit 2e52d02

Browse files
damianmomotgooglecopybara-github
authored andcommitted
fix: route HITL confirmation back to originating sub-agent in workflow agents
When an LlmAgent that uses a tool requiring Human-in-the-Loop confirmation is wrapped inside a non-LlmAgent workflow agent (e.g. SequentialAgent, ParallelAgent, LoopAgent), the runner used to fall back to the root agent on confirmation resumption. This caused 'VerifyException: Tool not found' because the root agent does not have the sub-agent's tools registered. Runner.findAgentToRun now first checks whether the last event is a function response and, if so, routes it back to the agent that emitted the matching function call (looked up by id), regardless of whether that agent's parent chain is fully transferable. This mirrors the Python ADK behaviour in Runner._find_agent_to_run via find_matching_function_call. PiperOrigin-RevId: 921354317
1 parent e6fe9aa commit 2e52d02

2 files changed

Lines changed: 148 additions & 5 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package com.google.adk.runner;
1818

19+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
20+
1921
import com.google.adk.agents.ActiveStreamingTool;
2022
import com.google.adk.agents.BaseAgent;
2123
import com.google.adk.agents.ContextCacheConfig;
@@ -45,6 +47,8 @@
4547
import com.google.adk.utils.CollectionUtils;
4648
import com.google.common.base.Preconditions;
4749
import com.google.common.collect.ImmutableList;
50+
import com.google.common.collect.ImmutableSet;
51+
import com.google.common.collect.Iterables;
4852
import com.google.common.collect.MapMaker;
4953
import com.google.errorprone.annotations.CanIgnoreReturnValue;
5054
import com.google.genai.types.AudioTranscriptionConfig;
@@ -64,6 +68,7 @@
6468
import java.util.Collections;
6569
import java.util.List;
6670
import java.util.Map;
71+
import java.util.Objects;
6772
import java.util.Optional;
6873
import java.util.concurrent.ConcurrentHashMap;
6974
import java.util.concurrent.ConcurrentMap;
@@ -772,12 +777,15 @@ private boolean isTransferableAcrossAgentTree(BaseAgent agentToRun) {
772777
return true;
773778
}
774779

775-
/**
776-
* Returns the agent that should handle the next request based on session history.
777-
*
778-
* @return agent to run.
779-
*/
780+
/** Returns the agent that should handle the next request based on session history. */
780781
private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
782+
// Route function responses back to the originating function-call author so HITL tool
783+
// confirmations resume the sub-agent even through non-LlmAgent ancestors.
784+
Optional<BaseAgent> functionCallAuthor = findFunctionCallAuthor(session, rootAgent);
785+
if (functionCallAuthor.isPresent()) {
786+
return functionCallAuthor.get();
787+
}
788+
781789
List<Event> events = new ArrayList<>(session.events());
782790
Collections.reverse(events);
783791

@@ -808,6 +816,39 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
808816
return rootAgent;
809817
}
810818

819+
/**
820+
* If the last event is a function response, returns the agent that emitted the matching function
821+
* call (by id), or empty if no match is found in the agent tree.
822+
*/
823+
private static Optional<BaseAgent> findFunctionCallAuthor(Session session, BaseAgent rootAgent) {
824+
List<Event> events = session.events();
825+
if (events.isEmpty()) {
826+
return Optional.empty();
827+
}
828+
ImmutableSet<String> functionResponseIds =
829+
Iterables.getLast(events).functionResponses().stream()
830+
.map(fr -> fr.id().orElse(null))
831+
.filter(Objects::nonNull)
832+
.collect(toImmutableSet());
833+
834+
// Iterate in reverse to prefer the most recent matching call, mirroring Python ADK's
835+
// find_event_by_function_call_id. Function call IDs are unique in normal flows, so this
836+
// is defense-in-depth and not covered by mutation testing.
837+
List<Event> precedingEvents = new ArrayList<>(events.subList(0, events.size() - 1));
838+
Collections.reverse(precedingEvents);
839+
for (Event event : precedingEvents) {
840+
boolean matches =
841+
event.functionCalls().stream()
842+
.map(fc -> fc.id().orElse(null))
843+
.filter(Objects::nonNull)
844+
.anyMatch(functionResponseIds::contains);
845+
if (matches && event.author() != null) {
846+
return rootAgent.findAgent(event.author());
847+
}
848+
}
849+
return Optional.empty();
850+
}
851+
811852
private void addActiveStreamingTools(InvocationContext invocationContext, List<BaseTool> tools) {
812853
tools.stream()
813854
.filter(FunctionTool.class::isInstance)

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.google.adk.agents.LiveRequestQueue;
4343
import com.google.adk.agents.LlmAgent;
4444
import com.google.adk.agents.RunConfig;
45+
import com.google.adk.agents.SequentialAgent;
4546
import com.google.adk.apps.App;
4647
import com.google.adk.artifacts.BaseArtifactService;
4748
import com.google.adk.events.Event;
@@ -1604,6 +1605,107 @@ public void runAsync_withToolConfirmation() {
16041605
.inOrder();
16051606
}
16061607

1608+
// HITL tool confirmation must resume the originating sub-agent even when wrapped inside a
1609+
// non-LlmAgent workflow agent (e.g. SequentialAgent).
1610+
@Test
1611+
public void runAsync_withToolConfirmation_inSequentialAgentSubAgent_resumesSubAgent() {
1612+
TestLlm childTestLlm =
1613+
createTestLlm(
1614+
createFunctionCallLlmResponse(
1615+
"tool_call_id", "echoTool", ImmutableMap.of("message", "hello")),
1616+
createTextLlmResponse("Response after observing tool needs confirmation."),
1617+
createTextLlmResponse("Response after user confirmed."));
1618+
LlmAgent childAgent =
1619+
createTestAgentBuilder(childTestLlm)
1620+
.name("child_agent")
1621+
.tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true))
1622+
.build();
1623+
SequentialAgent workflowAgent =
1624+
SequentialAgent.builder()
1625+
.name("workflow_agent")
1626+
.subAgents(ImmutableList.of(childAgent))
1627+
.build();
1628+
// Root transfers to workflow_agent to mirror the bug report's control flow.
1629+
TestLlm rootTestLlm =
1630+
createTestLlm(
1631+
createLlmResponse(
1632+
Content.fromParts(
1633+
Part.fromFunctionCall(
1634+
"transfer_to_agent", ImmutableMap.of("agent_name", "workflow_agent")))));
1635+
LlmAgent rootAgent =
1636+
createTestAgentBuilder(rootTestLlm)
1637+
.name("root_agent")
1638+
.subAgents(ImmutableList.of(workflowAgent))
1639+
.build();
1640+
Runner runner =
1641+
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
1642+
Session session = runner.sessionService().createSession("test", "user").blockingGet();
1643+
1644+
List<Event> eventsBeforeConfirmation =
1645+
runner
1646+
.runAsync("user", session.id(), Content.fromParts(Part.fromText("from user")))
1647+
.toList()
1648+
.blockingGet();
1649+
FunctionCall askUserConfirmationFunctionCall =
1650+
Iterables.getOnlyElement(
1651+
eventsBeforeConfirmation.stream()
1652+
.map(Functions::getAskUserConfirmationFunctionCalls)
1653+
.filter(functionCalls -> !functionCalls.isEmpty())
1654+
.findFirst()
1655+
.get());
1656+
List<Event> eventsAfterConfirmation =
1657+
runner
1658+
.runAsync(
1659+
"user",
1660+
session.id(),
1661+
Content.fromParts(
1662+
Part.builder()
1663+
.functionResponse(
1664+
FunctionResponse.builder()
1665+
.id(askUserConfirmationFunctionCall.id().get())
1666+
.name(askUserConfirmationFunctionCall.name().get())
1667+
.response(ImmutableMap.of("confirmed", true)))
1668+
.build()))
1669+
.toList()
1670+
.blockingGet();
1671+
1672+
// The originating child agent (not the root agent) must execute the tool.
1673+
assertThat(simplifyEvents(eventsAfterConfirmation))
1674+
.containsExactly(
1675+
"child_agent: FunctionResponse(name=echoTool, response={message=hello})",
1676+
"child_agent: Response after user confirmed.")
1677+
.inOrder();
1678+
}
1679+
1680+
// Orphan function responses (id not matching any prior call) should fall back to the root agent.
1681+
@Test
1682+
public void runAsync_withFunctionResponseNotMatchingAnyCall_fallsBackToRootAgent() {
1683+
TestLlm rootLlm = createTestLlm(createTextLlmResponse("after function response"));
1684+
LlmAgent rootAgent = createTestAgentBuilder(rootLlm).name("root_agent").build();
1685+
Runner runner =
1686+
Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build();
1687+
Session session = runner.sessionService().createSession("test", "user").blockingGet();
1688+
1689+
// Function response with id that does not match any prior function call.
1690+
List<Event> events =
1691+
runner
1692+
.runAsync(
1693+
"user",
1694+
session.id(),
1695+
Content.fromParts(
1696+
Part.builder()
1697+
.functionResponse(
1698+
FunctionResponse.builder()
1699+
.id("non_existent_id")
1700+
.name("orphanFn")
1701+
.response(ImmutableMap.of("x", 1)))
1702+
.build()))
1703+
.toList()
1704+
.blockingGet();
1705+
1706+
assertThat(simplifyEvents(events)).containsExactly("root_agent: after function response");
1707+
}
1708+
16071709
@Test
16081710
public void close_closesPluginsAndCodeExecutors() {
16091711
BasePlugin plugin = mockPlugin("close_test_plugin");

0 commit comments

Comments
 (0)