Skip to content

Commit 42a1a0c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: update sdk to support a2a 1.0
PiperOrigin-RevId: 890388363
1 parent 62656c2 commit 42a1a0c

File tree

4 files changed

+172
-54
lines changed

4 files changed

+172
-54
lines changed

vertexai/_genai/_agent_engines_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise(
632632
class_method = _to_proto(schema_dict)
633633
class_method[_MODE_KEY_IN_SCHEMA] = mode
634634
if hasattr(agent, "agent_card"):
635-
class_method[_A2A_AGENT_CARD] = getattr(
636-
agent, "agent_card"
637-
).model_dump_json()
635+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
636+
getattr(agent, "agent_card")
637+
)
638638
class_methods_spec.append(class_method)
639639

640640
return class_methods_spec

vertexai/_genai/agent_engines.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,10 +1834,13 @@ def _create_config(
18341834
agent_card = getattr(agent, "agent_card")
18351835
if agent_card:
18361836
try:
1837-
agent_engine_spec["agent_card"] = agent_card.model_dump(
1838-
exclude_none=True
1837+
from google.protobuf import json_format
1838+
import json
1839+
1840+
agent_engine_spec["agent_card"] = json.loads(
1841+
json_format.MessageToJson(agent_card)
18391842
)
1840-
except TypeError as e:
1843+
except Exception as e:
18411844
raise ValueError(
18421845
f"Failed to convert agent card to dict (serialization error): {e}"
18431846
) from e

vertexai/agent_engines/_agent_engines.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,28 @@
119119
try:
120120
from a2a.types import (
121121
AgentCard,
122-
TransportProtocol,
122+
AgentInterface,
123123
Message,
124124
TaskIdParams,
125125
TaskQueryParams,
126126
)
127+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
127128
from a2a.client import ClientConfig, ClientFactory
128129

129130
AgentCard = AgentCard
131+
AgentInterface = AgentInterface
130132
TransportProtocol = TransportProtocol
133+
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT
131134
Message = Message
132135
ClientConfig = ClientConfig
133136
ClientFactory = ClientFactory
134137
TaskIdParams = TaskIdParams
135138
TaskQueryParams = TaskQueryParams
136139
except (ImportError, AttributeError):
137140
AgentCard = None
141+
AgentInterface = None
138142
TransportProtocol = None
143+
PROTOCOL_VERSION_CURRENT = None
139144
Message = None
140145
ClientConfig = None
141146
ClientFactory = None
@@ -1735,17 +1740,20 @@ async def _method(self, **kwargs) -> Any:
17351740
a2a_agent_card = AgentCard(**json.loads(agent_card))
17361741

17371742
# A2A + AE integration currently only supports Rest API.
1738-
if (
1739-
a2a_agent_card.preferred_transport
1740-
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
1741-
):
1743+
if a2a_agent_card.supported_interfaces and a2a_agent_card.supported_interfaces[0].protocol_binding != TransportProtocol.HTTP_JSON:
17421744
raise ValueError(
1743-
"Only HTTP+JSON is supported for preferred transport on agent card "
1745+
"Only HTTP+JSON is supported for primary interface on agent card "
17441746
)
17451747

1746-
# Set preferred transport to HTTP+JSON if not set.
1747-
if not hasattr(a2a_agent_card, "preferred_transport"):
1748-
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1748+
# Set primary interface to HTTP+JSON if not set.
1749+
if not a2a_agent_card.supported_interfaces:
1750+
a2a_agent_card.supported_interfaces = []
1751+
a2a_agent_card.supported_interfaces.append(
1752+
AgentInterface(
1753+
protocol_binding=TransportProtocol.HTTP_JSON,
1754+
protocol_version=PROTOCOL_VERSION_CURRENT,
1755+
)
1756+
)
17491757

17501758
# AE cannot support streaming yet. Turn off streaming for now.
17511759
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
@@ -1759,12 +1767,13 @@ async def _method(self, **kwargs) -> Any:
17591767

17601768
# agent_card is set on the class_methods before set_up is invoked.
17611769
# Ensure that the agent_card url is set correctly before the client is created.
1762-
a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1770+
url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1771+
a2a_agent_card.supported_interfaces[0].url = url
17631772

17641773
# Using a2a client, inject the auth token from the global config.
17651774
config = ClientConfig(
17661775
supported_transports=[
1767-
TransportProtocol.http_json,
1776+
TransportProtocol.HTTP_JSON,
17681777
],
17691778
use_client_preference=True,
17701779
httpx_client=httpx.AsyncClient(
@@ -1977,9 +1986,10 @@ def _generate_class_methods_spec_or_raise(
19771986
class_method[_MODE_KEY_IN_SCHEMA] = mode
19781987
# A2A agent card is a special case, when running in A2A mode,
19791988
if hasattr(agent_engine, "agent_card"):
1980-
class_method[_A2A_AGENT_CARD] = getattr(
1981-
agent_engine, "agent_card"
1982-
).model_dump_json()
1989+
from google.protobuf import json_format
1990+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
1991+
getattr(agent_engine, "agent_card")
1992+
)
19831993
class_methods_spec.append(class_method)
19841994

19851995
return class_methods_spec

vertexai/preview/reasoning_engines/templates/a2a.py

Lines changed: 139 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def create_agent_card(
8787
provided.
8888
"""
8989
# pylint: disable=g-import-not-at-top
90-
from a2a.types import AgentCard, AgentCapabilities, TransportProtocol
90+
from a2a.types import AgentCard, AgentCapabilities, AgentInterface
91+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
9192

9293
# Check if a dictionary was provided.
9394
if agent_card:
@@ -98,14 +99,20 @@ def create_agent_card(
9899
return AgentCard(
99100
name=agent_name,
100101
description=description,
101-
url="http://localhost:9999/",
102102
version="1.0.0",
103103
default_input_modes=default_input_modes or ["text/plain"],
104104
default_output_modes=default_output_modes or ["application/json"],
105-
capabilities=AgentCapabilities(streaming=streaming),
105+
capabilities=AgentCapabilities(
106+
streaming=streaming, extended_agent_card=True
107+
),
106108
skills=skills,
107-
preferred_transport=TransportProtocol.http_json, # Http Only.
108-
supports_authenticated_extended_card=True,
109+
supported_interfaces=[
110+
AgentInterface(
111+
url="http://localhost:9999/",
112+
protocol_binding=TransportProtocol.HTTP_JSON,
113+
protocol_version=PROTOCOL_VERSION_CURRENT,
114+
)
115+
],
109116
)
110117

111118
# Raise an error if insufficient data is provided.
@@ -162,6 +169,21 @@ async def cancel(
162169
)
163170

164171

172+
def _is_version_enabled(agent_card: "AgentCard", version: str) -> bool:
173+
"""Checks if a specific version compatibility should be enabled for the A2aAgent."""
174+
from a2a.utils.constants import TransportProtocol
175+
176+
if not agent_card.supported_interfaces:
177+
return False
178+
for interface in agent_card.supported_interfaces:
179+
if (
180+
interface.protocol_version == version
181+
and interface.protocol_binding == TransportProtocol.HTTP_JSON
182+
):
183+
return True
184+
return False
185+
186+
165187
class A2aAgent:
166188
"""A class to initialize and set up an Agent-to-Agent application."""
167189

@@ -181,14 +203,15 @@ def __init__(
181203
"""Initializes the A2A agent."""
182204
# pylint: disable=g-import-not-at-top
183205
from google.cloud.aiplatform import initializer
184-
from a2a.types import TransportProtocol
206+
from a2a.utils.constants import TransportProtocol
185207

186208
if (
187-
agent_card.preferred_transport
188-
and agent_card.preferred_transport != TransportProtocol.http_json
209+
agent_card.supported_interfaces
210+
and agent_card.supported_interfaces[0].protocol_binding
211+
!= TransportProtocol.HTTP_JSON
189212
):
190213
raise ValueError(
191-
"Only HTTP+JSON is supported for preferred transport on agent card "
214+
"Only HTTP+JSON is supported for the primary interface on agent card "
192215
)
193216

194217
self._tmpl_attrs: dict[str, Any] = {
@@ -244,7 +267,21 @@ def set_up(self):
244267
agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine")
245268
version = "v1beta1"
246269

247-
self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
270+
new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
271+
if not self.agent_card.supported_interfaces:
272+
from a2a.types import AgentInterface
273+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
274+
275+
self.agent_card.supported_interfaces.append(
276+
AgentInterface(
277+
url=new_url,
278+
protocol_binding=TransportProtocol.HTTP_JSON,
279+
protocol_version=PROTOCOL_VERSION_CURRENT,
280+
)
281+
)
282+
else:
283+
# primary interface must be HTTP+JSON
284+
self.agent_card.supported_interfaces[0].url = new_url
248285
self._tmpl_attrs["agent_card"] = self.agent_card
249286

250287
# Create the agent executor if a builder is provided.
@@ -286,17 +323,30 @@ def set_up(self):
286323

287324
# a2a_rest_adapter is used to register the A2A API routes in the
288325
# Reasoning Engine API router.
289-
self.a2a_rest_adapter = RESTAdapter(
290-
agent_card=self.agent_card,
291-
http_handler=self._tmpl_attrs.get("request_handler"),
292-
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
293-
)
326+
if _is_version_enabled(self.agent_card, "1.0"):
327+
self.a2a_rest_adapter = RESTAdapter(
328+
agent_card=self.agent_card,
329+
http_handler=self._tmpl_attrs.get("request_handler"),
330+
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
331+
)
294332

295-
# rest_handler is used to handle the A2A API requests.
296-
self.rest_handler = RESTHandler(
297-
agent_card=self.agent_card,
298-
request_handler=self._tmpl_attrs.get("request_handler"),
299-
)
333+
# rest_handler is used to handle the A2A API requests.
334+
self.rest_handler = RESTHandler(
335+
agent_card=self.agent_card,
336+
request_handler=self._tmpl_attrs.get("request_handler"),
337+
)
338+
339+
# v0.3 handlers will be deprecated in the future.
340+
if _is_version_enabled(self.agent_card, "0.3"):
341+
from a2a.compat.v0_3.rest_adapter import REST03Adapter
342+
from a2a.compat.v0_3.rest_handler import REST03Handler
343+
import functools
344+
345+
self.v03_rest_adapter = REST03Adapter(
346+
agent_card=self.agent_card,
347+
http_handler=self._tmpl_attrs.get("request_handler"),
348+
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
349+
)
300350

301351
async def on_message_send(
302352
self,
@@ -330,18 +380,25 @@ async def handle_authenticated_agent_card(
330380

331381
def register_operations(self) -> Dict[str, List[str]]:
332382
"""Registers the operations of the A2A Agent."""
333-
routes = {
334-
"a2a_extension": [
335-
"on_message_send",
336-
"on_get_task",
337-
"on_cancel_task",
338-
]
339-
}
340-
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
341-
routes["a2a_extension"].append("on_message_send_stream")
342-
routes["a2a_extension"].append("on_resubscribe_to_task")
343-
if self.agent_card.supports_authenticated_extended_card:
344-
routes["a2a_extension"].append("handle_authenticated_agent_card")
383+
routes = {"a2a_extension": []}
384+
385+
if _is_version_enabled(self.agent_card, "1.0"):
386+
routes["a2a_extension"].extend(
387+
[
388+
"on_message_send",
389+
"on_get_task",
390+
"on_cancel_task",
391+
]
392+
)
393+
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
394+
routes["a2a_extension"].append("on_message_send_stream")
395+
routes["a2a_extension"].append("on_subscribe_to_task")
396+
if (
397+
self.agent_card.capabilities
398+
and self.agent_card.capabilities.extended_agent_card
399+
):
400+
routes["a2a_extension"].append("handle_authenticated_agent_card")
401+
345402
return routes
346403

347404
async def on_message_send_stream(
@@ -353,11 +410,59 @@ async def on_message_send_stream(
353410
async for chunk in self.rest_handler.on_message_send_stream(request, context):
354411
yield chunk
355412

356-
async def on_resubscribe_to_task(
413+
async def on_subscribe_to_task(
357414
self,
358415
request: "Request",
359416
context: "ServerCallContext",
360417
) -> AsyncIterator[str]:
361418
"""Handles A2A task resubscription requests via SSE."""
362-
async for chunk in self.rest_handler.on_resubscribe_to_task(request, context):
419+
async for chunk in self.rest_handler.on_subscribe_to_task(request, context):
363420
yield chunk
421+
422+
def __getstate__(self):
423+
"""Serializes AgentCard proto to a dictionary."""
424+
from google.protobuf import json_format
425+
import json
426+
427+
state = self.__dict__.copy()
428+
429+
def _to_dict_if_proto(obj):
430+
if hasattr(obj, "DESCRIPTOR"):
431+
return {
432+
"__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj))
433+
}
434+
return obj
435+
436+
state["agent_card"] = _to_dict_if_proto(state.get("agent_card"))
437+
if "_tmpl_attrs" in state:
438+
tmpl_attrs = state["_tmpl_attrs"].copy()
439+
tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card"))
440+
tmpl_attrs["extended_agent_card"] = _to_dict_if_proto(
441+
tmpl_attrs.get("extended_agent_card")
442+
)
443+
state["_tmpl_attrs"] = tmpl_attrs
444+
445+
return state
446+
447+
def __setstate__(self, state):
448+
"""Deserializes AgentCard proto from a dictionary."""
449+
from google.protobuf import json_format
450+
from a2a.types import AgentCard
451+
452+
def _from_dict_if_proto(obj):
453+
if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj:
454+
agent_card = AgentCard()
455+
json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card)
456+
return agent_card
457+
return obj
458+
459+
state["agent_card"] = _from_dict_if_proto(state.get("agent_card"))
460+
if "_tmpl_attrs" in state:
461+
state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto(
462+
state["_tmpl_attrs"].get("agent_card")
463+
)
464+
state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto(
465+
state["_tmpl_attrs"].get("extended_agent_card")
466+
)
467+
468+
self.__dict__.update(state)

0 commit comments

Comments
 (0)