@@ -87,7 +87,8 @@ def create_agent_card(
8787 provided.
8888 """
8989 # pylint: disable=g-import-not-at-top
90- from a2a .types import AgentCard , AgentCapabilities , TransportProtocol
90+ from a2a .types import AgentCard , AgentCapabilities , AgentInterface
91+ from a2a .utils .constants import TransportProtocol , PROTOCOL_VERSION_CURRENT
9192
9293 # Check if a dictionary was provided.
9394 if agent_card :
@@ -98,14 +99,20 @@ def create_agent_card(
9899 return AgentCard (
99100 name = agent_name ,
100101 description = description ,
101- url = "http://localhost:9999/" ,
102102 version = "1.0.0" ,
103103 default_input_modes = default_input_modes or ["text/plain" ],
104104 default_output_modes = default_output_modes or ["application/json" ],
105- capabilities = AgentCapabilities (streaming = streaming ),
105+ capabilities = AgentCapabilities (
106+ streaming = streaming , extended_agent_card = True
107+ ),
106108 skills = skills ,
107- preferred_transport = TransportProtocol .http_json , # Http Only.
108- supports_authenticated_extended_card = True ,
109+ supported_interfaces = [
110+ AgentInterface (
111+ url = "http://localhost:9999/" ,
112+ protocol_binding = TransportProtocol .HTTP_JSON ,
113+ protocol_version = PROTOCOL_VERSION_CURRENT ,
114+ )
115+ ],
109116 )
110117
111118 # Raise an error if insufficient data is provided.
@@ -162,6 +169,21 @@ async def cancel(
162169 )
163170
164171
172+ def _is_version_enabled (agent_card : "AgentCard" , version : str ) -> bool :
173+ """Checks if a specific version compatibility should be enabled for the A2aAgent."""
174+ from a2a .utils .constants import TransportProtocol
175+
176+ if not agent_card .supported_interfaces :
177+ return False
178+ for interface in agent_card .supported_interfaces :
179+ if (
180+ interface .protocol_version == version
181+ and interface .protocol_binding == TransportProtocol .HTTP_JSON
182+ ):
183+ return True
184+ return False
185+
186+
165187class A2aAgent :
166188 """A class to initialize and set up an Agent-to-Agent application."""
167189
@@ -181,14 +203,15 @@ def __init__(
181203 """Initializes the A2A agent."""
182204 # pylint: disable=g-import-not-at-top
183205 from google .cloud .aiplatform import initializer
184- from a2a .types import TransportProtocol
206+ from a2a .utils . constants import TransportProtocol
185207
186208 if (
187- agent_card .preferred_transport
188- and agent_card .preferred_transport != TransportProtocol .http_json
209+ agent_card .supported_interfaces
210+ and agent_card .supported_interfaces [0 ].protocol_binding
211+ != TransportProtocol .HTTP_JSON
189212 ):
190213 raise ValueError (
191- "Only HTTP+JSON is supported for preferred transport on agent card "
214+ "Only HTTP+JSON is supported for the primary interface on agent card "
192215 )
193216
194217 self ._tmpl_attrs : dict [str , Any ] = {
@@ -244,7 +267,21 @@ def set_up(self):
244267 agent_engine_id = os .getenv ("GOOGLE_CLOUD_AGENT_ENGINE_ID" , "test-agent-engine" )
245268 version = "v1beta1"
246269
247- self .agent_card .url = f"https://{ location } -aiplatform.googleapis.com/{ version } /projects/{ project } /locations/{ location } /reasoningEngines/{ agent_engine_id } /a2a"
270+ new_url = f"https://{ location } -aiplatform.googleapis.com/{ version } /projects/{ project } /locations/{ location } /reasoningEngines/{ agent_engine_id } /a2a"
271+ if not self .agent_card .supported_interfaces :
272+ from a2a .types import AgentInterface
273+ from a2a .utils .constants import TransportProtocol , PROTOCOL_VERSION_CURRENT
274+
275+ self .agent_card .supported_interfaces .append (
276+ AgentInterface (
277+ url = new_url ,
278+ protocol_binding = TransportProtocol .HTTP_JSON ,
279+ protocol_version = PROTOCOL_VERSION_CURRENT ,
280+ )
281+ )
282+ else :
283+ # primary interface must be HTTP+JSON
284+ self .agent_card .supported_interfaces [0 ].url = new_url
248285 self ._tmpl_attrs ["agent_card" ] = self .agent_card
249286
250287 # Create the agent executor if a builder is provided.
@@ -286,17 +323,30 @@ def set_up(self):
286323
287324 # a2a_rest_adapter is used to register the A2A API routes in the
288325 # Reasoning Engine API router.
289- self .a2a_rest_adapter = RESTAdapter (
290- agent_card = self .agent_card ,
291- http_handler = self ._tmpl_attrs .get ("request_handler" ),
292- extended_agent_card = self ._tmpl_attrs .get ("extended_agent_card" ),
293- )
326+ if _is_version_enabled (self .agent_card , "1.0" ):
327+ self .a2a_rest_adapter = RESTAdapter (
328+ agent_card = self .agent_card ,
329+ http_handler = self ._tmpl_attrs .get ("request_handler" ),
330+ extended_agent_card = self ._tmpl_attrs .get ("extended_agent_card" ),
331+ )
294332
295- # rest_handler is used to handle the A2A API requests.
296- self .rest_handler = RESTHandler (
297- agent_card = self .agent_card ,
298- request_handler = self ._tmpl_attrs .get ("request_handler" ),
299- )
333+ # rest_handler is used to handle the A2A API requests.
334+ self .rest_handler = RESTHandler (
335+ agent_card = self .agent_card ,
336+ request_handler = self ._tmpl_attrs .get ("request_handler" ),
337+ )
338+
339+ # v0.3 handlers will be deprecated in the future.
340+ if _is_version_enabled (self .agent_card , "0.3" ):
341+ from a2a .compat .v0_3 .rest_adapter import REST03Adapter
342+ from a2a .compat .v0_3 .rest_handler import REST03Handler
343+ import functools
344+
345+ self .v03_rest_adapter = REST03Adapter (
346+ agent_card = self .agent_card ,
347+ http_handler = self ._tmpl_attrs .get ("request_handler" ),
348+ extended_agent_card = self ._tmpl_attrs .get ("extended_agent_card" ),
349+ )
300350
301351 async def on_message_send (
302352 self ,
@@ -330,18 +380,25 @@ async def handle_authenticated_agent_card(
330380
331381 def register_operations (self ) -> Dict [str , List [str ]]:
332382 """Registers the operations of the A2A Agent."""
333- routes = {
334- "a2a_extension" : [
335- "on_message_send" ,
336- "on_get_task" ,
337- "on_cancel_task" ,
338- ]
339- }
340- if self .agent_card .capabilities and self .agent_card .capabilities .streaming :
341- routes ["a2a_extension" ].append ("on_message_send_stream" )
342- routes ["a2a_extension" ].append ("on_resubscribe_to_task" )
343- if self .agent_card .supports_authenticated_extended_card :
344- routes ["a2a_extension" ].append ("handle_authenticated_agent_card" )
383+ routes = {"a2a_extension" : []}
384+
385+ if _is_version_enabled (self .agent_card , "1.0" ):
386+ routes ["a2a_extension" ].extend (
387+ [
388+ "on_message_send" ,
389+ "on_get_task" ,
390+ "on_cancel_task" ,
391+ ]
392+ )
393+ if self .agent_card .capabilities and self .agent_card .capabilities .streaming :
394+ routes ["a2a_extension" ].append ("on_message_send_stream" )
395+ routes ["a2a_extension" ].append ("on_subscribe_to_task" )
396+ if (
397+ self .agent_card .capabilities
398+ and self .agent_card .capabilities .extended_agent_card
399+ ):
400+ routes ["a2a_extension" ].append ("handle_authenticated_agent_card" )
401+
345402 return routes
346403
347404 async def on_message_send_stream (
@@ -353,11 +410,59 @@ async def on_message_send_stream(
353410 async for chunk in self .rest_handler .on_message_send_stream (request , context ):
354411 yield chunk
355412
356- async def on_resubscribe_to_task (
413+ async def on_subscribe_to_task (
357414 self ,
358415 request : "Request" ,
359416 context : "ServerCallContext" ,
360417 ) -> AsyncIterator [str ]:
361418 """Handles A2A task resubscription requests via SSE."""
362- async for chunk in self .rest_handler .on_resubscribe_to_task (request , context ):
419+ async for chunk in self .rest_handler .on_subscribe_to_task (request , context ):
363420 yield chunk
421+
422+ def __getstate__ (self ):
423+ """Serializes AgentCard proto to a dictionary."""
424+ from google .protobuf import json_format
425+ import json
426+
427+ state = self .__dict__ .copy ()
428+
429+ def _to_dict_if_proto (obj ):
430+ if hasattr (obj , "DESCRIPTOR" ):
431+ return {
432+ "__protobuf_AgentCard__" : json .loads (json_format .MessageToJson (obj ))
433+ }
434+ return obj
435+
436+ state ["agent_card" ] = _to_dict_if_proto (state .get ("agent_card" ))
437+ if "_tmpl_attrs" in state :
438+ tmpl_attrs = state ["_tmpl_attrs" ].copy ()
439+ tmpl_attrs ["agent_card" ] = _to_dict_if_proto (tmpl_attrs .get ("agent_card" ))
440+ tmpl_attrs ["extended_agent_card" ] = _to_dict_if_proto (
441+ tmpl_attrs .get ("extended_agent_card" )
442+ )
443+ state ["_tmpl_attrs" ] = tmpl_attrs
444+
445+ return state
446+
447+ def __setstate__ (self , state ):
448+ """Deserializes AgentCard proto from a dictionary."""
449+ from google .protobuf import json_format
450+ from a2a .types import AgentCard
451+
452+ def _from_dict_if_proto (obj ):
453+ if isinstance (obj , dict ) and "__protobuf_AgentCard__" in obj :
454+ agent_card = AgentCard ()
455+ json_format .ParseDict (obj ["__protobuf_AgentCard__" ], agent_card )
456+ return agent_card
457+ return obj
458+
459+ state ["agent_card" ] = _from_dict_if_proto (state .get ("agent_card" ))
460+ if "_tmpl_attrs" in state :
461+ state ["_tmpl_attrs" ]["agent_card" ] = _from_dict_if_proto (
462+ state ["_tmpl_attrs" ].get ("agent_card" )
463+ )
464+ state ["_tmpl_attrs" ]["extended_agent_card" ] = _from_dict_if_proto (
465+ state ["_tmpl_attrs" ].get ("extended_agent_card" )
466+ )
467+
468+ self .__dict__ .update (state )
0 commit comments