Skip to content

Commit 2f61aa1

Browse files
committed
Add confirmed function call to next LLM call
1 parent 9611f89 commit 2f61aa1

2 files changed

Lines changed: 59 additions & 3 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import io.reactivex.rxjava3.core.Single;
4343
import java.util.Collection;
4444
import java.util.HashMap;
45+
import java.util.List;
4546
import java.util.Map;
4647
import java.util.Objects;
4748
import java.util.Optional;
@@ -109,8 +110,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
109110
continue;
110111
}
111112

112-
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
113-
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
113+
final Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
114+
final Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
114115

115116
event.functionCalls().stream()
116117
.filter(
@@ -163,6 +164,16 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
163164
// Create an updated LlmRequest including the new event's content
164165
ImmutableList.Builder<Content> updatedContentsBuilder =
165166
ImmutableList.<Content>builder().addAll(llmRequest.contents());
167+
168+
final List<Part> functionCalls =
169+
toolsToResumeWithArgs.values().stream()
170+
.map(functionCall -> Part.builder().functionCall(functionCall).build())
171+
.collect(toImmutableList());
172+
Content functionCallContent =
173+
Content.builder().role("model").parts(functionCalls).build();
174+
// add function call
175+
updatedContentsBuilder.add(functionCallContent);
176+
// add function response
166177
assembledEvent.content().ifPresent(updatedContentsBuilder::add);
167178

168179
LlmRequest updatedLlmRequest =

core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,48 @@ public class RequestConfirmationLlmRequestProcessorTest {
102102
.build()))
103103
.build();
104104

105+
private static final Event FUNCTION_CALL_EVENT =
106+
Event.builder()
107+
.author("model")
108+
.content(
109+
Content.builder()
110+
.role("model")
111+
.parts(
112+
Part.builder()
113+
.functionCall(
114+
FunctionCall.builder()
115+
.id(ORIGINAL_FUNCTION_CALL_ID)
116+
.name(ECHO_TOOL_NAME)
117+
.args(ImmutableMap.of("say", "hello"))
118+
.build())
119+
.build())
120+
.build())
121+
.build();
122+
123+
private static final Event FUNCTION_RESPONSE_EVENT =
124+
Event.builder()
125+
.author("user")
126+
.content(
127+
Content.builder()
128+
.role("user")
129+
.parts(
130+
Part.builder()
131+
.functionResponse(
132+
FunctionResponse.builder()
133+
.id(ORIGINAL_FUNCTION_CALL_ID)
134+
.name(ECHO_TOOL_NAME)
135+
.response(
136+
ImmutableMap.of("result", ImmutableMap.of("say", "hello")))
137+
.build())
138+
.build())
139+
.build())
140+
.build();
141+
105142
private static final RequestConfirmationLlmRequestProcessor processor =
106143
new RequestConfirmationLlmRequestProcessor();
107144

108145
@Test
109-
public void runAsync_withConfirmation_callsOriginalFunction() {
146+
public void runAsync_withConfirmation_callsOriginalFunctionAndAppendsToUpdatedRequest() {
110147
LlmAgent agent = createAgentWithEchoTool();
111148
Session session =
112149
Session.builder("session_id")
@@ -126,6 +163,14 @@ public void runAsync_withConfirmation_callsOriginalFunction() {
126163
assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID);
127164
assertThat(fr.name()).hasValue(ECHO_TOOL_NAME);
128165
assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello")));
166+
assertThat(result.updatedRequest())
167+
.isEqualTo(
168+
LlmRequest.builder()
169+
.contents(
170+
ImmutableList.of(
171+
FUNCTION_CALL_EVENT.content().get(),
172+
FUNCTION_RESPONSE_EVENT.content().get()))
173+
.build());
129174
}
130175

131176
@Test

0 commit comments

Comments
 (0)