Skip to content

Commit 03e77a7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(HITL): Let ADK resume after HITL approval is present
feat(HITL): Declining a proposal now correctly intercepts the run fix: Events for HITL are now emitted correctly fix: HITL endless loop when asking for approvals PiperOrigin-RevId: 839858592
1 parent 852ddf6 commit 03e77a7

8 files changed

Lines changed: 297 additions & 106 deletions

File tree

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,8 @@ public IncludeContents includeContents() {
736736
return includeContents;
737737
}
738738

739-
public List<BaseTool> tools() {
740-
return canonicalTools().toList().blockingGet();
739+
public Single<List<BaseTool>> tools() {
740+
return canonicalTools().toList();
741741
}
742742

743743
public List<Object> toolsUnion() {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
public final class Functions {
6565

6666
private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
67-
static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
67+
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
6868
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
6969

7070
/** Generates a unique ID for a function call. */
@@ -147,12 +147,22 @@ public static Maybe<Event> handleFunctionCalls(
147147
Function<FunctionCall, Maybe<Event>> functionCallMapper =
148148
functionCall -> {
149149
BaseTool tool = tools.get(functionCall.name().get());
150+
ToolConfirmation toolConfirmation = toolConfirmations.get(functionCall.id().orElse(null));
150151
ToolContext toolContext =
151152
ToolContext.builder(invocationContext)
152153
.functionCallId(functionCall.id().orElse(""))
153-
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
154+
.toolConfirmation(toolConfirmation)
154155
.build();
155156

157+
if (toolConfirmation != null && !toolConfirmation.confirmed()) {
158+
return Maybe.just(
159+
buildResponseEvent(
160+
tool,
161+
ImmutableMap.of("error", "User declined tool execution for " + tool.name()),
162+
toolContext,
163+
invocationContext));
164+
}
165+
156166
Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());
157167

158168
Maybe<Map<String, Object>> maybeFunctionResult =
@@ -241,6 +251,18 @@ public static Maybe<Event> handleFunctionCalls(
241251
*/
242252
public static Maybe<Event> handleFunctionCallsLive(
243253
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
254+
return handleFunctionCallsLive(invocationContext, functionCallEvent, tools, ImmutableMap.of());
255+
}
256+
257+
/**
258+
* Handles function calls in a live/streaming context with tool confirmations, supporting
259+
* background execution and stream termination.
260+
*/
261+
public static Maybe<Event> handleFunctionCallsLive(
262+
InvocationContext invocationContext,
263+
Event functionCallEvent,
264+
Map<String, BaseTool> tools,
265+
Map<String, ToolConfirmation> toolConfirmations) {
244266
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
245267

246268
for (FunctionCall functionCall : functionCalls) {
@@ -255,7 +277,9 @@ public static Maybe<Event> handleFunctionCallsLive(
255277
ToolContext toolContext =
256278
ToolContext.builder(invocationContext)
257279
.functionCallId(functionCall.id().orElse(""))
280+
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
258281
.build();
282+
259283
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
260284

261285
Maybe<Map<String, Object>> maybeFunctionResult =

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 110 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME;
2020
import static com.google.common.collect.ImmutableList.toImmutableList;
2121
import static com.google.common.collect.ImmutableMap.toImmutableMap;
22+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
2223

2324
import com.fasterxml.jackson.core.JsonProcessingException;
2425
import com.fasterxml.jackson.databind.ObjectMapper;
25-
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
26+
import com.google.adk.JsonBaseModel;
2627
import com.google.adk.agents.InvocationContext;
2728
import com.google.adk.agents.LlmAgent;
2829
import com.google.adk.events.Event;
@@ -31,86 +32,157 @@
3132
import com.google.adk.tools.ToolConfirmation;
3233
import com.google.common.collect.ImmutableList;
3334
import com.google.common.collect.ImmutableMap;
35+
import com.google.common.collect.ImmutableSet;
3436
import com.google.genai.types.Content;
3537
import com.google.genai.types.FunctionCall;
3638
import com.google.genai.types.FunctionResponse;
3739
import com.google.genai.types.Part;
3840
import io.reactivex.rxjava3.core.Maybe;
3941
import io.reactivex.rxjava3.core.Single;
4042
import java.util.Collection;
41-
import java.util.List;
43+
import java.util.HashMap;
4244
import java.util.Map;
4345
import java.util.Objects;
4446
import java.util.Optional;
4547
import org.slf4j.Logger;
4648
import org.slf4j.LoggerFactory;
4749

4850
/** Handles tool confirmation information to build the LLM request. */
51+
// TODO: b/469096654 - refactor loop counters into functional style.
4952
public class RequestConfirmationLlmRequestProcessor implements RequestProcessor {
5053
private static final Logger logger =
5154
LoggerFactory.getLogger(RequestConfirmationLlmRequestProcessor.class);
52-
private final ObjectMapper objectMapper;
53-
54-
public RequestConfirmationLlmRequestProcessor() {
55-
objectMapper = new ObjectMapper().registerModule(new Jdk8Module());
56-
}
55+
private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel.getMapper();
5756

5857
@Override
5958
public Single<RequestProcessor.RequestProcessingResult> processRequest(
6059
InvocationContext invocationContext, LlmRequest llmRequest) {
61-
List<Event> events = invocationContext.session().events();
60+
ImmutableList<Event> events = ImmutableList.copyOf(invocationContext.session().events());
6261
if (events.isEmpty()) {
6362
logger.info(
6463
"No events are present in the session. Skipping request confirmation processing.");
6564
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
6665
}
6766

68-
ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses =
69-
filterRequestConfirmationFunctionResponses(events);
67+
ImmutableMap<String, ToolConfirmation> responses = ImmutableMap.of();
68+
int confirmationEventIndex = -1;
69+
for (int i = events.size() - 1; i >= 0; i--) {
70+
Event event = events.get(i);
71+
if (!Objects.equals(event.author(), "user")) {
72+
continue;
73+
}
74+
if (event.functionResponses().isEmpty()) {
75+
continue;
76+
}
77+
responses =
78+
event.functionResponses().stream()
79+
.filter(functionResponse -> functionResponse.id().isPresent())
80+
.filter(
81+
functionResponse ->
82+
Objects.equals(
83+
functionResponse.name().orElse(null),
84+
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
85+
.map(this::maybeCreateToolConfirmationEntry)
86+
.flatMap(Optional::stream)
87+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
88+
confirmationEventIndex = i;
89+
break;
90+
}
91+
92+
// Make it final to enable access from lambda expressions.
93+
final ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses = responses;
94+
7095
if (requestConfirmationFunctionResponses.isEmpty()) {
7196
logger.info("No request confirmation function responses found.");
7297
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
7398
}
7499

75-
for (ImmutableList<FunctionCall> functionCalls :
76-
events.stream()
77-
.map(Event::functionCalls)
78-
.filter(fc -> !fc.isEmpty())
79-
.collect(toImmutableList())) {
100+
for (int i = events.size() - 2; i >= 0; i--) {
101+
Event event = events.get(i);
102+
if (event.functionCalls().isEmpty()) {
103+
continue;
104+
}
105+
106+
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
107+
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
108+
109+
event.functionCalls().stream()
110+
.filter(
111+
fc ->
112+
fc.id().isPresent()
113+
&& requestConfirmationFunctionResponses.containsKey(fc.id().get()))
114+
.forEach(
115+
fc ->
116+
getOriginalFunctionCall(fc)
117+
.ifPresent(
118+
ofc -> {
119+
toolsToResumeWithConfirmation.put(
120+
ofc.id().get(),
121+
requestConfirmationFunctionResponses.get(fc.id().get()));
122+
toolsToResumeWithArgs.put(ofc.id().get(), ofc);
123+
}));
124+
125+
if (toolsToResumeWithConfirmation.isEmpty()) {
126+
continue;
127+
}
128+
129+
// Remove the tools that have already been confirmed.
130+
ImmutableSet<String> alreadyConfirmedIds =
131+
events.subList(confirmationEventIndex + 1, events.size()).stream()
132+
.flatMap(e -> e.functionResponses().stream())
133+
.map(FunctionResponse::id)
134+
.flatMap(Optional::stream)
135+
.collect(toImmutableSet());
136+
toolsToResumeWithConfirmation.keySet().removeAll(alreadyConfirmedIds);
137+
toolsToResumeWithArgs.keySet().removeAll(alreadyConfirmedIds);
80138

81-
ImmutableMap<String, FunctionCall> toolsToResumeWithArgs =
82-
filterToolsToResumeWithArgs(functionCalls, requestConfirmationFunctionResponses);
83-
ImmutableMap<String, ToolConfirmation> toolsToResumeWithConfirmation =
84-
toolsToResumeWithArgs.keySet().stream()
85-
.filter(
86-
id ->
87-
events.stream()
88-
.flatMap(e -> e.functionResponses().stream())
89-
.anyMatch(fr -> Objects.equals(fr.id().orElse(null), id)))
90-
.collect(toImmutableMap(k -> k, requestConfirmationFunctionResponses::get));
91139
if (toolsToResumeWithConfirmation.isEmpty()) {
92-
logger.info("No tools to resume with confirmation.");
93140
continue;
94141
}
95142

96143
return assembleEvent(
97-
invocationContext, toolsToResumeWithArgs.values(), toolsToResumeWithConfirmation)
98-
.map(event -> RequestProcessingResult.create(llmRequest, ImmutableList.of(event)))
144+
invocationContext,
145+
toolsToResumeWithArgs.values(),
146+
ImmutableMap.copyOf(toolsToResumeWithConfirmation))
147+
.map(e -> RequestProcessingResult.create(llmRequest, ImmutableList.of(e)))
99148
.toSingle();
100149
}
101150

102151
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
103152
}
104153

154+
private Optional<FunctionCall> getOriginalFunctionCall(FunctionCall functionCall) {
155+
if (!functionCall.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall")) {
156+
return Optional.empty();
157+
}
158+
try {
159+
FunctionCall originalFunctionCall =
160+
OBJECT_MAPPER.convertValue(
161+
functionCall.args().get().get("originalFunctionCall"), FunctionCall.class);
162+
if (originalFunctionCall.id().isEmpty()) {
163+
return Optional.empty();
164+
}
165+
return Optional.of(originalFunctionCall);
166+
} catch (IllegalArgumentException e) {
167+
logger.warn("Failed to convert originalFunctionCall argument.", e);
168+
return Optional.empty();
169+
}
170+
}
171+
105172
private Maybe<Event> assembleEvent(
106173
InvocationContext invocationContext,
107174
Collection<FunctionCall> functionCalls,
108175
Map<String, ToolConfirmation> toolConfirmations) {
109-
ImmutableMap.Builder<String, BaseTool> toolsBuilder = ImmutableMap.builder();
176+
Single<ImmutableMap<String, BaseTool>> toolsMapSingle;
110177
if (invocationContext.agent() instanceof LlmAgent llmAgent) {
111-
for (BaseTool tool : llmAgent.tools()) {
112-
toolsBuilder.put(tool.name(), tool);
113-
}
178+
toolsMapSingle =
179+
llmAgent
180+
.tools()
181+
.map(
182+
toolList ->
183+
toolList.stream().collect(toImmutableMap(BaseTool::name, tool -> tool)));
184+
} else {
185+
toolsMapSingle = Single.just(ImmutableMap.of());
114186
}
115187

116188
var functionCallEvent =
@@ -124,23 +196,10 @@ private Maybe<Event> assembleEvent(
124196
.build())
125197
.build();
126198

127-
return Functions.handleFunctionCalls(
128-
invocationContext, functionCallEvent, toolsBuilder.buildOrThrow(), toolConfirmations);
129-
}
130-
131-
private ImmutableMap<String, ToolConfirmation> filterRequestConfirmationFunctionResponses(
132-
List<Event> events) {
133-
return events.stream()
134-
.filter(event -> Objects.equals(event.author(), "user"))
135-
.flatMap(event -> event.functionResponses().stream())
136-
.filter(functionResponse -> functionResponse.id().isPresent())
137-
.filter(
138-
functionResponse ->
139-
Objects.equals(
140-
functionResponse.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
141-
.map(this::maybeCreateToolConfirmationEntry)
142-
.flatMap(Optional::stream)
143-
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
199+
return toolsMapSingle.flatMapMaybe(
200+
toolsMap ->
201+
Functions.handleFunctionCalls(
202+
invocationContext, functionCallEvent, toolsMap, toolConfirmations));
144203
}
145204

146205
private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmationEntry(
@@ -150,36 +209,19 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio
150209
return Optional.of(
151210
Map.entry(
152211
functionResponse.id().get(),
153-
objectMapper.convertValue(responseMap, ToolConfirmation.class)));
212+
OBJECT_MAPPER.convertValue(responseMap, ToolConfirmation.class)));
154213
}
155214

156215
try {
157216
return Optional.of(
158217
Map.entry(
159218
functionResponse.id().get(),
160-
objectMapper.readValue(
219+
OBJECT_MAPPER.readValue(
161220
(String) responseMap.get("response"), ToolConfirmation.class)));
162221
} catch (JsonProcessingException e) {
163222
logger.error("Failed to parse tool confirmation response", e);
164223
}
165224

166225
return Optional.empty();
167226
}
168-
169-
private ImmutableMap<String, FunctionCall> filterToolsToResumeWithArgs(
170-
ImmutableList<FunctionCall> functionCalls,
171-
Map<String, ToolConfirmation> requestConfirmationFunctionResponses) {
172-
return functionCalls.stream()
173-
.filter(fc -> fc.id().isPresent())
174-
.filter(fc -> requestConfirmationFunctionResponses.containsKey(fc.id().get()))
175-
.filter(
176-
fc -> Objects.equals(fc.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
177-
.filter(fc -> fc.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall"))
178-
.collect(
179-
toImmutableMap(
180-
fc -> fc.id().get(),
181-
fc ->
182-
objectMapper.convertValue(
183-
fc.args().get().get("originalFunctionCall"), FunctionCall.class)));
184-
}
185227
}

core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class SingleFlow extends BaseLlmFlow {
3131
new Identity(),
3232
new Contents(),
3333
new Examples(),
34+
new RequestConfirmationLlmRequestProcessor(),
3435
CodeExecution.requestProcessor);
3536

3637
protected static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS =

0 commit comments

Comments
 (0)