Skip to content

Commit ec8e095

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: update agent engine utils in vertex_ai and agentplatform to support python-a2a sdk 1.0
PiperOrigin-RevId: 913139425
1 parent 47d8fb0 commit ec8e095

2 files changed

Lines changed: 319 additions & 190 deletions

File tree

agentplatform/_genai/_agent_engines_utils.py

Lines changed: 148 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,18 @@
111111

112112

113113
try:
114-
from a2a.types import (
115-
AgentCard,
116-
TransportProtocol,
117-
Message,
118-
TaskIdParams,
119-
TaskQueryParams,
120-
)
121-
from a2a.client import ClientConfig, ClientFactory
122-
123-
AgentCard = AgentCard
124-
TransportProtocol = TransportProtocol
125-
Message = Message
126-
ClientConfig = ClientConfig
127-
ClientFactory = ClientFactory
128-
TaskIdParams = TaskIdParams
129-
TaskQueryParams = TaskQueryParams
114+
from a2a.types import AgentCard
115+
from a2a.client import ClientConfig, ClientFactory
116+
from a2a.utils.constants import TransportProtocol
130117
except (ImportError, AttributeError):
131-
AgentCard = None
132-
TransportProtocol = None
133-
Message = None
134-
ClientConfig = None
135-
ClientFactory = None
136-
TaskIdParams = None
137-
TaskQueryParams = None
118+
AgentCard = None
119+
TransportProtocol = None
120+
ClientConfig = None
121+
ClientFactory = None
122+
SendMessageRequest = None
123+
GetTaskRequest = None
124+
CancelTaskRequest = None
125+
GetExtendedAgentCardRequest = None
138126
try:
139127
from autogen.agentchat import chat
140128

@@ -1760,162 +1748,149 @@ def _method(self: genai_types.AgentEngine, **kwargs) -> Iterator[Any]: # type:
17601748
def _wrap_async_stream_query_operation(
17611749
*, method_name: str
17621750
) -> Callable[..., AsyncIterator[Any]]:
1763-
"""Wraps an Agent Engine method, creating an async callable for `stream_query` API.
1764-
1765-
This function creates a callable object that executes the specified
1766-
Agent Engine method using the `stream_query` API. It handles the
1767-
creation of the API request and the processing of the API response.
1768-
1769-
The reserved keyword argument `http_options` is consumed by this
1770-
wrapper (rather than being forwarded to the deployed agent as part of
1771-
`input`) and is propagated to the underlying HTTP call.
1751+
"""Wraps an Agent Engine method, creating an async callable for `stream_query` API.
1752+
1753+
This function creates a callable object that executes the specified
1754+
Agent Engine method using the `stream_query` API. It handles the
1755+
creation of the API request and the processing of the API response.
1756+
1757+
The reserved keyword argument `http_options` is consumed by this
1758+
wrapper (rather than being forwarded to the deployed agent as part of
1759+
`input`) and is propagated to the underlying HTTP call.
1760+
1761+
Args:
1762+
method_name: The name of the Agent Engine method to call.
1763+
doc: Documentation string for the method.
1764+
1765+
Returns:
1766+
A callable object that executes the method on the Agent Engine via
1767+
the `stream_query` API.
1768+
"""
1769+
1770+
async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]: # type: ignore[no-untyped-def]
1771+
if not self.api_client:
1772+
raise ValueError("api_client is not initialized.")
1773+
if not self.api_resource:
1774+
raise ValueError("api_resource is not initialized.")
1775+
http_options = kwargs.pop("http_options", None)
1776+
async for http_response in self.api_client._async_stream_query(
1777+
name=self.api_resource.name,
1778+
config={
1779+
"class_method": method_name,
1780+
"input": kwargs,
1781+
"include_all_fields": True,
1782+
"http_options": http_options,
1783+
},
1784+
):
1785+
for line in _yield_parsed_json(http_response=http_response):
1786+
if line is not None:
1787+
yield line
1788+
1789+
return _method
1790+
1791+
1792+
def _wrap_a2a_operation(
1793+
method_name: str, agent_card: str
1794+
) -> Callable[..., list[Any]]:
1795+
"""Wraps an Agent Engine method, creating a callable for A2A API.
1796+
1797+
Args:
1798+
method_name: The name of the Agent Engine method to call.
1799+
agent_card: The agent card to use for the A2A API call.
1800+
Example: { 'name': 'Sample Agent', 'description': ( 'A helpful
1801+
assistant agent that can answer questions.' ),
1802+
'supportedInterfaces': [{ 'url': 'http://localhost:8080/a2a/rest/',
1803+
'protocolBinding': 'HTTP+JSON', 'protocolVersion': '1.0', }],
1804+
'version': '1.0.0', 'capabilities': { 'streaming': True,
1805+
'pushNotifications': False, 'extendedAgentCard': True, },
1806+
'defaultInputModes': ['text'], 'defaultOutputModes': ['text'],
1807+
'skills': [{ 'id': 'question_answer', 'name': 'Q&A Agent',
1808+
'description': ( 'A helpful assistant agent that can answer
1809+
questions.' ), 'tags': ['Question-Answer'], 'examples': [ 'Who is
1810+
leading 2025 F1 Standings?', 'Where can i find an active volcano?',
1811+
], 'inputModes': ['text'], 'outputModes': ['text'], }], }
1812+
1813+
Returns:
1814+
A callable object that executes the method on the Agent Engine via
1815+
the A2A API.
1816+
"""
1817+
1818+
async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
1819+
if not self.api_client:
1820+
raise ValueError("api_client is not initialized.")
1821+
if not self.api_resource:
1822+
raise ValueError("api_resource is not initialized.")
1823+
1824+
a2a_agent_card = AgentCard()
1825+
json_format.ParseDict(
1826+
json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True
1827+
)
17721828

1773-
Args:
1774-
method_name: The name of the Agent Engine method to call.
1775-
doc: Documentation string for the method.
1829+
if a2a_agent_card.supported_interfaces:
1830+
interface = a2a_agent_card.supported_interfaces[0]
1831+
if interface.protocol_binding != TransportProtocol.HTTP_JSON:
1832+
raise ValueError(
1833+
"Only HTTP+JSON is supported for preferred transport on agent card"
1834+
)
1835+
else:
1836+
raise ValueError("Agent card does not define any supported interfaces.")
17761837

1777-
Returns:
1778-
A callable object that executes the method on the Agent Engine via
1779-
the `stream_query` API.
1780-
"""
1838+
base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
1839+
api_version = self.api_client._api_client._http_options.api_version
1840+
a2a_agent_card.supported_interfaces[0].url = (
1841+
f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1842+
)
17811843

1782-
async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]: # type: ignore[no-untyped-def]
1783-
if not self.api_client:
1784-
raise ValueError("api_client is not initialized.")
1785-
if not self.api_resource:
1786-
raise ValueError("api_resource is not initialized.")
1787-
http_options = kwargs.pop("http_options", None)
1788-
async for http_response in self.api_client._async_stream_query(
1789-
name=self.api_resource.name,
1790-
config={
1791-
"class_method": method_name,
1792-
"input": kwargs,
1793-
"include_all_fields": True,
1794-
"http_options": http_options,
1844+
config = ClientConfig(
1845+
supported_protocol_bindings=[
1846+
TransportProtocol.HTTP_JSON,
1847+
],
1848+
use_client_preference=True,
1849+
httpx_client=httpx.AsyncClient(
1850+
headers={
1851+
"Authorization": (
1852+
f"Bearer {self.api_client._api_client._credentials.token}"
1853+
)
17951854
},
1796-
):
1797-
for line in _yield_parsed_json(http_response=http_response):
1798-
if line is not None:
1799-
yield line
1800-
1801-
return _method
1802-
1803-
1804-
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1805-
"""Wraps an Agent Engine method, creating a callable for A2A API.
1806-
1807-
Args:
1808-
method_name: The name of the Agent Engine method to call.
1809-
agent_card: The agent card to use for the A2A API call.
1810-
Example:
1811-
{'additionalInterfaces': None,
1812-
'capabilities': {'extensions': None,
1813-
'pushNotifications': None,
1814-
'stateTransitionHistory': None,
1815-
'streaming': False},
1816-
'defaultInputModes': ['text'],
1817-
'defaultOutputModes': ['text'],
1818-
'description': (
1819-
'A helpful assistant agent that can answer questions.'
1820-
),
1821-
'documentationUrl': None,
1822-
'iconUrl': None,
1823-
'name': 'Q&A Agent',
1824-
'preferredTransport': 'JSONRPC',
1825-
'protocolVersion': '0.3.0',
1826-
'provider': None,
1827-
'security': None,
1828-
'securitySchemes': None,
1829-
'signatures': None,
1830-
'skills': [{
1831-
'description': (
1832-
'A helpful assistant agent that can answer questions.'
1833-
),
1834-
'examples': ['Who is leading 2025 F1 Standings?',
1835-
'Where can i find an active volcano?'],
1836-
'id': 'question_answer',
1837-
'inputModes': None,
1838-
'name': 'Q&A Agent',
1839-
'outputModes': None,
1840-
'security': None,
1841-
'tags': ['Question-Answer']}],
1842-
'supportsAuthenticatedExtendedCard': True,
1843-
'url': 'http://localhost:8080/',
1844-
'version': '1.0.0'}
1845-
Returns:
1846-
A callable object that executes the method on the Agent Engine via
1847-
the A2A API.
1848-
"""
1849-
1850-
async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
1851-
"""Wraps an Agent Engine method, creating a callable for A2A API."""
1852-
if not self.api_client:
1853-
raise ValueError("api_client is not initialized.")
1854-
if not self.api_resource:
1855-
raise ValueError("api_resource is not initialized.")
1856-
a2a_agent_card = AgentCard(**json.loads(agent_card))
1857-
# A2A + AE integration currently only supports Rest API.
1858-
if (
1859-
a2a_agent_card.preferred_transport
1860-
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
1861-
):
1862-
raise ValueError(
1863-
"Only HTTP+JSON is supported for preferred transport on agent card "
1864-
)
1865-
1866-
# Set preferred transport to HTTP+JSON if not set.
1867-
if not hasattr(a2a_agent_card, "preferred_transport"):
1868-
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1869-
1870-
if not hasattr(a2a_agent_card.capabilities, "streaming"):
1871-
a2a_agent_card.capabilities.streaming = False
1872-
1873-
# agent_card is set on the class_methods before set_up is invoked.
1874-
# Ensure that the agent_card url is set correctly before the client is created.
1875-
base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
1876-
api_version = self.api_client._api_client._http_options.api_version
1877-
a2a_agent_card.url = f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1878-
1879-
# Using a2a client, inject the auth token from the global config.
1880-
config = ClientConfig(
1881-
supported_transports=[
1882-
TransportProtocol.http_json,
1883-
],
1884-
use_client_preference=True,
1885-
httpx_client=httpx.AsyncClient(
1886-
headers={
1887-
"Authorization": (
1888-
f"Bearer {self.api_client._api_client._credentials.token}"
1889-
)
1890-
},
1891-
timeout=(
1892-
self.api_client._api_client._http_options.timeout / 1000.0
1893-
if self.api_client._api_client._http_options.timeout
1894-
else None
1895-
),
1855+
timeout=(
1856+
self.api_client._api_client._http_options.timeout / 1000.0
1857+
if self.api_client._api_client._http_options.timeout
1858+
else None
18961859
),
1897-
)
1898-
factory = ClientFactory(config)
1899-
client = factory.create(a2a_agent_card)
1900-
1901-
if method_name == "on_message_send":
1902-
response = client.send_message(Message(**kwargs))
1903-
chunks = []
1904-
async for chunk in response:
1905-
chunks.append(chunk)
1906-
return chunks
1907-
elif method_name == "on_get_task":
1908-
response = await client.get_task(TaskQueryParams(**kwargs))
1909-
elif method_name == "on_cancel_task":
1910-
response = await client.cancel_task(TaskIdParams(**kwargs))
1911-
elif method_name == "handle_authenticated_agent_card":
1912-
response = await client.get_card()
1913-
else:
1914-
raise ValueError(f"Unknown method name: {method_name}")
1915-
1916-
return response
1860+
),
1861+
)
1862+
factory = ClientFactory(config)
1863+
client = factory.create(a2a_agent_card)
1864+
1865+
context = kwargs.pop("context", None)
1866+
if context is not None:
1867+
from a2a.client.client import ClientCallContext
1868+
1869+
if not isinstance(context, ClientCallContext):
1870+
actual_context = ClientCallContext()
1871+
if hasattr(context, "state"):
1872+
actual_context.state = context.state
1873+
elif isinstance(context, dict):
1874+
actual_context.state = context
1875+
context = actual_context
1876+
1877+
req = kwargs["request"]
1878+
if method_name == "on_message_send":
1879+
response = client.send_message(req, context=context)
1880+
chunks = []
1881+
async for chunk in response:
1882+
chunks.append(chunk)
1883+
return chunks
1884+
elif method_name == "on_get_task":
1885+
return await client.get_task(req, context=context)
1886+
elif method_name == "on_cancel_task":
1887+
return await client.cancel_task(req, context=context)
1888+
elif method_name == "on_get_extended_agent_card":
1889+
return await client.get_extended_agent_card(req, context=context)
1890+
else:
1891+
raise ValueError(f"Unknown method name: {method_name}")
19171892

1918-
return _method # type: ignore[return-value]
1893+
return _method # type: ignore[return-value]
19191894

19201895

19211896
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:

0 commit comments

Comments
 (0)