2323 Sequence ,
2424 Union ,
2525)
26+ import os
27+ import copy
2628
2729if TYPE_CHECKING :
2830 try :
@@ -250,53 +252,55 @@ def __init__(
250252 """
251253 from google .cloud .aiplatform import initializer
252254
253- # Set up llm config.
254- self ._project = initializer .global_config .project
255- self ._location = initializer .global_config .location
256- self ._model_name = model or "gemini-1.0-pro-001"
257- self ._api_type = api_type or "google"
258- self ._llm_config = llm_config or {
259- "config_list" : [
260- {
261- "project_id" : self ._project ,
262- "location" : self ._location ,
263- "model" : self ._model_name ,
264- "api_type" : self ._api_type ,
265- }
266- ]
255+ self ._tmpl_attrs : dict [str , Any ] = {
256+ "project" : initializer .global_config .project ,
257+ "location" : initializer .global_config .location ,
258+ "model_name" : model ,
259+ "api_type" : api_type or "google" ,
260+ "system_instruction" : system_instruction ,
261+ "runnable_name" : runnable_name ,
262+ "tools" : [],
263+ "ag2_tool_objects" : [],
264+ "runnable" : None ,
265+ "runnable_builder" : runnable_builder ,
266+ "instrumentor" : None ,
267+ "enable_tracing" : enable_tracing ,
268+ "provided_llm_config" : copy .deepcopy (llm_config ),
269+ "provided_runnable_kwargs" : copy .deepcopy (runnable_kwargs ),
267270 }
268- self ._system_instruction = system_instruction
269- self ._runnable_name = runnable_name
270- self ._runnable_kwargs = _prepare_runnable_kwargs (
271- runnable_kwargs = runnable_kwargs ,
272- llm_config = self ._llm_config ,
273- system_instruction = self ._system_instruction ,
274- runnable_name = self ._runnable_name ,
275- )
276-
277- self ._tools = []
278271 if tools :
279- # We validate tools at initialization for actionable feedback before
280- # they are deployed.
281272 _validate_tools (tools )
282- self ._tools = tools
283- self ._ag2_tool_objects = []
284- self ._runnable = None
285- self ._runnable_builder = runnable_builder
286-
287- self ._instrumentor = None
288- self ._enable_tracing = enable_tracing
273+ self ._tmpl_attrs ["tools" ] = tools
289274
290275 def set_up (self ):
291276 """Sets up the agent for execution of queries at runtime.
292277
293278 It initializes the runnable, binds the runnable with tools.
294-
295- This method should not be called for an object that being passed to
296- the ReasoningEngine service for deployment, as it initializes clients
297- that can not be serialized.
279+ Project and Location are sourced from environment variables.
298280 """
299- if self ._enable_tracing :
281+ project = os .environ .get ("GOOGLE_CLOUD_PROJECT" ) or self ._tmpl_attrs .get ("project" )
282+ location = os .environ .get ("GOOGLE_CLOUD_LOCATION" ) or self ._tmpl_attrs .get ("location" )
283+
284+ llm_config = {
285+ "config_list" : [
286+ {
287+ "project_id" : project ,
288+ "location" : location ,
289+ "model" : self ._tmpl_attrs .get ("model_name" ),
290+ "api_type" : self ._tmpl_attrs .get ("api_type" ),
291+ }
292+ ]
293+ }
294+ if self ._tmpl_attrs .get ("provided_llm_config" ):
295+ llm_config = self ._tmpl_attrs .get ("provided_llm_config" )
296+
297+ runnable_kwargs = _prepare_runnable_kwargs (
298+ runnable_kwargs = self ._tmpl_attrs .get ("provided_runnable_kwargs" ),
299+ llm_config = llm_config ,
300+ system_instruction = self ._tmpl_attrs .get ("system_instruction" ),
301+ runnable_name = self ._tmpl_attrs .get ("runnable_name" ),
302+ )
303+ if self ._tmpl_attrs .get ("enable_tracing" ):
300304 from vertexai .reasoning_engines import _utils
301305
302306 cloud_trace_exporter = _utils ._import_cloud_trace_exporter_or_warn ()
@@ -317,9 +321,9 @@ def set_up(self):
317321
318322 credentials , _ = google .auth .default ()
319323 span_exporter = cloud_trace_exporter .CloudTraceSpanExporter (
320- project_id = self . _project ,
324+ project_id = project ,
321325 client = cloud_trace_v2 .TraceServiceClient (
322- credentials = credentials .with_quota_project (self . _project ),
326+ credentials = credentials .with_quota_project (project ),
323327 ),
324328 )
325329 span_processor : SpanProcessor = (
@@ -381,34 +385,35 @@ def set_up(self):
381385 )
382386
383387 # Set up tools.
384- if self ._tools and not self ._ag2_tool_objects :
388+ tools = self ._tmpl_attrs .get ("tools" )
389+ ag2_tool_objects = self ._tmpl_attrs .get ("ag2_tool_objects" )
390+ if tools and not ag2_tool_objects :
385391 from vertexai .reasoning_engines import _utils
386392
387393 autogen_tools = _utils ._import_autogen_tools_or_warn ()
388394 if autogen_tools :
389- for tool in self . _tools :
390- self . _ag2_tool_objects .append (autogen_tools .Tool (func_or_tool = tool ))
395+ for tool in tools :
396+ ag2_tool_objects .append (autogen_tools .Tool (func_or_tool = tool ))
391397
392- # Set up runnable.
393- runnable_builder = self ._runnable_builder or _default_runnable_builder
394- self ._runnable = runnable_builder (
395- ** self ._runnable_kwargs ,
398+ runnable_builder = (
399+ self ._tmpl_attrs .get ("runnable_builder" ) or _default_runnable_builder
400+ )
401+ self ._tmpl_attrs ["runnable" ] = runnable_builder (
402+ ** runnable_kwargs
396403 )
397404
398405 def clone (self ) -> "AG2Agent" :
399406 """Returns a clone of the AG2Agent."""
400- import copy
401-
402407 return AG2Agent (
403- model = self ._model_name ,
404- api_type = self ._api_type ,
405- llm_config = copy .deepcopy (self ._llm_config ),
406- system_instruction = self ._system_instruction ,
407- runnable_name = self ._runnable_name ,
408- tools = copy .deepcopy (self ._tools ),
409- runnable_kwargs = copy .deepcopy (self ._runnable_kwargs ),
410- runnable_builder = self ._runnable_builder ,
411- enable_tracing = self ._enable_tracing ,
408+ model = self ._tmpl_attrs . get ( "model_name" ) ,
409+ api_type = self ._tmpl_attrs . get ( "api_type" ) ,
410+ llm_config = copy .deepcopy (self ._tmpl_attrs . get ( "provided_llm_config" ) ),
411+ system_instruction = self ._tmpl_attrs . get ( "system_instruction" ) ,
412+ runnable_name = self ._tmpl_attrs . get ( "runnable_name" ) ,
413+ tools = copy .deepcopy (self ._tmpl_attrs . get ( "tools" ) ),
414+ runnable_kwargs = copy .deepcopy (self ._tmpl_attrs . get ( "provided_runnable_kwargs" ) ),
415+ runnable_builder = self ._tmpl_attrs . get ( "runnable_builder" ) ,
416+ enable_tracing = self ._tmpl_attrs . get ( "enable_tracing" ) ,
412417 )
413418
414419 def query (
@@ -456,21 +461,21 @@ def query(
456461 )
457462 kwargs .pop ("user_input" )
458463
459- if not self ._runnable :
464+ if not self ._tmpl_attrs . get ( "runnable" ) :
460465 self .set_up ()
461466
467+ response = self ._tmpl_attrs .get ("runnable" ).run (
468+ message = input ,
469+ user_input = False ,
470+ tools = self ._tmpl_attrs .get ("ag2_tool_objects" ),
471+ max_turns = max_turns ,
472+ ** kwargs ,
473+ )
474+
462475 from vertexai .reasoning_engines import _utils
463476
464477 # `.run()` will return a `ChatResult` object, which is a dataclass.
465478 # We need to convert it to a JSON-serializable object.
466479 # More details of `ChatResult` can be found in
467480 # https://docs.ag2.ai/docs/api-reference/autogen/ChatResult.
468- return _utils .dataclass_to_dict (
469- self ._runnable .run (
470- input ,
471- user_input = False ,
472- tools = self ._ag2_tool_objects ,
473- max_turns = max_turns ,
474- ** kwargs ,
475- )
476- )
481+ return _utils .dataclass_to_dict (response )
0 commit comments