diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..fdda5219d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -461,14 +461,31 @@ public Flowable run(InvocationContext invocationContext) { private Flowable run( Context spanContext, InvocationContext invocationContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(spanContext, invocationContext).cache(); + Flowable currentStepEvents = runOneStep(spanContext, invocationContext); + + Flowable processedEvents = + currentStepEvents + .concatMap( + event -> + invocationContext + .sessionService() + .appendEvent(invocationContext.session(), event) + .flatMap( + registeredEvent -> + invocationContext + .pluginManager() + .onEventCallback(invocationContext, registeredEvent) + .defaultIfEmpty(registeredEvent)) + .toFlowable()) + .cache(); + if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); - return currentStepEvents; + return processedEvents; } - return currentStepEvents.concatWith( - currentStepEvents + return processedEvents.concatWith( + processedEvents .toList() .flatMapPublisher( eventList -> { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 0b0e5b4d5..49af2a122 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -71,8 +71,10 @@ public final class Functions { private static final Logger logger = LoggerFactory.getLogger(Functions.class); /** Generates a unique ID for a function call. */ - public static String generateClientFunctionCallId() { - return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID(); + public static String generateClientFunctionCallId(FunctionCall functionCall) { + String source = + functionCall.name().orElse("") + functionCall.args().orElse(ImmutableMap.of()).toString(); + return AF_FUNCTION_CALL_ID_PREFIX + UUID.nameUUIDFromBytes(source.getBytes()).toString(); } /** @@ -101,7 +103,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) { FunctionCall functionCall = part.functionCall().get(); if (functionCall.id().isEmpty() || functionCall.id().get().isEmpty()) { FunctionCall updatedFunctionCall = - functionCall.toBuilder().id(generateClientFunctionCallId()).build(); + functionCall.toBuilder().id(generateClientFunctionCallId(functionCall)).build(); newParts.add(part.toBuilder().functionCall(updatedFunctionCall).build()); modified = true; } else { @@ -621,7 +623,7 @@ private static Event buildResponseEvent( .build(); return Event.builder() - .id(Event.generateEventId()) + .id(toolContext.functionCallId().orElseGet(Event::generateEventId)) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) @@ -657,7 +659,7 @@ public static Optional generateRequestConfirmationEvent( .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)) .entrySet()) { - FunctionCall requestConfirmationFunctionCall = + FunctionCall.Builder builder = FunctionCall.builder() .name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME) .args( @@ -665,9 +667,9 @@ public static Optional generateRequestConfirmationEvent( "originalFunctionCall", functionCallsById.get(entry.getKey()), "toolConfirmation", - entry.getValue())) - .id(generateClientFunctionCallId()) - .build(); + entry.getValue())); + FunctionCall requestConfirmationFunctionCall = + builder.id(generateClientFunctionCallId(builder.build())).build(); longRunningToolIds.add(requestConfirmationFunctionCall.id().get()); parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build()); @@ -680,8 +682,15 @@ public static Optional generateRequestConfirmationEvent( var contentBuilder = Content.builder().parts(parts); functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role); + String deterministicId = + "req-conf-" + + functionResponseEvent.actions().requestedToolConfirmations().keySet().stream() + .sorted() + .collect(java.util.stream.Collectors.joining("-")); + return Optional.of( Event.builder() + .id(deterministicId) .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 44a281f72..f6fe08c2b 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -68,9 +68,12 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The main class for the GenAI Agents runner. */ public class Runner { + private static final Logger logger = LoggerFactory.getLogger(Runner.class); private final BaseAgent agent; private final String appName; private final BaseArtifactService artifactService; @@ -570,19 +573,28 @@ private Flowable runAgentWithUpdatedSession( .agent() .runAsync(contextWithUpdatedSession) .concatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack after deprecating runAsync with Session. - copySessionStates(updatedSession, initialContext.session()); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback(contextWithUpdatedSession, registeredEvent) - .defaultIfEmpty(registeredEvent); - }) - .toFlowable()); + agentEvent -> { + // TODO: remove this hack after deprecating runAsync with Session. + copySessionStates(updatedSession, initialContext.session()); + + // TODO: b/502182243 - Investigate if appendEvent should be made idempotent in + // SessionService to avoid this check. + if (updatedSession.events().stream() + .anyMatch(e -> e.id() != null && e.id().equals(agentEvent.id()))) { + logger.debug("Event {} already in session, skipping append", agentEvent.id()); + return io.reactivex.rxjava3.core.Flowable.just(agentEvent); + } + return this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> { + return contextWithUpdatedSession + .pluginManager() + .onEventCallback(contextWithUpdatedSession, registeredEvent) + .defaultIfEmpty(registeredEvent); + }) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and skip agent Context capturedContext = Context.current(); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index ff75c97b0..5f3c7295d 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -46,9 +46,12 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.GetSessionConfig; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; @@ -80,6 +83,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -588,12 +592,22 @@ public void onToolErrorCallback_error() { @Test public void onEventCallback_success() { when(plugin.onEventCallback(any(), any())) - .thenReturn(Maybe.just(TestUtils.createEvent("form plugin"))); + .thenAnswer( + invocation -> { + Event event = invocation.getArgument(1); + return Maybe.just( + Event.builder() + .id(event.id()) + .invocationId(event.invocationId()) + .author("model") + .content(createContent("from plugin")) + .build()); + }); List events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin"); + assertThat(simplifyEvents(events)).containsExactly("model: from plugin"); verify(plugin).onEventCallback(any(), any()); } @@ -1686,4 +1700,109 @@ public void runner_executesSaveArtifactFlow() { // agent was run assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); } + + @Test + public void runAsync_ensuresSequentialConsistencyForTools() { + // Arrange + TestLlm testLlm = + createTestLlm( + createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), + createTextLlmResponse("Final response")); + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools( + ImmutableList.of( + FunctionTool.create(RaceConditionTools.class, "tool1"), + FunctionTool.create(RaceConditionTools.class, "tool2"))) + .build(); + + BaseSessionService delegate = new InMemorySessionService(); + BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 0); + + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .sessionService(delayedSessionService) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + // Act + var unused = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) + .toList() + .blockingGet(); + + // Assert + ImmutableList requests = ImmutableList.copyOf(testLlm.getRequests()); + assertThat(requests).hasSize(2); + + // Second request should contain the result of tool1 + LlmRequest secondRequest = requests.get(1); + List history = secondRequest.contents(); + + boolean foundToolResponse = + history.stream() + .flatMap(content -> content.parts().stream().flatMap(List::stream)) + .filter(part -> part.functionResponse().isPresent()) + .map(part -> part.functionResponse().get()) + .anyMatch( + response -> + response.name().orElse("").equals("tool1") + && response + .response() + .map( + r -> + java.util.Objects.equals( + r, ImmutableMap.of("result", "result_value1"))) + .orElse(false)); + + assertThat(foundToolResponse).isTrue(); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + private static BaseSessionService createDelayedSessionService( + BaseSessionService delegate, long delayMs) { + BaseSessionService delayedSessionService = mock(BaseSessionService.class); + when(delayedSessionService.createSession(anyString(), anyString(), any(Map.class), anyString())) + .thenAnswer( + inv -> + delegate.createSession( + (String) inv.getArgument(0), + (String) inv.getArgument(1), + (Map) inv.getArgument(2), + (String) inv.getArgument(3))); + when(delayedSessionService.createSession(anyString(), anyString())) + .thenAnswer( + inv -> + delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1))); + when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any())) + .thenAnswer( + inv -> + delegate.getSession( + (String) inv.getArgument(0), + (String) inv.getArgument(1), + (String) inv.getArgument(2), + (Optional) inv.getArgument(3))); + when(delayedSessionService.appendEvent(any(), any())) + .thenAnswer( + inv -> + delegate + .appendEvent(inv.getArgument(0), inv.getArgument(1)) + .delay(delayMs, MILLISECONDS)); + return delayedSessionService; + } + + public static class RaceConditionTools { + private RaceConditionTools() {} + + public static ImmutableMap tool1(String arg) { + return ImmutableMap.of("result", "result_" + arg); + } + + public static ImmutableMap tool2(String input) { + return ImmutableMap.of("status", "received_" + input); + } + } }