Skip to content

Commit 32dc8e5

Browse files
r4ineemazas-google
authored andcommitted
Add confirmed function call to next LLM call
1 parent f49260e commit 32dc8e5

2 files changed

Lines changed: 59 additions & 3 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import io.reactivex.rxjava3.core.Single;
4343
import java.util.Collection;
4444
import java.util.HashMap;
45+
import java.util.List;
4546
import java.util.Map;
4647
import java.util.Objects;
4748
import java.util.Optional;
@@ -109,8 +110,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
109110
continue;
110111
}
111112

112-
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
113-
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
113+
final Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
114+
final Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
114115

115116
event.functionCalls().stream()
116117
.filter(
@@ -163,6 +164,16 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
163164
// Create an updated LlmRequest including the new event's content
164165
ImmutableList.Builder<Content> updatedContentsBuilder =
165166
ImmutableList.<Content>builder().addAll(llmRequest.contents());
167+
168+
final List<Part> functionCalls =
169+
toolsToResumeWithArgs.values().stream()
170+
.map(functionCall -> Part.builder().functionCall(functionCall).build())
171+
.collect(toImmutableList());
172+
Content functionCallContent =
173+
Content.builder().role("model").parts(functionCalls).build();
174+
// add function call
175+
updatedContentsBuilder.add(functionCallContent);
176+
// add function response
166177
assembledEvent.content().ifPresent(updatedContentsBuilder::add);
167178

168179
LlmRequest updatedLlmRequest =

core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,48 @@ public class RequestConfirmationLlmRequestProcessorTest {
100100
.build()))
101101
.build();
102102

103+
private static final Event FUNCTION_CALL_EVENT =
104+
Event.builder()
105+
.author("model")
106+
.content(
107+
Content.builder()
108+
.role("model")
109+
.parts(
110+
Part.builder()
111+
.functionCall(
112+
FunctionCall.builder()
113+
.id(ORIGINAL_FUNCTION_CALL_ID)
114+
.name(ECHO_TOOL_NAME)
115+
.args(ImmutableMap.of("say", "hello"))
116+
.build())
117+
.build())
118+
.build())
119+
.build();
120+
121+
private static final Event FUNCTION_RESPONSE_EVENT =
122+
Event.builder()
123+
.author("user")
124+
.content(
125+
Content.builder()
126+
.role("user")
127+
.parts(
128+
Part.builder()
129+
.functionResponse(
130+
FunctionResponse.builder()
131+
.id(ORIGINAL_FUNCTION_CALL_ID)
132+
.name(ECHO_TOOL_NAME)
133+
.response(
134+
ImmutableMap.of("result", ImmutableMap.of("say", "hello")))
135+
.build())
136+
.build())
137+
.build())
138+
.build();
139+
103140
private static final RequestConfirmationLlmRequestProcessor processor =
104141
new RequestConfirmationLlmRequestProcessor();
105142

106143
@Test
107-
public void runAsync_withConfirmation_callsOriginalFunction() {
144+
public void runAsync_withConfirmation_callsOriginalFunctionAndAppendsToUpdatedRequest() {
108145
LlmAgent agent = createAgentWithEchoTool();
109146
Session session =
110147
Session.builder("session_id")
@@ -124,6 +161,14 @@ public void runAsync_withConfirmation_callsOriginalFunction() {
124161
assertThat(fr.id()).hasValue(ORIGINAL_FUNCTION_CALL_ID);
125162
assertThat(fr.name()).hasValue(ECHO_TOOL_NAME);
126163
assertThat(fr.response()).hasValue(ImmutableMap.of("result", ImmutableMap.of("say", "hello")));
164+
assertThat(result.updatedRequest())
165+
.isEqualTo(
166+
LlmRequest.builder()
167+
.contents(
168+
ImmutableList.of(
169+
FUNCTION_CALL_EVENT.content().get(),
170+
FUNCTION_RESPONSE_EVENT.content().get()))
171+
.build());
127172
}
128173

129174
@Test

0 commit comments

Comments
 (0)