Skip to content

Commit bea1638

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Refactor AG2Agent and ADK templates to use environment variables for project/location.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#6596 from googleapis:release-please--branches--main b82c8bd PiperOrigin-RevId: 901257696
1 parent 3c55f26 commit bea1638

5 files changed

Lines changed: 131 additions & 122 deletions

File tree

tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ def test_initialization(self):
150150
agent = reasoning_engines.AG2Agent(
151151
model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME
152152
)
153-
assert agent._model_name == _TEST_MODEL
154-
assert agent._runnable_name == _TEST_RUNNABLE_NAME
155-
assert agent._project == _TEST_PROJECT
156-
assert agent._location == _TEST_LOCATION
157-
assert agent._runnable is None
153+
assert agent._tmpl_attrs["model_name"] == _TEST_MODEL
154+
assert agent._tmpl_attrs["runnable_name"] == _TEST_RUNNABLE_NAME
155+
assert agent._tmpl_attrs["project"] == _TEST_PROJECT
156+
assert agent._tmpl_attrs["location"] == _TEST_LOCATION
157+
assert agent._tmpl_attrs["runnable"] is None
158158

159159
def test_initialization_with_tools(self, autogen_tools_mock):
160160
tools = [
@@ -168,22 +168,22 @@ def test_initialization_with_tools(self, autogen_tools_mock):
168168
tools=tools,
169169
runnable_builder=lambda **kwargs: kwargs,
170170
)
171-
assert agent._runnable is None
172-
assert agent._tools
173-
assert not agent._ag2_tool_objects
171+
assert agent._tmpl_attrs["runnable"] is None
172+
assert agent._tmpl_attrs["tools"]
173+
assert not agent._tmpl_attrs["ag2_tool_objects"]
174174
agent.set_up()
175-
assert agent._runnable is not None
176-
assert agent._ag2_tool_objects
175+
assert agent._tmpl_attrs["runnable"] is not None
176+
assert agent._tmpl_attrs["ag2_tool_objects"]
177177

178178
def test_set_up(self):
179179
agent = reasoning_engines.AG2Agent(
180180
model=_TEST_MODEL,
181181
runnable_name=_TEST_RUNNABLE_NAME,
182182
runnable_builder=lambda **kwargs: kwargs,
183183
)
184-
assert agent._runnable is None
184+
assert agent._tmpl_attrs["runnable"] is None
185185
agent.set_up()
186-
assert agent._runnable is not None
186+
assert agent._tmpl_attrs["runnable"] is not None
187187

188188
def test_clone(self):
189189
agent = reasoning_engines.AG2Agent(
@@ -192,26 +192,26 @@ def test_clone(self):
192192
runnable_builder=lambda **kwargs: kwargs,
193193
)
194194
agent.set_up()
195-
assert agent._runnable is not None
195+
assert agent._tmpl_attrs["runnable"] is not None
196196
agent_clone = agent.clone()
197-
assert agent._runnable is not None
198-
assert agent_clone._runnable is None
197+
assert agent._tmpl_attrs["runnable"] is not None
198+
assert agent_clone._tmpl_attrs["runnable"] is None
199199
agent_clone.set_up()
200-
assert agent_clone._runnable is not None
200+
assert agent_clone._tmpl_attrs["runnable"] is not None
201201

202202
def test_query(self, dataclasses_asdict_mock):
203203
agent = reasoning_engines.AG2Agent(
204204
model=_TEST_MODEL,
205205
runnable_name=_TEST_RUNNABLE_NAME,
206206
)
207-
agent._runnable = mock.Mock()
207+
agent._tmpl_attrs["runnable"] = mock.Mock()
208208
mocks = mock.Mock()
209-
mocks.attach_mock(mock=agent._runnable, attribute="run")
209+
mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="run")
210210
agent.query(input="test query")
211211
mocks.assert_has_calls(
212212
[
213213
mock.call.run.run(
214-
{"content": "test query"},
214+
message={"content": "test query"},
215215
user_input=False,
216216
tools=[],
217217
max_turns=None,
@@ -233,10 +233,10 @@ def test_enable_tracing(
233233
runnable_name=_TEST_RUNNABLE_NAME,
234234
enable_tracing=True,
235235
)
236-
assert agent._instrumentor is None
236+
assert agent._tmpl_attrs["instrumentor"] is None
237237
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
238238
# agent.set_up()
239-
# assert agent._instrumentor is not None
239+
# assert agent._tmpl_attrs["instrumentor"] is not None
240240
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
241241

242242
@pytest.mark.usefixtures("caplog")
@@ -246,7 +246,7 @@ def test_enable_tracing_warning(self, caplog, autogen_instrumentor_none_mock):
246246
runnable_name=_TEST_RUNNABLE_NAME,
247247
enable_tracing=True,
248248
)
249-
assert agent._instrumentor is None
249+
assert agent._tmpl_attrs["instrumentor"] is None
250250
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
251251
# agent.set_up()
252252
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text

vertexai/agent_engines/templates/adk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,10 +811,10 @@ def set_up(self):
811811
# to disable bound token sharing.
812812
os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = "false"
813813
# --- END BOUND TOKEN PATCH ---
814-
project = self._tmpl_attrs.get("project")
814+
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
815815
if project:
816816
os.environ["GOOGLE_CLOUD_PROJECT"] = project
817-
location = self._tmpl_attrs.get("location")
817+
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
818818
if location:
819819
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
820820
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location

vertexai/agent_engines/templates/ag2.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Sequence,
2424
Union,
2525
)
26+
import os
27+
import copy
2628

2729
if TYPE_CHECKING:
2830
try:
@@ -351,45 +353,49 @@ def __init__(
351353
"instrumentor": None,
352354
"instrumentor_builder": instrumentor_builder,
353355
"enable_tracing": enable_tracing,
356+
"provided_llm_config": copy.deepcopy(llm_config),
357+
"provided_runnable_kwargs": copy.deepcopy(runnable_kwargs),
354358
}
355-
self._tmpl_attrs["llm_config"] = llm_config or {
356-
"config_list": [
357-
{
358-
"project_id": self._tmpl_attrs.get("project"),
359-
"location": self._tmpl_attrs.get("location"),
360-
"model": self._tmpl_attrs.get("model_name"),
361-
"api_type": self._tmpl_attrs.get("api_type"),
362-
}
363-
]
364-
}
365-
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
366-
runnable_kwargs=runnable_kwargs,
367-
llm_config=self._tmpl_attrs.get("llm_config"),
368-
system_instruction=self._tmpl_attrs.get("system_instruction"),
369-
runnable_name=self._tmpl_attrs.get("runnable_name"),
370-
)
371359
if tools:
372-
# We validate tools at initialization for actionable feedback before
373-
# they are deployed.
374360
_validate_tools(tools)
375361
self._tmpl_attrs["tools"] = tools
376362

377363
def set_up(self):
378364
"""Sets up the agent for execution of queries at runtime.
379365
380366
It initializes the runnable, binds the runnable with tools.
381-
382-
This method should not be called for an object that being passed to
383-
the ReasoningEngine service for deployment, as it initializes clients
384-
that can not be serialized.
367+
Project and Location are sourced from environment variables.
385368
"""
369+
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
370+
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
371+
372+
llm_config = {
373+
"config_list": [
374+
{
375+
"project_id": project,
376+
"location": location,
377+
"model": self._tmpl_attrs.get("model_name"),
378+
"api_type": self._tmpl_attrs.get("api_type"),
379+
}
380+
]
381+
}
382+
if self._tmpl_attrs.get("provided_llm_config"):
383+
llm_config = self._tmpl_attrs.get("provided_llm_config")
384+
385+
runnable_kwargs = _prepare_runnable_kwargs(
386+
runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"),
387+
llm_config=llm_config,
388+
system_instruction=self._tmpl_attrs.get("system_instruction"),
389+
runnable_name=self._tmpl_attrs.get("runnable_name"),
390+
)
391+
386392
if self._tmpl_attrs.get("enable_tracing"):
387393
instrumentor_builder = (
388394
self._tmpl_attrs.get("instrumentor_builder")
389395
or _default_instrumentor_builder
390396
)
391397
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
392-
project_id=self._tmpl_attrs.get("project")
398+
project_id=project,
393399
)
394400

395401
# Set up tools.
@@ -408,21 +414,20 @@ def set_up(self):
408414
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
409415
)
410416
self._tmpl_attrs["runnable"] = runnable_builder(
411-
**self._tmpl_attrs.get("runnable_kwargs")
417+
**runnable_kwargs
412418
)
413419

414420
def clone(self) -> "AG2Agent":
415421
"""Returns a clone of the AG2Agent."""
416-
import copy
417422

418423
return AG2Agent(
419424
model=self._tmpl_attrs.get("model_name"),
420425
api_type=self._tmpl_attrs.get("api_type"),
421-
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
426+
llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")),
422427
system_instruction=self._tmpl_attrs.get("system_instruction"),
423428
runnable_name=self._tmpl_attrs.get("runnable_name"),
424429
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
425-
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
430+
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")),
426431
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
427432
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
428433
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,8 @@ def set_up(self):
725725
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
726726

727727
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
728-
project = self._tmpl_attrs.get("project")
729-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
730-
location = self._tmpl_attrs.get("location")
728+
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
729+
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
731730
if location:
732731
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
733732
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location

0 commit comments

Comments
 (0)