1919import static com .google .adk .flows .llmflows .Functions .REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ;
2020import static com .google .common .collect .ImmutableList .toImmutableList ;
2121import static com .google .common .collect .ImmutableMap .toImmutableMap ;
22+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
2223
2324import com .fasterxml .jackson .core .JsonProcessingException ;
2425import com .fasterxml .jackson .databind .ObjectMapper ;
25- import com .fasterxml . jackson . datatype . jdk8 . Jdk8Module ;
26+ import com .google . adk . JsonBaseModel ;
2627import com .google .adk .agents .InvocationContext ;
2728import com .google .adk .agents .LlmAgent ;
2829import com .google .adk .events .Event ;
3132import com .google .adk .tools .ToolConfirmation ;
3233import com .google .common .collect .ImmutableList ;
3334import com .google .common .collect .ImmutableMap ;
35+ import com .google .common .collect .ImmutableSet ;
3436import com .google .genai .types .Content ;
3537import com .google .genai .types .FunctionCall ;
3638import com .google .genai .types .FunctionResponse ;
3739import com .google .genai .types .Part ;
3840import io .reactivex .rxjava3 .core .Maybe ;
3941import io .reactivex .rxjava3 .core .Single ;
4042import java .util .Collection ;
41- import java .util .List ;
43+ import java .util .HashMap ;
4244import java .util .Map ;
4345import java .util .Objects ;
4446import java .util .Optional ;
4951public class RequestConfirmationLlmRequestProcessor implements RequestProcessor {
5052 private static final Logger logger =
5153 LoggerFactory .getLogger (RequestConfirmationLlmRequestProcessor .class );
52- private final ObjectMapper objectMapper ;
53-
54- public RequestConfirmationLlmRequestProcessor () {
55- objectMapper = new ObjectMapper ().registerModule (new Jdk8Module ());
56- }
54+ private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel .getMapper ();
55+ private static final String ORIGINAL_FUNCTION_CALL = "originalFunctionCall" ;
5756
5857 @ Override
5958 public Single <RequestProcessor .RequestProcessingResult > processRequest (
6059 InvocationContext invocationContext , LlmRequest llmRequest ) {
61- List <Event > events = invocationContext .session ().events ();
60+ ImmutableList <Event > events = ImmutableList . copyOf ( invocationContext .session ().events () );
6261 if (events .isEmpty ()) {
6362 logger .info (
6463 "No events are present in the session. Skipping request confirmation processing." );
6564 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
6665 }
6766
68- ImmutableMap <String , ToolConfirmation > requestConfirmationFunctionResponses =
69- filterRequestConfirmationFunctionResponses (events );
67+ ImmutableMap <String , ToolConfirmation > responses = ImmutableMap .of ();
68+ int confirmationEventIndex = -1 ;
69+ for (int i = events .size () - 1 ; i >= 0 ; i --) {
70+ Event event = events .get (i );
71+ if (!Objects .equals (event .author (), "user" )) {
72+ continue ;
73+ }
74+ if (event .functionResponses ().isEmpty ()) {
75+ continue ;
76+ }
77+ responses =
78+ event .functionResponses ().stream ()
79+ .filter (functionResponse -> functionResponse .id ().isPresent ())
80+ .filter (
81+ functionResponse ->
82+ Objects .equals (
83+ functionResponse .name ().orElse (null ),
84+ REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
85+ .map (this ::maybeCreateToolConfirmationEntry )
86+ .flatMap (Optional ::stream )
87+ .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
88+ confirmationEventIndex = i ;
89+ break ;
90+ }
91+
92+ // Make it final to enable access from lambda expressions.
93+ final ImmutableMap <String , ToolConfirmation > requestConfirmationFunctionResponses = responses ;
94+
7095 if (requestConfirmationFunctionResponses .isEmpty ()) {
7196 logger .info ("No request confirmation function responses found." );
7297 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
7398 }
7499
75- for (ImmutableList <FunctionCall > functionCalls :
76- events .stream ()
77- .map (Event ::functionCalls )
78- .filter (fc -> !fc .isEmpty ())
79- .collect (toImmutableList ())) {
100+ for (int i = events .size () - 2 ; i >= 0 ; i --) {
101+ Event event = events .get (i );
102+ if (event .functionCalls ().isEmpty ()) {
103+ continue ;
104+ }
105+
106+ Map <String , ToolConfirmation > toolsToResumeWithConfirmation = new HashMap <>();
107+ Map <String , FunctionCall > toolsToResumeWithArgs = new HashMap <>();
108+
109+ event .functionCalls ().stream ()
110+ .filter (
111+ fc ->
112+ fc .id ().isPresent ()
113+ && requestConfirmationFunctionResponses .containsKey (fc .id ().get ()))
114+ .forEach (
115+ fc ->
116+ getOriginalFunctionCall (fc )
117+ .ifPresent (
118+ ofc -> {
119+ toolsToResumeWithConfirmation .put (
120+ ofc .id ().get (),
121+ requestConfirmationFunctionResponses .get (fc .id ().get ()));
122+ toolsToResumeWithArgs .put (ofc .id ().get (), ofc );
123+ }));
124+
125+ if (toolsToResumeWithConfirmation .isEmpty ()) {
126+ continue ;
127+ }
128+
129+ // Remove the tools that have already been confirmed.
130+ ImmutableSet <String > alreadyConfirmedIds =
131+ events .subList (confirmationEventIndex + 1 , events .size ()).stream ()
132+ .flatMap (e -> e .functionResponses ().stream ())
133+ .map (FunctionResponse ::id )
134+ .flatMap (Optional ::stream )
135+ .collect (toImmutableSet ());
136+ toolsToResumeWithConfirmation .keySet ().removeAll (alreadyConfirmedIds );
137+ toolsToResumeWithArgs .keySet ().removeAll (alreadyConfirmedIds );
80138
81- ImmutableMap <String , FunctionCall > toolsToResumeWithArgs =
82- filterToolsToResumeWithArgs (functionCalls , requestConfirmationFunctionResponses );
83- ImmutableMap <String , ToolConfirmation > toolsToResumeWithConfirmation =
84- toolsToResumeWithArgs .keySet ().stream ()
85- .filter (
86- id ->
87- events .stream ()
88- .flatMap (e -> e .functionResponses ().stream ())
89- .anyMatch (fr -> Objects .equals (fr .id ().orElse (null ), id )))
90- .collect (toImmutableMap (k -> k , requestConfirmationFunctionResponses ::get ));
91139 if (toolsToResumeWithConfirmation .isEmpty ()) {
92- logger .info ("No tools to resume with confirmation." );
93140 continue ;
94141 }
95142
96143 return assembleEvent (
97- invocationContext , toolsToResumeWithArgs .values (), toolsToResumeWithConfirmation )
98- .map (event -> RequestProcessingResult .create (llmRequest , ImmutableList .of (event )))
144+ invocationContext ,
145+ toolsToResumeWithArgs .values (),
146+ ImmutableMap .copyOf (toolsToResumeWithConfirmation ))
147+ .map (e -> RequestProcessingResult .create (llmRequest , ImmutableList .of (e )))
99148 .toSingle ();
100149 }
101150
102151 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
103152 }
104153
154+ private Optional <FunctionCall > getOriginalFunctionCall (FunctionCall functionCall ) {
155+ if (!functionCall .args ().orElse (ImmutableMap .of ()).containsKey (ORIGINAL_FUNCTION_CALL )) {
156+ return Optional .empty ();
157+ }
158+ try {
159+ FunctionCall originalFunctionCall =
160+ OBJECT_MAPPER .convertValue (
161+ functionCall .args ().get ().get (ORIGINAL_FUNCTION_CALL ), FunctionCall .class );
162+ if (originalFunctionCall .id ().isEmpty ()) {
163+ return Optional .empty ();
164+ }
165+ return Optional .of (originalFunctionCall );
166+ } catch (IllegalArgumentException e ) {
167+ logger .warn ("Failed to convert originalFunctionCall argument." , e );
168+ return Optional .empty ();
169+ }
170+ }
171+
105172 private Maybe <Event > assembleEvent (
106173 InvocationContext invocationContext ,
107174 Collection <FunctionCall > functionCalls ,
108175 Map <String , ToolConfirmation > toolConfirmations ) {
109- ImmutableMap . Builder <String , BaseTool > toolsBuilder = ImmutableMap . builder () ;
176+ Single < ImmutableMap <String , BaseTool >> toolsMapSingle ;
110177 if (invocationContext .agent () instanceof LlmAgent llmAgent ) {
111- for (BaseTool tool : llmAgent .tools ()) {
112- toolsBuilder .put (tool .name (), tool );
113- }
178+ toolsMapSingle =
179+ llmAgent
180+ .tools ()
181+ .map (
182+ toolList ->
183+ toolList .stream ().collect (toImmutableMap (BaseTool ::name , tool -> tool )));
184+ } else {
185+ toolsMapSingle = Single .just (ImmutableMap .of ());
114186 }
115187
116188 var functionCallEvent =
@@ -124,23 +196,10 @@ private Maybe<Event> assembleEvent(
124196 .build ())
125197 .build ();
126198
127- return Functions .handleFunctionCalls (
128- invocationContext , functionCallEvent , toolsBuilder .buildOrThrow (), toolConfirmations );
129- }
130-
131- private ImmutableMap <String , ToolConfirmation > filterRequestConfirmationFunctionResponses (
132- List <Event > events ) {
133- return events .stream ()
134- .filter (event -> Objects .equals (event .author (), "user" ))
135- .flatMap (event -> event .functionResponses ().stream ())
136- .filter (functionResponse -> functionResponse .id ().isPresent ())
137- .filter (
138- functionResponse ->
139- Objects .equals (
140- functionResponse .name ().orElse (null ), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
141- .map (this ::maybeCreateToolConfirmationEntry )
142- .flatMap (Optional ::stream )
143- .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
199+ return toolsMapSingle .flatMapMaybe (
200+ toolsMap ->
201+ Functions .handleFunctionCalls (
202+ invocationContext , functionCallEvent , toolsMap , toolConfirmations ));
144203 }
145204
146205 private Optional <Map .Entry <String , ToolConfirmation >> maybeCreateToolConfirmationEntry (
@@ -150,36 +209,19 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio
150209 return Optional .of (
151210 Map .entry (
152211 functionResponse .id ().get (),
153- objectMapper .convertValue (responseMap , ToolConfirmation .class )));
212+ OBJECT_MAPPER .convertValue (responseMap , ToolConfirmation .class )));
154213 }
155214
156215 try {
157216 return Optional .of (
158217 Map .entry (
159218 functionResponse .id ().get (),
160- objectMapper .readValue (
219+ OBJECT_MAPPER .readValue (
161220 (String ) responseMap .get ("response" ), ToolConfirmation .class )));
162221 } catch (JsonProcessingException e ) {
163222 logger .error ("Failed to parse tool confirmation response" , e );
164223 }
165224
166225 return Optional .empty ();
167226 }
168-
169- private ImmutableMap <String , FunctionCall > filterToolsToResumeWithArgs (
170- ImmutableList <FunctionCall > functionCalls ,
171- Map <String , ToolConfirmation > requestConfirmationFunctionResponses ) {
172- return functionCalls .stream ()
173- .filter (fc -> fc .id ().isPresent ())
174- .filter (fc -> requestConfirmationFunctionResponses .containsKey (fc .id ().get ()))
175- .filter (
176- fc -> Objects .equals (fc .name ().orElse (null ), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
177- .filter (fc -> fc .args ().orElse (ImmutableMap .of ()).containsKey ("originalFunctionCall" ))
178- .collect (
179- toImmutableMap (
180- fc -> fc .id ().get (),
181- fc ->
182- objectMapper .convertValue (
183- fc .args ().get ().get ("originalFunctionCall" ), FunctionCall .class )));
184- }
185227}
0 commit comments