Skip to content

Commit 8808643

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(HITL): Let ADK resume after HITL approval is present
feat(HITL): Declining a proposal now correctly intercepts the run fix: Events for HITL are now emitted correctly fix: HITL endless loop when asking for approvals PiperOrigin-RevId: 839858592
1 parent 007e938 commit 8808643

5 files changed

Lines changed: 228 additions & 66 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
public final class Functions {
6565

6666
private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
67-
static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
67+
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
6868
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
6969

7070
/** Generates a unique ID for a function call. */
@@ -147,12 +147,22 @@ public static Maybe<Event> handleFunctionCalls(
147147
Function<FunctionCall, Maybe<Event>> functionCallMapper =
148148
functionCall -> {
149149
BaseTool tool = tools.get(functionCall.name().get());
150+
ToolConfirmation toolConfirmation = toolConfirmations.get(functionCall.id().orElse(null));
150151
ToolContext toolContext =
151152
ToolContext.builder(invocationContext)
152153
.functionCallId(functionCall.id().orElse(""))
153-
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
154+
.toolConfirmation(toolConfirmation)
154155
.build();
155156

157+
if (toolConfirmation != null && !toolConfirmation.confirmed()) {
158+
return Maybe.just(
159+
buildResponseEvent(
160+
tool,
161+
ImmutableMap.of("error", "User declined tool execution for " + tool.name()),
162+
toolContext,
163+
invocationContext));
164+
}
165+
156166
Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());
157167

158168
Maybe<Map<String, Object>> maybeFunctionResult =
@@ -241,6 +251,18 @@ public static Maybe<Event> handleFunctionCalls(
241251
*/
242252
public static Maybe<Event> handleFunctionCallsLive(
243253
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
254+
return handleFunctionCallsLive(invocationContext, functionCallEvent, tools, ImmutableMap.of());
255+
}
256+
257+
/**
258+
* Handles function calls in a live/streaming context with tool confirmations, supporting
259+
* background execution and stream termination.
260+
*/
261+
public static Maybe<Event> handleFunctionCallsLive(
262+
InvocationContext invocationContext,
263+
Event functionCallEvent,
264+
Map<String, BaseTool> tools,
265+
Map<String, ToolConfirmation> toolConfirmations) {
244266
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
245267

246268
for (FunctionCall functionCall : functionCalls) {
@@ -255,7 +277,9 @@ public static Maybe<Event> handleFunctionCallsLive(
255277
ToolContext toolContext =
256278
ToolContext.builder(invocationContext)
257279
.functionCallId(functionCall.id().orElse(""))
280+
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
258281
.build();
282+
259283
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
260284

261285
Maybe<Map<String, Object>> maybeFunctionResult =

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 96 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME;
2020
import static com.google.common.collect.ImmutableList.toImmutableList;
2121
import static com.google.common.collect.ImmutableMap.toImmutableMap;
22+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
2223

2324
import com.fasterxml.jackson.core.JsonProcessingException;
2425
import com.fasterxml.jackson.databind.ObjectMapper;
25-
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
26+
import com.google.adk.JsonBaseModel;
2627
import com.google.adk.agents.InvocationContext;
2728
import com.google.adk.agents.LlmAgent;
2829
import com.google.adk.events.Event;
@@ -31,14 +32,15 @@
3132
import com.google.adk.tools.ToolConfirmation;
3233
import com.google.common.collect.ImmutableList;
3334
import com.google.common.collect.ImmutableMap;
35+
import com.google.common.collect.ImmutableSet;
3436
import com.google.genai.types.Content;
3537
import com.google.genai.types.FunctionCall;
3638
import com.google.genai.types.FunctionResponse;
3739
import com.google.genai.types.Part;
3840
import io.reactivex.rxjava3.core.Maybe;
3941
import io.reactivex.rxjava3.core.Single;
4042
import java.util.Collection;
41-
import java.util.List;
43+
import java.util.HashMap;
4244
import java.util.Map;
4345
import java.util.Objects;
4446
import java.util.Optional;
@@ -49,59 +51,123 @@
4951
public class RequestConfirmationLlmRequestProcessor implements RequestProcessor {
5052
private static final Logger logger =
5153
LoggerFactory.getLogger(RequestConfirmationLlmRequestProcessor.class);
52-
private final ObjectMapper objectMapper;
53-
54-
public RequestConfirmationLlmRequestProcessor() {
55-
objectMapper = new ObjectMapper().registerModule(new Jdk8Module());
56-
}
54+
private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel.getMapper();
5755

5856
@Override
5957
public Single<RequestProcessor.RequestProcessingResult> processRequest(
6058
InvocationContext invocationContext, LlmRequest llmRequest) {
61-
List<Event> events = invocationContext.session().events();
59+
ImmutableList<Event> events = ImmutableList.copyOf(invocationContext.session().events());
6260
if (events.isEmpty()) {
6361
logger.info(
6462
"No events are present in the session. Skipping request confirmation processing.");
6563
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
6664
}
6765

68-
ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses =
69-
filterRequestConfirmationFunctionResponses(events);
66+
ImmutableMap<String, ToolConfirmation> responses = ImmutableMap.of();
67+
int confirmationEventIndex = -1;
68+
for (int i = events.size() - 1; i >= 0; i--) {
69+
Event event = events.get(i);
70+
if (!Objects.equals(event.author(), "user")) {
71+
continue;
72+
}
73+
if (event.functionResponses().isEmpty()) {
74+
continue;
75+
}
76+
responses =
77+
event.functionResponses().stream()
78+
.filter(functionResponse -> functionResponse.id().isPresent())
79+
.filter(
80+
functionResponse ->
81+
Objects.equals(
82+
functionResponse.name().orElse(null),
83+
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
84+
.map(this::maybeCreateToolConfirmationEntry)
85+
.flatMap(Optional::stream)
86+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
87+
confirmationEventIndex = i;
88+
break;
89+
}
90+
91+
// Make it final to enable access from lambda expressions.
92+
final ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses = responses;
93+
7094
if (requestConfirmationFunctionResponses.isEmpty()) {
7195
logger.info("No request confirmation function responses found.");
7296
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
7397
}
7498

75-
for (ImmutableList<FunctionCall> functionCalls :
76-
events.stream()
77-
.map(Event::functionCalls)
78-
.filter(fc -> !fc.isEmpty())
79-
.collect(toImmutableList())) {
99+
for (int i = events.size() - 2; i >= 0; i--) {
100+
Event event = events.get(i);
101+
if (event.functionCalls().isEmpty()) {
102+
continue;
103+
}
104+
105+
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
106+
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
107+
108+
event.functionCalls().stream()
109+
.filter(
110+
fc ->
111+
fc.id().isPresent()
112+
&& requestConfirmationFunctionResponses.containsKey(fc.id().get()))
113+
.forEach(
114+
fc ->
115+
getOriginalFunctionCall(fc)
116+
.ifPresent(
117+
ofc -> {
118+
toolsToResumeWithConfirmation.put(
119+
ofc.id().get(),
120+
requestConfirmationFunctionResponses.get(fc.id().get()));
121+
toolsToResumeWithArgs.put(ofc.id().get(), ofc);
122+
}));
123+
124+
if (toolsToResumeWithConfirmation.isEmpty()) {
125+
continue;
126+
}
127+
128+
// Remove the tools that have already been confirmed.
129+
ImmutableSet<String> alreadyConfirmedIds =
130+
events.subList(confirmationEventIndex + 1, events.size()).stream()
131+
.flatMap(e -> e.functionResponses().stream())
132+
.map(FunctionResponse::id)
133+
.flatMap(Optional::stream)
134+
.collect(toImmutableSet());
135+
toolsToResumeWithConfirmation.keySet().removeAll(alreadyConfirmedIds);
136+
toolsToResumeWithArgs.keySet().removeAll(alreadyConfirmedIds);
80137

81-
ImmutableMap<String, FunctionCall> toolsToResumeWithArgs =
82-
filterToolsToResumeWithArgs(functionCalls, requestConfirmationFunctionResponses);
83-
ImmutableMap<String, ToolConfirmation> toolsToResumeWithConfirmation =
84-
toolsToResumeWithArgs.keySet().stream()
85-
.filter(
86-
id ->
87-
events.stream()
88-
.flatMap(e -> e.functionResponses().stream())
89-
.anyMatch(fr -> Objects.equals(fr.id().orElse(null), id)))
90-
.collect(toImmutableMap(k -> k, requestConfirmationFunctionResponses::get));
91138
if (toolsToResumeWithConfirmation.isEmpty()) {
92-
logger.info("No tools to resume with confirmation.");
93139
continue;
94140
}
95141

96142
return assembleEvent(
97-
invocationContext, toolsToResumeWithArgs.values(), toolsToResumeWithConfirmation)
98-
.map(event -> RequestProcessingResult.create(llmRequest, ImmutableList.of(event)))
143+
invocationContext,
144+
toolsToResumeWithArgs.values(),
145+
ImmutableMap.copyOf(toolsToResumeWithConfirmation))
146+
.map(e -> RequestProcessingResult.create(llmRequest, ImmutableList.of(e)))
99147
.toSingle();
100148
}
101149

102150
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
103151
}
104152

153+
private Optional<FunctionCall> getOriginalFunctionCall(FunctionCall functionCall) {
154+
if (!functionCall.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall")) {
155+
return Optional.empty();
156+
}
157+
try {
158+
FunctionCall originalFunctionCall =
159+
OBJECT_MAPPER.convertValue(
160+
functionCall.args().get().get("originalFunctionCall"), FunctionCall.class);
161+
if (originalFunctionCall.id().isEmpty()) {
162+
return Optional.empty();
163+
}
164+
return Optional.of(originalFunctionCall);
165+
} catch (IllegalArgumentException e) {
166+
logger.warn("Failed to convert originalFunctionCall argument.", e);
167+
return Optional.empty();
168+
}
169+
}
170+
105171
private Maybe<Event> assembleEvent(
106172
InvocationContext invocationContext,
107173
Collection<FunctionCall> functionCalls,
@@ -128,58 +194,26 @@ private Maybe<Event> assembleEvent(
128194
invocationContext, functionCallEvent, toolsBuilder.buildOrThrow(), toolConfirmations);
129195
}
130196

131-
private ImmutableMap<String, ToolConfirmation> filterRequestConfirmationFunctionResponses(
132-
List<Event> events) {
133-
return events.stream()
134-
.filter(event -> Objects.equals(event.author(), "user"))
135-
.flatMap(event -> event.functionResponses().stream())
136-
.filter(functionResponse -> functionResponse.id().isPresent())
137-
.filter(
138-
functionResponse ->
139-
Objects.equals(
140-
functionResponse.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
141-
.map(this::maybeCreateToolConfirmationEntry)
142-
.flatMap(Optional::stream)
143-
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
144-
}
145-
146197
private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmationEntry(
147198
FunctionResponse functionResponse) {
148199
Map<String, Object> responseMap = functionResponse.response().orElse(ImmutableMap.of());
149200
if (responseMap.size() != 1 || !responseMap.containsKey("response")) {
150201
return Optional.of(
151202
Map.entry(
152203
functionResponse.id().get(),
153-
objectMapper.convertValue(responseMap, ToolConfirmation.class)));
204+
OBJECT_MAPPER.convertValue(responseMap, ToolConfirmation.class)));
154205
}
155206

156207
try {
157208
return Optional.of(
158209
Map.entry(
159210
functionResponse.id().get(),
160-
objectMapper.readValue(
211+
OBJECT_MAPPER.readValue(
161212
(String) responseMap.get("response"), ToolConfirmation.class)));
162213
} catch (JsonProcessingException e) {
163214
logger.error("Failed to parse tool confirmation response", e);
164215
}
165216

166217
return Optional.empty();
167218
}
168-
169-
private ImmutableMap<String, FunctionCall> filterToolsToResumeWithArgs(
170-
ImmutableList<FunctionCall> functionCalls,
171-
Map<String, ToolConfirmation> requestConfirmationFunctionResponses) {
172-
return functionCalls.stream()
173-
.filter(fc -> fc.id().isPresent())
174-
.filter(fc -> requestConfirmationFunctionResponses.containsKey(fc.id().get()))
175-
.filter(
176-
fc -> Objects.equals(fc.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
177-
.filter(fc -> fc.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall"))
178-
.collect(
179-
toImmutableMap(
180-
fc -> fc.id().get(),
181-
fc ->
182-
objectMapper.convertValue(
183-
fc.args().get().get("originalFunctionCall"), FunctionCall.class)));
184-
}
185219
}

core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class SingleFlow extends BaseLlmFlow {
3131
new Identity(),
3232
new Contents(),
3333
new Examples(),
34+
new RequestConfirmationLlmRequestProcessor(),
3435
CodeExecution.requestProcessor);
3536

3637
protected static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS =

core/src/main/java/com/google/adk/tools/ToolContext.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,20 @@ public Builder toBuilder() {
127127
.toolConfirmation(toolConfirmation.orElse(null));
128128
}
129129

130+
@Override
131+
public String toString() {
132+
return "ToolContext{"
133+
+ "invocationContext="
134+
+ invocationContext
135+
+ ", eventActions="
136+
+ eventActions
137+
+ ", functionCallId="
138+
+ functionCallId
139+
+ ", toolConfirmation="
140+
+ toolConfirmation
141+
+ '}';
142+
}
143+
130144
/** Builder for {@link ToolContext}. */
131145
public static final class Builder {
132146
private final InvocationContext invocationContext;

0 commit comments

Comments
 (0)