|
46 | 46 | import com.google.adk.artifacts.BaseArtifactService; |
47 | 47 | import com.google.adk.events.Event; |
48 | 48 | import com.google.adk.flows.llmflows.Functions; |
| 49 | +import com.google.adk.models.LlmRequest; |
49 | 50 | import com.google.adk.models.LlmResponse; |
50 | 51 | import com.google.adk.plugins.BasePlugin; |
51 | 52 | import com.google.adk.sessions.BaseSessionService; |
| 53 | +import com.google.adk.sessions.GetSessionConfig; |
| 54 | +import com.google.adk.sessions.InMemorySessionService; |
52 | 55 | import com.google.adk.sessions.Session; |
53 | 56 | import com.google.adk.sessions.SessionKey; |
54 | 57 | import com.google.adk.summarizer.EventsCompactionConfig; |
@@ -588,12 +591,22 @@ public void onToolErrorCallback_error() { |
588 | 591 | @Test |
589 | 592 | public void onEventCallback_success() { |
590 | 593 | when(plugin.onEventCallback(any(), any())) |
591 | | - .thenReturn(Maybe.just(TestUtils.createEvent("form plugin"))); |
| 594 | + .thenAnswer( |
| 595 | + invocation -> { |
| 596 | + Event event = invocation.getArgument(1); |
| 597 | + return Maybe.just( |
| 598 | + Event.builder() |
| 599 | + .id(event.id()) |
| 600 | + .invocationId(event.invocationId()) |
| 601 | + .author("model") |
| 602 | + .content(createContent("from plugin")) |
| 603 | + .build()); |
| 604 | + }); |
592 | 605 |
|
593 | 606 | List<Event> events = |
594 | 607 | runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); |
595 | 608 |
|
596 | | - assertThat(simplifyEvents(events)).containsExactly("author: content for event form plugin"); |
| 609 | + assertThat(simplifyEvents(events)).containsExactly("model: from plugin"); |
597 | 610 |
|
598 | 611 | verify(plugin).onEventCallback(any(), any()); |
599 | 612 | } |
@@ -1686,4 +1699,105 @@ public void runner_executesSaveArtifactFlow() { |
1686 | 1699 | // agent was run |
1687 | 1700 | assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); |
1688 | 1701 | } |
| 1702 | + |
| 1703 | + @Test |
| 1704 | + public void runAsync_ensuresSequentialConsistencyForTools() { |
| 1705 | + // Arrange |
| 1706 | + TestLlm testLlm = |
| 1707 | + createTestLlm( |
| 1708 | + createFunctionCallLlmResponse("call_1", "tool1", ImmutableMap.of("arg", "value1")), |
| 1709 | + createTextLlmResponse("Final response")); |
| 1710 | + |
| 1711 | + LlmAgent agent = |
| 1712 | + createTestAgentBuilder(testLlm) |
| 1713 | + .tools( |
| 1714 | + ImmutableList.of( |
| 1715 | + FunctionTool.create(RaceConditionTools.class, "tool1"), |
| 1716 | + FunctionTool.create(RaceConditionTools.class, "tool2"))) |
| 1717 | + .build(); |
| 1718 | + |
| 1719 | + BaseSessionService delegate = new InMemorySessionService(); |
| 1720 | + BaseSessionService delayedSessionService = createDelayedSessionService(delegate, 100); |
| 1721 | + |
| 1722 | + Runner runner = |
| 1723 | + Runner.builder() |
| 1724 | + .app(App.builder().name("test").rootAgent(agent).build()) |
| 1725 | + .sessionService(delayedSessionService) |
| 1726 | + .build(); |
| 1727 | + Session session = runner.sessionService().createSession("test", "user").blockingGet(); |
| 1728 | + |
| 1729 | + // Act |
| 1730 | + var unused = |
| 1731 | + runner |
| 1732 | + .runAsync("user", session.id(), Content.fromParts(Part.fromText("start"))) |
| 1733 | + .toList() |
| 1734 | + .blockingGet(); |
| 1735 | + |
| 1736 | + // Assert |
| 1737 | + ImmutableList<LlmRequest> requests = ImmutableList.copyOf(testLlm.getRequests()); |
| 1738 | + assertThat(requests).hasSize(2); |
| 1739 | + |
| 1740 | + // Second request should contain the result of tool1 |
| 1741 | + LlmRequest secondRequest = requests.get(1); |
| 1742 | + List<Content> history = secondRequest.contents(); |
| 1743 | + |
| 1744 | + boolean foundToolResponse = |
| 1745 | + history.stream() |
| 1746 | + .flatMap(content -> content.parts().stream().flatMap(List::stream)) |
| 1747 | + .filter(part -> part.functionResponse().isPresent()) |
| 1748 | + .map(part -> part.functionResponse().get()) |
| 1749 | + .anyMatch( |
| 1750 | + response -> |
| 1751 | + response.name().orElse("").equals("tool1") |
| 1752 | + && response |
| 1753 | + .response() |
| 1754 | + .orElse(null) |
| 1755 | + .equals(ImmutableMap.of("result", "result_value1"))); |
| 1756 | + |
| 1757 | + assertThat(foundToolResponse).isTrue(); |
| 1758 | + } |
| 1759 | + |
| 1760 | + private static BaseSessionService createDelayedSessionService( |
| 1761 | + BaseSessionService delegate, long delayMs) { |
| 1762 | + BaseSessionService delayedSessionService = mock(BaseSessionService.class); |
| 1763 | + when(delayedSessionService.createSession(anyString(), anyString(), any(), anyString())) |
| 1764 | + .thenAnswer( |
| 1765 | + inv -> |
| 1766 | + delegate.createSession( |
| 1767 | + inv.getArgument(0), |
| 1768 | + inv.getArgument(1), |
| 1769 | + inv.getArgument(2), |
| 1770 | + inv.getArgument(3))); |
| 1771 | + when(delayedSessionService.createSession(anyString(), anyString())) |
| 1772 | + .thenAnswer( |
| 1773 | + inv -> |
| 1774 | + delegate.createSession((String) inv.getArgument(0), (String) inv.getArgument(1))); |
| 1775 | + when(delayedSessionService.getSession(anyString(), anyString(), anyString(), any())) |
| 1776 | + .thenAnswer( |
| 1777 | + inv -> |
| 1778 | + delegate.getSession( |
| 1779 | + (String) inv.getArgument(0), |
| 1780 | + (String) inv.getArgument(1), |
| 1781 | + (String) inv.getArgument(2), |
| 1782 | + (Optional<GetSessionConfig>) inv.getArgument(3))); |
| 1783 | + when(delayedSessionService.appendEvent(any(), any())) |
| 1784 | + .thenAnswer( |
| 1785 | + inv -> |
| 1786 | + delegate |
| 1787 | + .appendEvent(inv.getArgument(0), inv.getArgument(1)) |
| 1788 | + .delay(delayMs, MILLISECONDS)); |
| 1789 | + return delayedSessionService; |
| 1790 | + } |
| 1791 | + |
| 1792 | + public static class RaceConditionTools { |
| 1793 | + private RaceConditionTools() {} |
| 1794 | + |
| 1795 | + public static ImmutableMap<String, Object> tool1(String arg) { |
| 1796 | + return ImmutableMap.of("result", "result_" + arg); |
| 1797 | + } |
| 1798 | + |
| 1799 | + public static ImmutableMap<String, Object> tool2(String input) { |
| 1800 | + return ImmutableMap.of("status", "received_" + input); |
| 1801 | + } |
| 1802 | + } |
1689 | 1803 | } |
0 commit comments