Skip to content

Commit 952ff68

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fix ADK Runner race condition for sequential tool execution
PiperOrigin-RevId: 901312735
1 parent 4009905 commit 952ff68

7 files changed

Lines changed: 475 additions & 57 deletions

File tree

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static java.util.stream.Collectors.joining;
2222

2323
import com.fasterxml.jackson.core.JsonProcessingException;
24+
import com.fasterxml.jackson.databind.ObjectMapper;
2425
import com.google.adk.SchemaUtils;
2526
import com.google.adk.agents.Callbacks.AfterAgentCallbackSync;
2627
import com.google.adk.agents.Callbacks.AfterModelCallback;
@@ -50,12 +51,15 @@
5051
import com.google.adk.flows.llmflows.SingleFlow;
5152
import com.google.adk.models.BaseLlm;
5253
import com.google.adk.models.LlmRegistry;
54+
import com.google.adk.models.LlmResponse;
5355
import com.google.adk.models.Model;
5456
import com.google.adk.tools.BaseTool;
5557
import com.google.adk.tools.BaseToolset;
5658
import com.google.common.base.Preconditions;
5759
import com.google.common.collect.ImmutableList;
60+
import com.google.common.collect.ImmutableMap;
5861
import com.google.errorprone.annotations.CanIgnoreReturnValue;
62+
import com.google.genai.types.Blob;
5963
import com.google.genai.types.Content;
6064
import com.google.genai.types.GenerateContentConfig;
6165
import com.google.genai.types.Part;
@@ -64,6 +68,7 @@
6468
import io.reactivex.rxjava3.core.Flowable;
6569
import io.reactivex.rxjava3.core.Maybe;
6670
import io.reactivex.rxjava3.core.Single;
71+
import java.nio.charset.StandardCharsets;
6772
import java.util.ArrayList;
6873
import java.util.List;
6974
import java.util.Map;
@@ -134,7 +139,13 @@ protected LlmAgent(Builder builder) {
134139
this.disallowTransferToParent = requireNonNullElse(builder.disallowTransferToParent, false);
135140
this.disallowTransferToPeers = requireNonNullElse(builder.disallowTransferToPeers, false);
136141
this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of());
137-
this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of());
142+
List<AfterModelCallback> afterCallbacks = new ArrayList<>();
143+
if (builder.outputKey != null) {
144+
afterCallbacks.add(
145+
new OutputKeySaverCallback(builder.outputKey, Optional.ofNullable(builder.outputSchema)));
146+
}
147+
afterCallbacks.addAll(requireNonNullElse(builder.afterModelCallback, ImmutableList.of()));
148+
this.afterModelCallback = ImmutableList.copyOf(afterCallbacks);
138149
this.onModelErrorCallback =
139150
requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of());
140151
this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of());
@@ -610,41 +621,69 @@ protected BaseLlmFlow determineLlmFlow() {
610621
}
611622
}
612623

613-
private void maybeSaveOutputToState(Event event) {
614-
if (outputKey().isPresent() && event.finalResponse() && event.content().isPresent()) {
615-
// Concatenate text from all parts, excluding thoughts.
616-
Object output;
624+
private static class OutputKeySaverCallback implements AfterModelCallback {
625+
private static final ObjectMapper objectMapper = new ObjectMapper();
626+
private final String outputKey;
627+
private final Optional<Schema> outputSchema;
628+
629+
private OutputKeySaverCallback(String outputKey, Optional<Schema> outputSchema) {
630+
this.outputKey = outputKey;
631+
this.outputSchema = outputSchema;
632+
}
633+
634+
@Override
635+
public Maybe<LlmResponse> call(CallbackContext context, LlmResponse response) {
636+
if (response.content().isEmpty()) {
637+
return Maybe.empty();
638+
}
639+
640+
Content originalContent = response.content().get();
617641
String rawResult =
618-
event.content().flatMap(Content::parts).orElseGet(ImmutableList::of).stream()
642+
originalContent.parts().orElse(ImmutableList.of()).stream()
619643
.filter(part -> !isThought(part))
620644
.map(part -> part.text().orElse(""))
621645
.collect(joining());
622646

623-
Optional<Schema> outputSchema = outputSchema();
647+
Object output;
624648
if (outputSchema.isPresent()) {
625649
try {
626650
Map<String, Object> validatedMap =
627651
SchemaUtils.validateOutputSchema(rawResult, outputSchema.get());
628652
output = validatedMap;
629-
} catch (JsonProcessingException e) {
630-
logger.error(
631-
"LlmAgent output for outputKey '{}' was not valid JSON, despite an outputSchema being"
632-
+ " present. Saving raw output to state.",
633-
outputKey().get(),
634-
e);
635-
output = rawResult;
636-
} catch (IllegalArgumentException e) {
653+
} catch (JsonProcessingException | IllegalArgumentException e) {
637654
logger.error(
638655
"LlmAgent output for outputKey '{}' did not match the outputSchema. Saving raw output"
639656
+ " to state.",
640-
outputKey().get(),
657+
outputKey,
641658
e);
642659
output = rawResult;
643660
}
644661
} else {
645662
output = rawResult;
646663
}
647-
event.actions().stateDelta().put(outputKey().get(), output);
664+
665+
String jsonMetadata;
666+
try {
667+
jsonMetadata = objectMapper.writeValueAsString(ImmutableMap.of(outputKey, output));
668+
} catch (JsonProcessingException e) {
669+
return Maybe.error(e);
670+
}
671+
672+
Part stateDeltaPart =
673+
Part.builder()
674+
.inlineData(
675+
Blob.builder()
676+
.data(jsonMetadata.getBytes(StandardCharsets.UTF_8))
677+
.mimeType("application/json+metadata")
678+
.build())
679+
.build();
680+
681+
List<Part> newParts = new ArrayList<>(originalContent.parts().orElse(ImmutableList.of()));
682+
newParts.add(stateDeltaPart);
683+
684+
Content newContent = originalContent.toBuilder().parts(newParts).build();
685+
686+
return Maybe.just(response.toBuilder().content(newContent).build());
648687
}
649688
}
650689

@@ -654,12 +693,12 @@ private static boolean isThought(Part part) {
654693

655694
@Override
656695
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
657-
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
696+
return llmFlow.run(invocationContext);
658697
}
659698

660699
@Override
661700
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
662-
return llmFlow.runLive(invocationContext).doOnNext(this::maybeSaveOutputToState);
701+
return llmFlow.runLive(invocationContext);
663702
}
664703

665704
/**

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
package com.google.adk.flows.llmflows;
1818

19+
import static java.nio.charset.StandardCharsets.UTF_8;
20+
21+
import com.fasterxml.jackson.core.JsonProcessingException;
22+
import com.fasterxml.jackson.core.type.TypeReference;
23+
import com.fasterxml.jackson.databind.ObjectMapper;
1924
import com.google.adk.agents.ActiveStreamingTool;
2025
import com.google.adk.agents.BaseAgent;
2126
import com.google.adk.agents.CallbackContext;
@@ -28,6 +33,7 @@
2833
import com.google.adk.agents.ReadonlyContext;
2934
import com.google.adk.agents.RunConfig.StreamingMode;
3035
import com.google.adk.events.Event;
36+
import com.google.adk.events.EventActions;
3137
import com.google.adk.flows.BaseFlow;
3238
import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult;
3339
import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult;
@@ -41,7 +47,9 @@
4147
import com.google.adk.tools.ToolContext;
4248
import com.google.common.collect.ImmutableList;
4349
import com.google.common.collect.Iterables;
50+
import com.google.genai.types.Content;
4451
import com.google.genai.types.FunctionResponse;
52+
import com.google.genai.types.Part;
4553
import io.opentelemetry.api.trace.Span;
4654
import io.opentelemetry.api.trace.StatusCode;
4755
import io.opentelemetry.context.Context;
@@ -54,7 +62,9 @@
5462
import io.reactivex.rxjava3.observers.DisposableCompletableObserver;
5563
import io.reactivex.rxjava3.schedulers.Schedulers;
5664
import java.util.ArrayList;
65+
import java.util.HashMap;
5766
import java.util.List;
67+
import java.util.Map;
5868
import java.util.Optional;
5969
import java.util.Set;
6070
import java.util.concurrent.atomic.AtomicReference;
@@ -64,6 +74,7 @@
6474
/** A basic flow that calls the LLM in a loop until a final response is generated. */
6575
public abstract class BaseLlmFlow implements BaseFlow {
6676
private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class);
77+
private static final ObjectMapper objectMapper = new ObjectMapper();
6778

6879
protected final List<RequestProcessor> requestProcessors;
6980
protected final List<ResponseProcessor> responseProcessors;
@@ -349,14 +360,19 @@ private Single<LlmResponse> handleAfterModelCallback(
349360

350361
Maybe<LlmResponse> callbackResult =
351362
Maybe.defer(
352-
() ->
353-
Flowable.fromIterable(callbacks)
354-
.concatMapMaybe(
355-
callback ->
363+
() -> {
364+
Single<LlmResponse> currentResponse = Single.just(llmResponse);
365+
for (AfterModelCallback callback : callbacks) {
366+
currentResponse =
367+
currentResponse.flatMap(
368+
resp ->
356369
callback
357-
.call(callbackContext, llmResponse)
358-
.compose(Tracing.withContext(currentContext)))
359-
.firstElement());
370+
.call(callbackContext, resp)
371+
.compose(Tracing.withContext(currentContext))
372+
.defaultIfEmpty(resp));
373+
}
374+
return currentResponse.toMaybe();
375+
});
360376

361377
return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);
362378
}
@@ -461,14 +477,37 @@ public Flowable<Event> run(InvocationContext invocationContext) {
461477

462478
private Flowable<Event> run(
463479
Context spanContext, InvocationContext invocationContext, int stepsCompleted) {
464-
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext).cache();
480+
Flowable<Event> currentStepEvents = runOneStep(spanContext, invocationContext);
481+
482+
Flowable<Event> processedEvents =
483+
currentStepEvents
484+
.concatMap(
485+
event -> {
486+
if (invocationContext.session().events().stream()
487+
.anyMatch(e -> e.id() != null && e.id().equals(event.id()))) {
488+
logger.debug("Event {} already in session, skipping append", event.id());
489+
return Flowable.just(event);
490+
}
491+
return invocationContext
492+
.sessionService()
493+
.appendEvent(invocationContext.session(), event)
494+
.flatMap(
495+
registeredEvent ->
496+
invocationContext
497+
.pluginManager()
498+
.onEventCallback(invocationContext, registeredEvent)
499+
.defaultIfEmpty(registeredEvent))
500+
.toFlowable();
501+
})
502+
.cache();
503+
465504
if (stepsCompleted + 1 >= maxSteps) {
466505
logger.debug("Ending flow execution because max steps reached.");
467-
return currentStepEvents;
506+
return processedEvents;
468507
}
469508

470-
return currentStepEvents.concatWith(
471-
currentStepEvents
509+
return processedEvents.concatWith(
510+
processedEvents
472511
.toList()
473512
.flatMapPublisher(
474513
eventList -> {
@@ -685,22 +724,75 @@ private Flowable<Event> buildPostprocessingEvents(
685724

686725
Event modelResponseEvent =
687726
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
688-
if (modelResponseEvent.functionCalls().isEmpty()) {
689-
return processorEvents.concatWith(Flowable.just(modelResponseEvent));
727+
728+
if (context.agent() instanceof LlmAgent agent) {
729+
Optional<String> outputKeyOpt = agent.outputKey();
730+
if (outputKeyOpt.isPresent() && modelResponseEvent.content().isPresent()) {
731+
Content content = modelResponseEvent.content().get();
732+
Map<String, Object> extractedDelta = new HashMap<>();
733+
List<Part> cleanParts = new ArrayList<>();
734+
boolean metadataFound = false;
735+
for (Part part : content.parts().orElse(ImmutableList.of())) {
736+
if (part.inlineData().isPresent()
737+
&& part.inlineData()
738+
.get()
739+
.mimeType()
740+
.orElse("")
741+
.equals("application/json+metadata")) {
742+
metadataFound = true;
743+
byte[] data = part.inlineData().get().data().orElse(null);
744+
if (data != null) {
745+
String json = new String(data, UTF_8);
746+
try {
747+
Map<String, Object> metadata =
748+
objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {});
749+
extractedDelta.putAll(metadata);
750+
} catch (JsonProcessingException e) {
751+
logger.error("Failed to parse metadata from inlineData", e);
752+
}
753+
}
754+
} else {
755+
cleanParts.add(part);
756+
}
757+
}
758+
759+
if (metadataFound) {
760+
Event.Builder updatedEventBuilder = modelResponseEvent.toBuilder();
761+
Content newContent =
762+
Content.builder().role(content.role().orElse("model")).parts(cleanParts).build();
763+
updatedEventBuilder.content(newContent);
764+
765+
if (!extractedDelta.isEmpty() && modelResponseEvent.finalResponse()) {
766+
Map<String, Object> newStateDelta =
767+
new HashMap<>(modelResponseEvent.actions().stateDelta());
768+
newStateDelta.putAll(extractedDelta);
769+
EventActions updatedActions =
770+
modelResponseEvent.actions().toBuilder().stateDelta(newStateDelta).build();
771+
updatedEventBuilder.actions(updatedActions);
772+
}
773+
modelResponseEvent = updatedEventBuilder.build();
774+
}
775+
}
776+
}
777+
final Event finalModelResponseEvent = modelResponseEvent;
778+
779+
if (finalModelResponseEvent.functionCalls().isEmpty()) {
780+
return processorEvents.concatWith(Flowable.just(finalModelResponseEvent));
690781
}
691782

692783
Flowable<Event> functionEvents;
693784
try (Scope scope = parentContext.makeCurrent()) {
694785
Maybe<Event> maybeFunctionResponseEvent =
695786
context.runConfig().streamingMode() == StreamingMode.BIDI
696-
? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools())
697-
: Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools());
787+
? Functions.handleFunctionCallsLive(
788+
context, finalModelResponseEvent, llmRequest.tools())
789+
: Functions.handleFunctionCalls(context, finalModelResponseEvent, llmRequest.tools());
698790
functionEvents =
699791
maybeFunctionResponseEvent.flatMapPublisher(
700792
functionResponseEvent -> {
701793
Optional<Event> toolConfirmationEvent =
702794
Functions.generateRequestConfirmationEvent(
703-
context, modelResponseEvent, functionResponseEvent);
795+
context, finalModelResponseEvent, functionResponseEvent);
704796
List<Event> events = new ArrayList<>();
705797
toolConfirmationEvent.ifPresent(events::add);
706798
events.add(functionResponseEvent);

0 commit comments

Comments
 (0)