Skip to content

Commit 4c8ea3b

Browse files
committed
Add confirmed function to next LLM call
1 parent 03e77a7 commit 4c8ea3b

2 files changed

Lines changed: 79 additions & 4 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
import com.google.genai.types.Part;
4040
import io.reactivex.rxjava3.core.Maybe;
4141
import io.reactivex.rxjava3.core.Single;
42+
import java.util.ArrayList;
4243
import java.util.Collection;
4344
import java.util.HashMap;
45+
import java.util.List;
4446
import java.util.Map;
4547
import java.util.Objects;
4648
import java.util.Optional;
@@ -103,8 +105,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
103105
continue;
104106
}
105107

106-
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
107-
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
108+
final Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
109+
final Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
108110

109111
event.functionCalls().stream()
110112
.filter(
@@ -144,7 +146,35 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
144146
invocationContext,
145147
toolsToResumeWithArgs.values(),
146148
ImmutableMap.copyOf(toolsToResumeWithConfirmation))
147-
.map(e -> RequestProcessingResult.create(llmRequest, ImmutableList.of(e)))
149+
.map(
150+
responseEvent -> {
151+
final List<Content> updatedContent = new ArrayList<>(llmRequest.contents());
152+
final List<Part> functionCalls =
153+
toolsToResumeWithArgs.values().stream()
154+
.map(functionCall -> Part.builder().functionCall(functionCall).build())
155+
.collect(toImmutableList());
156+
157+
Content functionCallContent =
158+
Content.builder().role("model").parts(functionCalls).build();
159+
160+
// append function call to next LLM request
161+
updatedContent.add(functionCallContent);
162+
// append function response to next LLM request
163+
updatedContent.add(
164+
Content.builder()
165+
.role("user")
166+
.parts(
167+
responseEvent.functionResponses().stream()
168+
.map(
169+
functionResponse ->
170+
Part.builder().functionResponse(functionResponse).build())
171+
.collect(toImmutableList()))
172+
.build());
173+
174+
return RequestProcessingResult.create(
175+
llmRequest.toBuilder().contents(updatedContent).build(),
176+
ImmutableList.of(responseEvent));
177+
})
148178
.toSingle();
149179
}
150180

core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,48 @@ public class RequestConfirmationLlmRequestProcessorTest {
102102
.build()))
103103
.build();
104104

105+
private static final Event FUNCTION_CALL_EVENT =
106+
Event.builder()
107+
.author("model")
108+
.content(
109+
Content.builder()
110+
.role("model")
111+
.parts(
112+
Part.builder()
113+
.functionCall(
114+
FunctionCall.builder()
115+
.id(ORIGINAL_FUNCTION_CALL_ID)
116+
.name(ECHO_TOOL_NAME)
117+
.args(ImmutableMap.of("say", "hello"))
118+
.build())
119+
.build())
120+
.build())
121+
.build();
122+
123+
private static final Event FUNCTION_RESPONSE_EVENT =
124+
Event.builder()
125+
.author("user")
126+
.content(
127+
Content.builder()
128+
.role("user")
129+
.parts(
130+
Part.builder()
131+
.functionResponse(
132+
FunctionResponse.builder()
133+
.id(ORIGINAL_FUNCTION_CALL_ID)
134+
.name(ECHO_TOOL_NAME)
135+
.response(
136+
ImmutableMap.of("result", ImmutableMap.of("say", "hello")))
137+
.build())
138+
.build())
139+
.build())
140+
.build();
141+
105142
private static final RequestConfirmationLlmRequestProcessor processor =
106143
new RequestConfirmationLlmRequestProcessor();
107144

108145
@Test
109-
public void runAsync_withConfirmation_callsOriginalFunction() {
146+
public void runAsync_withConfirmation_callsOriginalFunctionAndAppendsToUpdatedRequest() {
110147
LlmAgent agent = createAgentWithEchoTool();
111148
Session session =
112149
Session.builder("session_id")
@@ -126,6 +163,14 @@ public void runAsync_withConfirmation_callsOriginalFunction() {
126163
assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID);
127164
assertThat(fr.name()).hasValue(ECHO_TOOL_NAME);
128165
assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello")));
166+
assertThat(result.updatedRequest())
167+
.isEqualTo(
168+
LlmRequest.builder()
169+
.contents(
170+
ImmutableList.of(
171+
FUNCTION_CALL_EVENT.content().get(),
172+
FUNCTION_RESPONSE_EVENT.content().get()))
173+
.build());
129174
}
130175

131176
@Test

0 commit comments

Comments
 (0)