Skip to content

Commit 882a4a7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Refactor AG2Agent and ADK templates to use environment variables for project/location.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#6596 from googleapis:release-please--branches--main b82c8bd PiperOrigin-RevId: 901257696
1 parent 3c55f26 commit 882a4a7

4 files changed

Lines changed: 111 additions & 100 deletions

File tree

vertexai/agent_engines/templates/adk.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,10 +811,12 @@ def set_up(self):
811811
# to disable bound token sharing.
812812
os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = "false"
813813
# --- END BOUND TOKEN PATCH ---
814-
project = self._tmpl_attrs.get("project")
814+
project = os.environ["GOOGLE_CLOUD_PROJECT"] or self._tmpl_attrs.get("project")
815815
if project:
816816
os.environ["GOOGLE_CLOUD_PROJECT"] = project
817-
location = self._tmpl_attrs.get("location")
817+
location = os.environ["GOOGLE_CLOUD_LOCATION"] or self._tmpl_attrs.get("location")
818+
if location and "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
819+
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location
818820
if location:
819821
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
820822
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location

vertexai/agent_engines/templates/ag2.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Sequence,
2424
Union,
2525
)
26+
import os
27+
import copy
2628

2729
if TYPE_CHECKING:
2830
try:
@@ -351,45 +353,49 @@ def __init__(
351353
"instrumentor": None,
352354
"instrumentor_builder": instrumentor_builder,
353355
"enable_tracing": enable_tracing,
356+
"provided_llm_config": copy.deepcopy(llm_config),
357+
"provided_runnable_kwargs": copy.deepcopy(runnable_kwargs),
354358
}
355-
self._tmpl_attrs["llm_config"] = llm_config or {
356-
"config_list": [
357-
{
358-
"project_id": self._tmpl_attrs.get("project"),
359-
"location": self._tmpl_attrs.get("location"),
360-
"model": self._tmpl_attrs.get("model_name"),
361-
"api_type": self._tmpl_attrs.get("api_type"),
362-
}
363-
]
364-
}
365-
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
366-
runnable_kwargs=runnable_kwargs,
367-
llm_config=self._tmpl_attrs.get("llm_config"),
368-
system_instruction=self._tmpl_attrs.get("system_instruction"),
369-
runnable_name=self._tmpl_attrs.get("runnable_name"),
370-
)
371359
if tools:
372-
# We validate tools at initialization for actionable feedback before
373-
# they are deployed.
374360
_validate_tools(tools)
375361
self._tmpl_attrs["tools"] = tools
376362

377363
def set_up(self):
378364
"""Sets up the agent for execution of queries at runtime.
379365
380366
It initializes the runnable, binds the runnable with tools.
381-
382-
This method should not be called for an object that being passed to
383-
the ReasoningEngine service for deployment, as it initializes clients
384-
that can not be serialized.
367+
Project and Location are sourced from environment variables.
385368
"""
369+
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
370+
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
371+
372+
llm_config = {
373+
"config_list": [
374+
{
375+
"project_id": project,
376+
"location": location,
377+
"model": self._tmpl_attrs.get("model_name"),
378+
"api_type": self._tmpl_attrs.get("api_type"),
379+
}
380+
]
381+
}
382+
if self._tmpl_attrs.get("provided_llm_config"):
383+
llm_config = self._tmpl_attrs.get("provided_llm_config")
384+
385+
runnable_kwargs = _prepare_runnable_kwargs(
386+
runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"),
387+
llm_config=llm_config,
388+
system_instruction=self._tmpl_attrs.get("system_instruction"),
389+
runnable_name=self._tmpl_attrs.get("runnable_name"),
390+
)
391+
386392
if self._tmpl_attrs.get("enable_tracing"):
387393
instrumentor_builder = (
388394
self._tmpl_attrs.get("instrumentor_builder")
389395
or _default_instrumentor_builder
390396
)
391397
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
392-
project_id=self._tmpl_attrs.get("project")
398+
project_id=project,
393399
)
394400

395401
# Set up tools.
@@ -408,21 +414,20 @@ def set_up(self):
408414
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
409415
)
410416
self._tmpl_attrs["runnable"] = runnable_builder(
411-
**self._tmpl_attrs.get("runnable_kwargs")
417+
**runnable_kwargs
412418
)
413419

414420
def clone(self) -> "AG2Agent":
415421
"""Returns a clone of the AG2Agent."""
416-
import copy
417422

418423
return AG2Agent(
419424
model=self._tmpl_attrs.get("model_name"),
420425
api_type=self._tmpl_attrs.get("api_type"),
421-
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
426+
llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")),
422427
system_instruction=self._tmpl_attrs.get("system_instruction"),
423428
runnable_name=self._tmpl_attrs.get("runnable_name"),
424429
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
425-
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
430+
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")),
426431
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
427432
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
428433
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,8 @@ def set_up(self):
725725
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
726726

727727
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
728-
project = self._tmpl_attrs.get("project")
729-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
730-
location = self._tmpl_attrs.get("location")
728+
project = os.environ["GOOGLE_CLOUD_PROJECT"] or self._tmpl_attrs.get("project")
729+
location = os.environ["GOOGLE_CLOUD_LOCATION"] or self._tmpl_attrs.get("location")
731730
if location:
732731
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
733732
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location

vertexai/preview/reasoning_engines/templates/ag2.py

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Sequence,
2424
Union,
2525
)
26+
import os
27+
import copy
2628

2729
if TYPE_CHECKING:
2830
try:
@@ -250,53 +252,55 @@ def __init__(
250252
"""
251253
from google.cloud.aiplatform import initializer
252254

253-
# Set up llm config.
254-
self._project = initializer.global_config.project
255-
self._location = initializer.global_config.location
256-
self._model_name = model or "gemini-1.0-pro-001"
257-
self._api_type = api_type or "google"
258-
self._llm_config = llm_config or {
259-
"config_list": [
260-
{
261-
"project_id": self._project,
262-
"location": self._location,
263-
"model": self._model_name,
264-
"api_type": self._api_type,
265-
}
266-
]
255+
self._tmpl_attrs: dict[str, Any] = {
256+
"project": initializer.global_config.project,
257+
"location": initializer.global_config.location,
258+
"model_name": model,
259+
"api_type": api_type or "google",
260+
"system_instruction": system_instruction,
261+
"runnable_name": runnable_name,
262+
"tools": [],
263+
"ag2_tool_objects": [],
264+
"runnable": None,
265+
"runnable_builder": runnable_builder,
266+
"instrumentor": None,
267+
"enable_tracing": enable_tracing,
268+
"provided_llm_config": copy.deepcopy(llm_config),
269+
"provided_runnable_kwargs": copy.deepcopy(runnable_kwargs),
267270
}
268-
self._system_instruction = system_instruction
269-
self._runnable_name = runnable_name
270-
self._runnable_kwargs = _prepare_runnable_kwargs(
271-
runnable_kwargs=runnable_kwargs,
272-
llm_config=self._llm_config,
273-
system_instruction=self._system_instruction,
274-
runnable_name=self._runnable_name,
275-
)
276-
277-
self._tools = []
278271
if tools:
279-
# We validate tools at initialization for actionable feedback before
280-
# they are deployed.
281272
_validate_tools(tools)
282-
self._tools = tools
283-
self._ag2_tool_objects = []
284-
self._runnable = None
285-
self._runnable_builder = runnable_builder
286-
287-
self._instrumentor = None
288-
self._enable_tracing = enable_tracing
273+
self._tmpl_attrs["tools"] = tools
289274

290275
def set_up(self):
291276
"""Sets up the agent for execution of queries at runtime.
292277
293278
It initializes the runnable, binds the runnable with tools.
294-
295-
This method should not be called for an object that being passed to
296-
the ReasoningEngine service for deployment, as it initializes clients
297-
that can not be serialized.
279+
Project and Location are sourced from environment variables.
298280
"""
299-
if self._enable_tracing:
281+
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
282+
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
283+
284+
llm_config = {
285+
"config_list": [
286+
{
287+
"project_id": project,
288+
"location": location,
289+
"model": self._tmpl_attrs.get("model_name"),
290+
"api_type": self._tmpl_attrs.get("api_type"),
291+
}
292+
]
293+
}
294+
if self._tmpl_attrs.get("provided_llm_config"):
295+
llm_config = self._tmpl_attrs.get("provided_llm_config")
296+
297+
runnable_kwargs = _prepare_runnable_kwargs(
298+
runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"),
299+
llm_config=llm_config,
300+
system_instruction=self._tmpl_attrs.get("system_instruction"),
301+
runnable_name=self._tmpl_attrs.get("runnable_name"),
302+
)
303+
if self._tmpl_attrs.get("enable_tracing"):
300304
from vertexai.reasoning_engines import _utils
301305

302306
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
@@ -317,9 +321,9 @@ def set_up(self):
317321

318322
credentials, _ = google.auth.default()
319323
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
320-
project_id=self._project,
324+
project_id=project,
321325
client=cloud_trace_v2.TraceServiceClient(
322-
credentials=credentials.with_quota_project(self._project),
326+
credentials=credentials.with_quota_project(project),
323327
),
324328
)
325329
span_processor: SpanProcessor = (
@@ -381,34 +385,35 @@ def set_up(self):
381385
)
382386

383387
# Set up tools.
384-
if self._tools and not self._ag2_tool_objects:
388+
tools = self._tmpl_attrs.get("tools")
389+
ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects")
390+
if tools and not ag2_tool_objects:
385391
from vertexai.reasoning_engines import _utils
386392

387393
autogen_tools = _utils._import_autogen_tools_or_warn()
388394
if autogen_tools:
389-
for tool in self._tools:
390-
self._ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
395+
for tool in tools:
396+
ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
391397

392-
# Set up runnable.
393-
runnable_builder = self._runnable_builder or _default_runnable_builder
394-
self._runnable = runnable_builder(
395-
**self._runnable_kwargs,
398+
runnable_builder = (
399+
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
400+
)
401+
self._tmpl_attrs["runnable"] = runnable_builder(
402+
**runnable_kwargs
396403
)
397404

398405
def clone(self) -> "AG2Agent":
399406
"""Returns a clone of the AG2Agent."""
400-
import copy
401-
402407
return AG2Agent(
403-
model=self._model_name,
404-
api_type=self._api_type,
405-
llm_config=copy.deepcopy(self._llm_config),
406-
system_instruction=self._system_instruction,
407-
runnable_name=self._runnable_name,
408-
tools=copy.deepcopy(self._tools),
409-
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
410-
runnable_builder=self._runnable_builder,
411-
enable_tracing=self._enable_tracing,
408+
model=self._tmpl_attrs.get("model_name"),
409+
api_type=self._tmpl_attrs.get("api_type"),
410+
llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")),
411+
system_instruction=self._tmpl_attrs.get("system_instruction"),
412+
runnable_name=self._tmpl_attrs.get("runnable_name"),
413+
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
414+
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")),
415+
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
416+
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
412417
)
413418

414419
def query(
@@ -456,21 +461,21 @@ def query(
456461
)
457462
kwargs.pop("user_input")
458463

459-
if not self._runnable:
464+
if not self._tmpl_attrs.get("runnable"):
460465
self.set_up()
461466

467+
response = self._tmpl_attrs.get("runnable").run(
468+
message=input,
469+
user_input=False,
470+
tools=self._tmpl_attrs.get("ag2_tool_objects"),
471+
max_turns=max_turns,
472+
**kwargs,
473+
)
474+
462475
from vertexai.reasoning_engines import _utils
463476

464477
# `.run()` will return a `ChatResult` object, which is a dataclass.
465478
# We need to convert it to a JSON-serializable object.
466479
# More details of `ChatResult` can be found in
467480
# https://docs.ag2.ai/docs/api-reference/autogen/ChatResult.
468-
return _utils.dataclass_to_dict(
469-
self._runnable.run(
470-
input,
471-
user_input=False,
472-
tools=self._ag2_tool_objects,
473-
max_turns=max_turns,
474-
**kwargs,
475-
)
476-
)
481+
return _utils.dataclass_to_dict(response)

0 commit comments

Comments
 (0)