Skip to content

Commit 02a08a1

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Support plugins in Java AgentTool similar to Python's implementation
PiperOrigin-RevId: 904686960
1 parent 4009905 commit 02a08a1

2 files changed

Lines changed: 189 additions & 6 deletions

File tree

core/src/main/java/com/google/adk/tools/AgentTool.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2727
import com.google.adk.agents.LlmAgent;
2828
import com.google.adk.events.Event;
29+
import com.google.adk.plugins.Plugin;
2930
import com.google.adk.runner.InMemoryRunner;
3031
import com.google.adk.runner.Runner;
3132
import com.google.adk.sessions.State;
@@ -46,6 +47,7 @@ public class AgentTool extends BaseTool {
4647

4748
private final BaseAgent agent;
4849
private final boolean skipSummarization;
50+
private final boolean includePlugins;
4951

5052
public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
5153
throws ConfigurationException {
@@ -62,21 +64,34 @@ public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
6264
}
6365

6466
BaseAgent agent = resolvedAgents.get(0);
65-
return AgentTool.create(agent, args.getOrDefault("skipSummarization", false).booleanValue());
67+
return AgentTool.create(
68+
agent,
69+
args.getOrDefault("skipSummarization", false).booleanValue(),
70+
args.getOrDefault("includePlugins", false).booleanValue());
71+
}
72+
73+
public static AgentTool create(
74+
BaseAgent agent, boolean skipSummarization, boolean includePlugins) {
75+
return new AgentTool(agent, skipSummarization, includePlugins);
6676
}
6777

6878
public static AgentTool create(BaseAgent agent, boolean skipSummarization) {
69-
return new AgentTool(agent, skipSummarization);
79+
return new AgentTool(agent, skipSummarization, /* includePlugins= */ false);
7080
}
7181

7282
public static AgentTool create(BaseAgent agent) {
73-
return new AgentTool(agent, false);
83+
return new AgentTool(agent, /* skipSummarization= */ false, /* includePlugins= */ false);
7484
}
7585

7686
protected AgentTool(BaseAgent agent, boolean skipSummarization) {
87+
this(agent, skipSummarization, /* includePlugins= */ false);
88+
}
89+
90+
protected AgentTool(BaseAgent agent, boolean skipSummarization, boolean includePlugins) {
7791
super(agent.name(), agent.description());
7892
this.agent = agent;
7993
this.skipSummarization = skipSummarization;
94+
this.includePlugins = includePlugins;
8095
}
8196

8297
@VisibleForTesting
@@ -159,9 +174,11 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
159174
content = Content.fromParts(Part.fromText(input.toString()));
160175
}
161176

162-
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName());
163-
// Session state is final, can't update to toolContext state
164-
// session.toBuilder().setState(toolContext.getState());
177+
ImmutableList<Plugin> plugins =
178+
this.includePlugins
179+
? ImmutableList.of(toolContext.invocationContext().pluginManager())
180+
: ImmutableList.of();
181+
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName(), plugins);
165182
return runner
166183
.sessionService()
167184
.createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null)

core/src/test/java/com/google/adk/tools/AgentToolTest.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import com.google.adk.agents.LlmAgent;
2929
import com.google.adk.agents.SequentialAgent;
3030
import com.google.adk.models.LlmResponse;
31+
import com.google.adk.plugins.Plugin;
32+
import com.google.adk.plugins.PluginManager;
3133
import com.google.adk.sessions.InMemorySessionService;
3234
import com.google.adk.sessions.Session;
3335
import com.google.adk.testing.TestLlm;
@@ -41,6 +43,7 @@
4143
import io.reactivex.rxjava3.core.Flowable;
4244
import io.reactivex.rxjava3.core.Maybe;
4345
import java.util.Map;
46+
import java.util.concurrent.atomic.AtomicBoolean;
4447
import org.junit.Before;
4548
import org.junit.Test;
4649
import org.junit.runner.RunWith;
@@ -704,6 +707,169 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() {
704707
.build());
705708
}
706709

710+
@Test
711+
public void call_withIncludePluginsTrue_propagatesPlugins() throws Exception {
712+
AtomicBoolean callbackCalled = new AtomicBoolean(false);
713+
Plugin mockPlugin =
714+
new Plugin() {
715+
@Override
716+
public String getName() {
717+
return "mock_plugin";
718+
}
719+
720+
@Override
721+
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
722+
callbackCalled.set(true);
723+
return Maybe.empty();
724+
}
725+
};
726+
LlmAgent testAgent =
727+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
728+
.name("agent_name")
729+
.description("agent description")
730+
.build();
731+
AgentTool agentTool =
732+
AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ true);
733+
Session session =
734+
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
735+
InvocationContext invocationContext =
736+
InvocationContext.builder()
737+
.invocationId(InvocationContext.newInvocationContextId())
738+
.agent(testAgent)
739+
.session(session)
740+
.sessionService(sessionService)
741+
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
742+
.build();
743+
ToolContext toolContext = ToolContext.builder(invocationContext).build();
744+
745+
Map<String, Object> unused =
746+
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();
747+
748+
assertThat(callbackCalled.get()).isTrue();
749+
}
750+
751+
@Test
752+
public void call_withIncludePluginsFalse_doesNotPropagatePlugins() throws Exception {
753+
AtomicBoolean callbackCalled = new AtomicBoolean(false);
754+
Plugin mockPlugin =
755+
new Plugin() {
756+
@Override
757+
public String getName() {
758+
return "mock_plugin";
759+
}
760+
761+
@Override
762+
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
763+
callbackCalled.set(true);
764+
return Maybe.empty();
765+
}
766+
};
767+
LlmAgent testAgent =
768+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
769+
.name("agent_name")
770+
.description("agent description")
771+
.build();
772+
AgentTool agentTool =
773+
AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ false);
774+
Session session =
775+
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
776+
InvocationContext invocationContext =
777+
InvocationContext.builder()
778+
.invocationId(InvocationContext.newInvocationContextId())
779+
.agent(testAgent)
780+
.session(session)
781+
.sessionService(sessionService)
782+
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
783+
.build();
784+
ToolContext toolContext = ToolContext.builder(invocationContext).build();
785+
786+
Map<String, Object> unused =
787+
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();
788+
789+
assertThat(callbackCalled.get()).isFalse();
790+
}
791+
792+
@Test
793+
public void call_createWithAgentOnly_defaultsIncludePluginsToFalse() throws Exception {
794+
AtomicBoolean callbackCalled = new AtomicBoolean(false);
795+
Plugin mockPlugin =
796+
new Plugin() {
797+
@Override
798+
public String getName() {
799+
return "mock_plugin";
800+
}
801+
802+
@Override
803+
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
804+
callbackCalled.set(true);
805+
return Maybe.empty();
806+
}
807+
};
808+
LlmAgent testAgent =
809+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
810+
.name("agent_name")
811+
.description("agent description")
812+
.build();
813+
AgentTool agentTool = AgentTool.create(testAgent);
814+
Session session =
815+
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
816+
InvocationContext invocationContext =
817+
InvocationContext.builder()
818+
.invocationId(InvocationContext.newInvocationContextId())
819+
.agent(testAgent)
820+
.session(session)
821+
.sessionService(sessionService)
822+
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
823+
.build();
824+
ToolContext toolContext = ToolContext.builder(invocationContext).build();
825+
826+
Map<String, Object> unused =
827+
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();
828+
829+
assertThat(callbackCalled.get()).isFalse();
830+
}
831+
832+
@Test
833+
public void call_createWithAgentAndSkipSummarization_defaultsIncludePluginsToFalse()
834+
throws Exception {
835+
AtomicBoolean callbackCalled = new AtomicBoolean(false);
836+
Plugin mockPlugin =
837+
new Plugin() {
838+
@Override
839+
public String getName() {
840+
return "mock_plugin";
841+
}
842+
843+
@Override
844+
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
845+
callbackCalled.set(true);
846+
return Maybe.empty();
847+
}
848+
};
849+
LlmAgent testAgent =
850+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
851+
.name("agent_name")
852+
.description("agent description")
853+
.build();
854+
AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true);
855+
Session session =
856+
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
857+
InvocationContext invocationContext =
858+
InvocationContext.builder()
859+
.invocationId(InvocationContext.newInvocationContextId())
860+
.agent(testAgent)
861+
.session(session)
862+
.sessionService(sessionService)
863+
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
864+
.build();
865+
ToolContext toolContext = ToolContext.builder(invocationContext).build();
866+
867+
Map<String, Object> unused =
868+
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();
869+
870+
assertThat(callbackCalled.get()).isFalse();
871+
}
872+
707873
private ToolContext createToolContext(BaseAgent agent) {
708874
Session session =
709875
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();

0 commit comments

Comments
 (0)