|
16 | 16 |
|
17 | 17 | package com.google.adk.flows.llmflows; |
18 | 18 |
|
| 19 | +import static java.nio.charset.StandardCharsets.UTF_8; |
| 20 | + |
| 21 | +import com.fasterxml.jackson.core.JsonProcessingException; |
| 22 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 23 | +import com.fasterxml.jackson.databind.ObjectMapper; |
19 | 24 | import com.google.adk.agents.ActiveStreamingTool; |
20 | 25 | import com.google.adk.agents.BaseAgent; |
21 | 26 | import com.google.adk.agents.CallbackContext; |
|
28 | 33 | import com.google.adk.agents.ReadonlyContext; |
29 | 34 | import com.google.adk.agents.RunConfig.StreamingMode; |
30 | 35 | import com.google.adk.events.Event; |
| 36 | +import com.google.adk.events.EventActions; |
31 | 37 | import com.google.adk.flows.BaseFlow; |
32 | 38 | import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; |
33 | 39 | import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; |
|
41 | 47 | import com.google.adk.tools.ToolContext; |
42 | 48 | import com.google.common.collect.ImmutableList; |
43 | 49 | import com.google.common.collect.Iterables; |
| 50 | +import com.google.genai.types.Content; |
44 | 51 | import com.google.genai.types.FunctionResponse; |
| 52 | +import com.google.genai.types.Part; |
45 | 53 | import io.opentelemetry.api.trace.Span; |
46 | 54 | import io.opentelemetry.api.trace.StatusCode; |
47 | 55 | import io.opentelemetry.context.Context; |
|
54 | 62 | import io.reactivex.rxjava3.observers.DisposableCompletableObserver; |
55 | 63 | import io.reactivex.rxjava3.schedulers.Schedulers; |
56 | 64 | import java.util.ArrayList; |
| 65 | +import java.util.HashMap; |
57 | 66 | import java.util.List; |
| 67 | +import java.util.Map; |
58 | 68 | import java.util.Optional; |
59 | 69 | import java.util.Set; |
60 | 70 | import java.util.concurrent.atomic.AtomicReference; |
|
64 | 74 | /** A basic flow that calls the LLM in a loop until a final response is generated. */ |
65 | 75 | public abstract class BaseLlmFlow implements BaseFlow { |
66 | 76 | private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class); |
| 77 | + private static final ObjectMapper objectMapper = new ObjectMapper(); |
67 | 78 |
|
68 | 79 | protected final List<RequestProcessor> requestProcessors; |
69 | 80 | protected final List<ResponseProcessor> responseProcessors; |
@@ -349,14 +360,19 @@ private Single<LlmResponse> handleAfterModelCallback( |
349 | 360 |
|
350 | 361 | Maybe<LlmResponse> callbackResult = |
351 | 362 | Maybe.defer( |
352 | | - () -> |
353 | | - Flowable.fromIterable(callbacks) |
354 | | - .concatMapMaybe( |
355 | | - callback -> |
| 363 | + () -> { |
| 364 | + Single<LlmResponse> currentResponse = Single.just(llmResponse); |
| 365 | + for (AfterModelCallback callback : callbacks) { |
| 366 | + currentResponse = |
| 367 | + currentResponse.flatMap( |
| 368 | + resp -> |
356 | 369 | callback |
357 | | - .call(callbackContext, llmResponse) |
358 | | - .compose(Tracing.withContext(currentContext))) |
359 | | - .firstElement()); |
| 370 | + .call(callbackContext, resp) |
| 371 | + .compose(Tracing.withContext(currentContext)) |
| 372 | + .defaultIfEmpty(resp)); |
| 373 | + } |
| 374 | + return currentResponse.toMaybe(); |
| 375 | + }); |
360 | 376 |
|
361 | 377 | return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); |
362 | 378 | } |
@@ -461,14 +477,37 @@ public Flowable<Event> run(InvocationContext invocationContext) { |
461 | 477 |
|
462 | 478 | private Flowable<Event> run( |
463 | 479 | Context spanContext, InvocationContext invocationContext, int stepsCompleted) { |
464 | | - Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache(); |
| 480 | + Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext); |
| 481 | + |
| 482 | + Flowable<Event> processedEvents = |
| 483 | + currentStepEvents |
| 484 | + .concatMap( |
| 485 | + event -> { |
| 486 | + if (invocationContext.session().events().stream() |
| 487 | + .anyMatch(e -> e.id() != null && e.id().equals(event.id()))) { |
| 488 | + logger.debug("Event {} already in session, skipping append", event.id()); |
| 489 | + return Flowable.just(event); |
| 490 | + } |
| 491 | + return invocationContext |
| 492 | + .sessionService() |
| 493 | + .appendEvent(invocationContext.session(), event) |
| 494 | + .flatMap( |
| 495 | + registeredEvent -> |
| 496 | + invocationContext |
| 497 | + .pluginManager() |
| 498 | + .onEventCallback(invocationContext, registeredEvent) |
| 499 | + .defaultIfEmpty(registeredEvent)) |
| 500 | + .toFlowable(); |
| 501 | + }) |
| 502 | + .cache(); |
| 503 | + |
465 | 504 | if (stepsCompleted + 1 >= maxSteps) { |
466 | 505 | logger.debug("Ending flow execution because max steps reached."); |
467 | | - return currentStepEvents; |
| 506 | + return processedEvents; |
468 | 507 | } |
469 | 508 |
|
470 | | - return currentStepEvents.concatWith( |
471 | | - currentStepEvents |
| 509 | + return processedEvents.concatWith( |
| 510 | + processedEvents |
472 | 511 | .toList() |
473 | 512 | .flatMapPublisher( |
474 | 513 | eventList -> { |
@@ -685,22 +724,75 @@ private Flowable<Event> buildPostprocessingEvents( |
685 | 724 |
|
686 | 725 | Event modelResponseEvent = |
687 | 726 | buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); |
688 | | - if (modelResponseEvent.functionCalls().isEmpty()) { |
689 | | - return processorEvents.concatWith(Flowable.just(modelResponseEvent)); |
| 727 | + |
| 728 | + if (context.agent() instanceof LlmAgent agent) { |
| 729 | + Optional<String> outputKeyOpt = agent.outputKey(); |
| 730 | + if (outputKeyOpt.isPresent() && modelResponseEvent.content().isPresent()) { |
| 731 | + Content content = modelResponseEvent.content().get(); |
| 732 | + Map<String, Object> extractedDelta = new HashMap<>(); |
| 733 | + List<Part> cleanParts = new ArrayList<>(); |
| 734 | + boolean metadataFound = false; |
| 735 | + for (Part part : content.parts().orElse(ImmutableList.of())) { |
| 736 | + if (part.inlineData().isPresent() |
| 737 | + && part.inlineData() |
| 738 | + .get() |
| 739 | + .mimeType() |
| 740 | + .orElse("") |
| 741 | + .equals("application/json+metadata")) { |
| 742 | + metadataFound = true; |
| 743 | + byte[] data = part.inlineData().get().data().orElse(null); |
| 744 | + if (data != null) { |
| 745 | + String json = new String(data, UTF_8); |
| 746 | + try { |
| 747 | + Map<String, Object> metadata = |
| 748 | + objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {}); |
| 749 | + extractedDelta.putAll(metadata); |
| 750 | + } catch (JsonProcessingException e) { |
| 751 | + logger.error("Failed to parse metadata from inlineData", e); |
| 752 | + } |
| 753 | + } |
| 754 | + } else { |
| 755 | + cleanParts.add(part); |
| 756 | + } |
| 757 | + } |
| 758 | + |
| 759 | + if (metadataFound) { |
| 760 | + Event.Builder updatedEventBuilder = modelResponseEvent.toBuilder(); |
| 761 | + Content newContent = |
| 762 | + Content.builder().role(content.role().orElse("model")).parts(cleanParts).build(); |
| 763 | + updatedEventBuilder.content(newContent); |
| 764 | + |
| 765 | + if (!extractedDelta.isEmpty() && modelResponseEvent.finalResponse()) { |
| 766 | + Map<String, Object> newStateDelta = |
| 767 | + new HashMap<>(modelResponseEvent.actions().stateDelta()); |
| 768 | + newStateDelta.putAll(extractedDelta); |
| 769 | + EventActions updatedActions = |
| 770 | + modelResponseEvent.actions().toBuilder().stateDelta(newStateDelta).build(); |
| 771 | + updatedEventBuilder.actions(updatedActions); |
| 772 | + } |
| 773 | + modelResponseEvent = updatedEventBuilder.build(); |
| 774 | + } |
| 775 | + } |
| 776 | + } |
| 777 | + final Event finalModelResponseEvent = modelResponseEvent; |
| 778 | + |
| 779 | + if (finalModelResponseEvent.functionCalls().isEmpty()) { |
| 780 | + return processorEvents.concatWith(Flowable.just(finalModelResponseEvent)); |
690 | 781 | } |
691 | 782 |
|
692 | 783 | Flowable<Event> functionEvents; |
693 | 784 | try (Scope scope = parentContext.makeCurrent()) { |
694 | 785 | Maybe<Event> maybeFunctionResponseEvent = |
695 | 786 | context.runConfig().streamingMode() == StreamingMode.BIDI |
696 | | - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) |
697 | | - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); |
| 787 | + ? Functions.handleFunctionCallsLive( |
| 788 | + context, finalModelResponseEvent, llmRequest.tools()) |
| 789 | + : Functions.handleFunctionCalls(context, finalModelResponseEvent, llmRequest.tools()); |
698 | 790 | functionEvents = |
699 | 791 | maybeFunctionResponseEvent.flatMapPublisher( |
700 | 792 | functionResponseEvent -> { |
701 | 793 | Optional<Event> toolConfirmationEvent = |
702 | 794 | Functions.generateRequestConfirmationEvent( |
703 | | - context, modelResponseEvent, functionResponseEvent); |
| 795 | + context, finalModelResponseEvent, functionResponseEvent); |
704 | 796 | List<Event> events = new ArrayList<>(); |
705 | 797 | toolConfirmationEvent.ifPresent(events::add); |
706 | 798 | events.add(functionResponseEvent); |
|
0 commit comments