Skip to content

Commit b4b52f3

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: update agent engine utils to support python-a2a sdk 1.0
PiperOrigin-RevId: 913139425
1 parent a99f340 commit b4b52f3

1 file changed

Lines changed: 171 additions & 17 deletions

File tree

vertexai/_genai/_agent_engines_utils.py

Lines changed: 171 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,30 +110,50 @@
110110

111111

112112
try:
113-
from a2a.types import (
114-
AgentCard,
115-
TransportProtocol,
116-
Message,
117-
TaskIdParams,
118-
TaskQueryParams,
119-
)
120-
from a2a.client import ClientConfig, ClientFactory
121-
122-
AgentCard = AgentCard
123-
TransportProtocol = TransportProtocol
124-
Message = Message
125-
ClientConfig = ClientConfig
126-
ClientFactory = ClientFactory
127-
TaskIdParams = TaskIdParams
128-
TaskQueryParams = TaskQueryParams
113+
from a2a.utils.constants import TransportProtocol as _A2aVersionTest # noqa: F401
114+
115+
_A2A_SDK_VERSION: Optional[str] = "1.0"
116+
except ImportError:
117+
try:
118+
from a2a.types import TransportProtocol as _A2aVersionTest # noqa: F401
119+
120+
_A2A_SDK_VERSION = "0.3"
121+
except ImportError:
122+
_A2A_SDK_VERSION = None
123+
124+
try:
125+
if _A2A_SDK_VERSION == "1.0":
126+
from a2a.types import (
127+
AgentCard,
128+
Message,
129+
)
130+
from a2a.client import ClientConfig, ClientFactory
131+
from a2a.utils.constants import TransportProtocol
132+
from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams
133+
elif _A2A_SDK_VERSION == "0.3":
134+
from a2a.types import (
135+
AgentCard,
136+
TransportProtocol,
137+
Message,
138+
TaskIdParams,
139+
TaskQueryParams,
140+
)
141+
from a2a.client import ClientConfig, ClientFactory
142+
else:
143+
raise ImportError
129144
except (ImportError, AttributeError):
145+
_A2A_SDK_VERSION = None
130146
AgentCard = None
131147
TransportProtocol = None
132148
Message = None
133149
ClientConfig = None
134150
ClientFactory = None
135151
TaskIdParams = None
136152
TaskQueryParams = None
153+
SendMessageRequest = None
154+
GetTaskRequest = None
155+
CancelTaskRequest = None
156+
GetExtendedAgentCardRequest = None
137157

138158
_ACTIONS_KEY = "actions"
139159
_ACTION_APPEND = "append"
@@ -1737,7 +1757,9 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
17371757
return _method
17381758

17391759

1740-
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1760+
def _wrap_a2a_operation_v03(
1761+
method_name: str, agent_card: str
1762+
) -> Callable[..., list[Any]]:
17411763
"""Wraps an Agent Engine method, creating a callable for A2A API.
17421764
17431765
Args:
@@ -1854,6 +1876,138 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18541876
return _method # type: ignore[return-value]
18551877

18561878

1879+
def _wrap_a2a_operation_v10(
1880+
method_name: str, agent_card: str
1881+
) -> Callable[..., list[Any]]:
1882+
"""Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+).
1883+
1884+
Args:
1885+
method_name: The name of the Agent Engine method to call.
1886+
agent_card: The agent card JSON string to use for the A2A API call.
1887+
Example:
1888+
{
1889+
'name': 'Sample Agent',
1890+
'description': (
1891+
'A helpful assistant agent that can answer questions.'
1892+
),
1893+
'supportedInterfaces': [{
1894+
'url': 'http://localhost:8080/a2a/rest/',
1895+
'protocolBinding': 'HTTP+JSON',
1896+
'protocolVersion': '1.0',
1897+
}],
1898+
'version': '1.0.0',
1899+
'capabilities': {
1900+
'streaming': True,
1901+
'pushNotifications': False,
1902+
'extendedAgentCard': True,
1903+
},
1904+
'defaultInputModes': ['text'],
1905+
'defaultOutputModes': ['text'],
1906+
'skills': [{
1907+
'id': 'question_answer',
1908+
'name': 'Q&A Agent',
1909+
'description': (
1910+
'A helpful assistant agent that can answer questions.'
1911+
),
1912+
'tags': ['Question-Answer'],
1913+
'examples': [
1914+
'Who is leading 2025 F1 Standings?',
1915+
'Where can i find an active volcano?',
1916+
],
1917+
'inputModes': ['text'],
1918+
'outputModes': ['text'],
1919+
}],
1920+
}
1921+
1922+
Returns:
1923+
A callable object that executes the method on the Agent Engine via
1924+
the A2A API.
1925+
"""
1926+
1927+
async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
1928+
if not self.api_client:
1929+
raise ValueError("api_client is not initialized.")
1930+
if not self.api_resource:
1931+
raise ValueError("api_resource is not initialized.")
1932+
1933+
a2a_agent_card = AgentCard()
1934+
json_format.ParseDict(
1935+
json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True
1936+
)
1937+
1938+
if a2a_agent_card.supported_interfaces:
1939+
interface = a2a_agent_card.supported_interfaces[0]
1940+
if interface.protocol_binding != TransportProtocol.HTTP_JSON:
1941+
raise ValueError(
1942+
"Only HTTP+JSON is supported for preferred transport on agent card"
1943+
)
1944+
else:
1945+
raise ValueError("Agent card does not define any supported interfaces.")
1946+
1947+
base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
1948+
api_version = self.api_client._api_client._http_options.api_version
1949+
a2a_agent_card.supported_interfaces[0].url = (
1950+
f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1951+
)
1952+
1953+
config = ClientConfig(
1954+
supported_protocol_bindings=[
1955+
TransportProtocol.HTTP_JSON,
1956+
],
1957+
use_client_preference=True,
1958+
httpx_client=httpx.AsyncClient(
1959+
headers={
1960+
"Authorization": (
1961+
f"Bearer {self.api_client._api_client._credentials.token}"
1962+
)
1963+
},
1964+
timeout=(
1965+
self.api_client._api_client._http_options.timeout / 1000.0
1966+
if self.api_client._api_client._http_options.timeout
1967+
else None
1968+
),
1969+
),
1970+
)
1971+
factory = ClientFactory(config)
1972+
client = factory.create(a2a_agent_card)
1973+
1974+
context = kwargs.pop("context", None)
1975+
if context is not None:
1976+
from a2a.client.client import ClientCallContext
1977+
1978+
if not isinstance(context, ClientCallContext):
1979+
actual_context = ClientCallContext()
1980+
if hasattr(context, "state"):
1981+
actual_context.state = context.state
1982+
elif isinstance(context, dict):
1983+
actual_context.state = context
1984+
context = actual_context
1985+
1986+
req = kwargs["request"]
1987+
if method_name == "on_message_send":
1988+
response = client.send_message(req, context=context)
1989+
chunks = []
1990+
async for chunk in response:
1991+
chunks.append(chunk)
1992+
return chunks
1993+
elif method_name == "on_get_task":
1994+
return await client.get_task(req, context=context)
1995+
elif method_name == "on_cancel_task":
1996+
return await client.cancel_task(req, context=context)
1997+
elif method_name == "on_get_extended_agent_card":
1998+
return await client.get_extended_agent_card(req, context=context)
1999+
else:
2000+
raise ValueError(f"Unknown method name: {method_name}")
2001+
2002+
return _method # type: ignore[return-value]
2003+
2004+
2005+
if _A2A_SDK_VERSION == "1.0":
2006+
_wrap_a2a_operation = _wrap_a2a_operation_v10
2007+
else:
2008+
_wrap_a2a_operation = _wrap_a2a_operation_v03
2009+
2010+
18572011
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
18582012
"""Converts the body of the HTTP Response message to JSON format.
18592013

0 commit comments

Comments
 (0)