|
38 | 38 | import static org.mockito.Mockito.when; |
39 | 39 |
|
40 | 40 | import com.google.adk.agents.BaseAgent; |
| 41 | +import com.google.adk.agents.Callbacks.AfterModelCallback; |
41 | 42 | import com.google.adk.agents.InvocationContext; |
42 | 43 | import com.google.adk.agents.LiveRequestQueue; |
43 | 44 | import com.google.adk.agents.LlmAgent; |
|
47 | 48 | import com.google.adk.artifacts.BaseArtifactService; |
48 | 49 | import com.google.adk.events.Event; |
49 | 50 | import com.google.adk.flows.llmflows.Functions; |
| 51 | +import com.google.adk.models.LlmRequest; |
50 | 52 | import com.google.adk.models.LlmResponse; |
51 | 53 | import com.google.adk.plugins.BasePlugin; |
52 | 54 | import com.google.adk.sessions.BaseSessionService; |
| 55 | +import com.google.adk.sessions.GetSessionConfig; |
| 56 | +import com.google.adk.sessions.InMemorySessionService; |
| 57 | +import com.google.adk.sessions.ListEventsResponse; |
| 58 | +import com.google.adk.sessions.ListSessionsResponse; |
53 | 59 | import com.google.adk.sessions.Session; |
54 | 60 | import com.google.adk.sessions.SessionKey; |
55 | 61 | import com.google.adk.summarizer.EventsCompactionConfig; |
|
85 | 91 | import java.util.Optional; |
86 | 92 | import java.util.UUID; |
87 | 93 | import java.util.concurrent.ConcurrentHashMap; |
| 94 | +import java.util.concurrent.ConcurrentMap; |
88 | 95 | import java.util.concurrent.atomic.AtomicInteger; |
89 | 96 | import java.util.concurrent.atomic.AtomicReference; |
90 | 97 | import org.junit.After; |
@@ -860,6 +867,146 @@ public void runAsync_concurrentCalls_firstFails_secondSucceeds() throws Exceptio |
860 | 867 | subscriber2.assertValue(agentEvent); |
861 | 868 | } |
862 | 869 |
|
| 870 | + /** |
| 871 | + * A slow appendEvent must not let the next LLM step start with a stale session missing the |
| 872 | + * previous step's function-response event. |
| 873 | + */ |
| 874 | + @Test |
| 875 | + public void runAsync_slowAppendEvent_doesNotCauseStaleSessionInNextStep() throws Exception { |
| 876 | + TestLlm raceTestLlm = |
| 877 | + createTestLlm( |
| 878 | + createFunctionCallLlmResponse("call_1", echoTool.name(), ImmutableMap.of("arg", "v1")), |
| 879 | + createTextLlmResponse("done")); |
| 880 | + |
| 881 | + LlmAgent agentForRace = |
| 882 | + createTestAgentBuilder(raceTestLlm).tools(ImmutableList.of(echoTool)).build(); |
| 883 | + |
| 884 | + BaseSessionService delayedSessionService = |
| 885 | + new AppendDelayingSessionService(new InMemorySessionService(), 50); |
| 886 | + |
| 887 | + Runner runnerForRace = |
| 888 | + Runner.builder() |
| 889 | + .app(App.builder().name("test").rootAgent(agentForRace).build()) |
| 890 | + .sessionService(delayedSessionService) |
| 891 | + .build(); |
| 892 | + Session raceSession = |
| 893 | + runnerForRace.sessionService().createSession("test", "user").blockingGet(); |
| 894 | + |
| 895 | + var unused = |
| 896 | + runnerForRace |
| 897 | + .runAsync("user", raceSession.id(), createContent("start")) |
| 898 | + .toList() |
| 899 | + .blockingGet(); |
| 900 | + |
| 901 | + ImmutableList<LlmRequest> requests = ImmutableList.copyOf(raceTestLlm.getRequests()); |
| 902 | + assertThat(requests).hasSize(2); |
| 903 | + |
| 904 | + // Second LLM request must see the function response from step 1. |
| 905 | + boolean foundToolResponse = |
| 906 | + requests.get(1).contents().stream() |
| 907 | + .flatMap(c -> c.parts().stream().flatMap(List::stream)) |
| 908 | + .anyMatch(part -> part.functionResponse().isPresent()); |
| 909 | + assertThat(foundToolResponse).isTrue(); |
| 910 | + } |
| 911 | + |
| 912 | + /** |
| 913 | + * When an LlmAgent transfers control to a sub-LlmAgent, the sub-agent's events flow back up |
| 914 | + * through the parent's {@code BaseLlmFlow.run()} pipeline. Each event must be appended to the |
| 915 | + * session exactly once. |
| 916 | + */ |
| 917 | + @Test |
| 918 | + public void runAsync_transferToSubAgent_eventsAppendedOnce() throws Exception { |
| 919 | + LlmAgent subAgent = |
| 920 | + createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response"))) |
| 921 | + .name("sub-agent") |
| 922 | + .build(); |
| 923 | + |
| 924 | + // Force a transfer to sub-agent using an afterModelCallback. |
| 925 | + AfterModelCallback transferCallback = |
| 926 | + (ctx, response) -> { |
| 927 | + ctx.eventActions().setTransferToAgent(subAgent.name()); |
| 928 | + return Maybe.empty(); |
| 929 | + }; |
| 930 | + |
| 931 | + TestLlm rootTestLlm = createTestLlm(createTextLlmResponse("initial")); |
| 932 | + LlmAgent rootAgent = |
| 933 | + createTestAgentBuilder(rootTestLlm) |
| 934 | + .subAgents(subAgent) |
| 935 | + .afterModelCallback(ImmutableList.of(transferCallback)) |
| 936 | + .build(); |
| 937 | + |
| 938 | + Runner transferRunner = |
| 939 | + Runner.builder().app(App.builder().name("test").rootAgent(rootAgent).build()).build(); |
| 940 | + Session transferSession = |
| 941 | + transferRunner.sessionService().createSession("test", "user").blockingGet(); |
| 942 | + |
| 943 | + var unused = |
| 944 | + transferRunner |
| 945 | + .runAsync("user", transferSession.id(), createContent("start")) |
| 946 | + .toList() |
| 947 | + .blockingGet(); |
| 948 | + |
| 949 | + Session finalSession = |
| 950 | + transferRunner |
| 951 | + .sessionService() |
| 952 | + .getSession( |
| 953 | + transferSession.appName(), |
| 954 | + transferSession.userId(), |
| 955 | + transferSession.id(), |
| 956 | + Optional.empty()) |
| 957 | + .blockingGet(); |
| 958 | + |
| 959 | + // Each event id should appear at most once in the session. |
| 960 | + List<String> eventIds = finalSession.events().stream().map(Event::id).toList(); |
| 961 | + assertThat(eventIds).containsNoDuplicates(); |
| 962 | + } |
| 963 | + |
| 964 | + /** {@link BaseSessionService} that delays {@link #appendEvent} to surface ordering bugs. */ |
| 965 | + private static final class AppendDelayingSessionService implements BaseSessionService { |
| 966 | + private final BaseSessionService delegate; |
| 967 | + private final long appendDelayMs; |
| 968 | + |
| 969 | + AppendDelayingSessionService(BaseSessionService delegate, long appendDelayMs) { |
| 970 | + this.delegate = delegate; |
| 971 | + this.appendDelayMs = appendDelayMs; |
| 972 | + } |
| 973 | + |
| 974 | + // Delegates to the underlying BaseSessionService createSession overload, which is itself |
| 975 | + // deprecated; suppressed because the wrapper must preserve the same signature. |
| 976 | + @SuppressWarnings("deprecation") |
| 977 | + @Override |
| 978 | + public Single<Session> createSession( |
| 979 | + String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) { |
| 980 | + return delegate.createSession(appName, userId, state, sessionId); |
| 981 | + } |
| 982 | + |
| 983 | + @Override |
| 984 | + public Maybe<Session> getSession( |
| 985 | + String appName, String userId, String sessionId, Optional<GetSessionConfig> config) { |
| 986 | + return delegate.getSession(appName, userId, sessionId, config); |
| 987 | + } |
| 988 | + |
| 989 | + @Override |
| 990 | + public Single<ListSessionsResponse> listSessions(String appName, String userId) { |
| 991 | + return delegate.listSessions(appName, userId); |
| 992 | + } |
| 993 | + |
| 994 | + @Override |
| 995 | + public Completable deleteSession(String appName, String userId, String sessionId) { |
| 996 | + return delegate.deleteSession(appName, userId, sessionId); |
| 997 | + } |
| 998 | + |
| 999 | + @Override |
| 1000 | + public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) { |
| 1001 | + return delegate.listEvents(appName, userId, sessionId); |
| 1002 | + } |
| 1003 | + |
| 1004 | + @Override |
| 1005 | + public Single<Event> appendEvent(Session session, Event event) { |
| 1006 | + return delegate.appendEvent(session, event).delay(appendDelayMs, MILLISECONDS); |
| 1007 | + } |
| 1008 | + } |
| 1009 | + |
863 | 1010 | @Test |
864 | 1011 | public void runAsync_withSessionKey_success() { |
865 | 1012 | var events = |
|
0 commit comments