Skip to content

Commit c245d28

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: update agent engine utils to support python-a2a sdk 1.0
PiperOrigin-RevId: 913139425
1 parent 4ba222b commit c245d28

1 file changed

Lines changed: 166 additions & 25 deletions

File tree

vertexai/_genai/_agent_engines_utils.py

Lines changed: 166 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -110,30 +110,46 @@
110110

111111

112112
try:
113-
from a2a.types import (
114-
AgentCard,
115-
TransportProtocol,
116-
Message,
117-
TaskIdParams,
118-
TaskQueryParams,
119-
)
120-
from a2a.client import ClientConfig, ClientFactory
121-
122-
AgentCard = AgentCard
123-
TransportProtocol = TransportProtocol
124-
Message = Message
125-
ClientConfig = ClientConfig
126-
ClientFactory = ClientFactory
127-
TaskIdParams = TaskIdParams
128-
TaskQueryParams = TaskQueryParams
129-
except (ImportError, AttributeError):
130-
AgentCard = None
131-
TransportProtocol = None
132-
Message = None
133-
ClientConfig = None
134-
ClientFactory = None
135-
TaskIdParams = None
136-
TaskQueryParams = None
113+
from a2a.utils.constants import TransportProtocol as _TpTest
114+
115+
_A2A_SDK_VERSION = "1.0"
116+
except ImportError:
117+
try:
118+
from a2a.types import TransportProtocol as _TpTest
119+
120+
_A2A_SDK_VERSION = "0.3"
121+
except ImportError:
122+
_A2A_SDK_VERSION = None
123+
124+
if _A2A_SDK_VERSION == "1.0":
125+
from a2a.types import (
126+
AgentCard,
127+
Message,
128+
)
129+
from a2a.client import ClientConfig, ClientFactory
130+
from a2a.utils.constants import TransportProtocol
131+
from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams
132+
elif _A2A_SDK_VERSION == "0.3":
133+
from a2a.types import (
134+
AgentCard,
135+
TransportProtocol,
136+
Message,
137+
TaskIdParams,
138+
TaskQueryParams,
139+
)
140+
from a2a.client import ClientConfig, ClientFactory
141+
else:
142+
AgentCard = None
143+
TransportProtocol = None
144+
Message = None
145+
ClientConfig = None
146+
ClientFactory = None
147+
TaskIdParams = None
148+
TaskQueryParams = None
149+
SendMessageRequest = None
150+
GetTaskRequest = None
151+
CancelTaskRequest = None
152+
GetExtendedAgentCardRequest = None
137153

138154
_ACTIONS_KEY = "actions"
139155
_ACTION_APPEND = "append"
@@ -1737,7 +1753,7 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
17371753
return _method
17381754

17391755

1740-
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1756+
def _wrap_a2a_operation_v03(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
17411757
"""Wraps an Agent Engine method, creating a callable for A2A API.
17421758
17431759
Args:
@@ -1854,6 +1870,131 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18541870
return _method # type: ignore[return-value]
18551871

18561872

1873+
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1874+
"""Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+).
1875+
1876+
Args:
1877+
method_name: The name of the Agent Engine method to call.
1878+
agent_card: The agent card JSON string to use for the A2A API call.
1879+
Example:
1880+
{'name': 'Sample Agent',
1881+
'description': (
1882+
'A helpful assistant agent that can answer questions.'
1883+
),
1884+
'supportedInterfaces': [{
1885+
'url': 'http://localhost:8080/a2a/rest/',
1886+
'protocolBinding': 'HTTP+JSON',
1887+
'protocolVersion': '1.0',
1888+
}],
1889+
'version': '1.0.0',
1890+
'capabilities': {
1891+
'streaming': True,
1892+
'pushNotifications': False,
1893+
'extendedAgentCard': True,
1894+
},
1895+
'defaultInputModes': ['text'],
1896+
'defaultOutputModes': ['text'],
1897+
'skills': [{
1898+
'id': 'question_answer',
1899+
'name': 'Q&A Agent',
1900+
'description': (
1901+
'A helpful assistant agent that can answer questions.'
1902+
),
1903+
'tags': ['Question-Answer'],
1904+
'examples': [
1905+
'Who is leading 2025 F1 Standings?',
1906+
'Where can i find an active volcano?',
1907+
],
1908+
'inputModes': ['text'],
1909+
'outputModes': ['text'],
1910+
}]}
1911+
Returns:
1912+
A callable object that executes the method on the Agent Engine via
1913+
the A2A API.
1914+
"""
1915+
1916+
async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
1917+
if not self.api_client:
1918+
raise ValueError("api_client is not initialized.")
1919+
if not self.api_resource:
1920+
raise ValueError("api_resource is not initialized.")
1921+
1922+
a2a_agent_card = AgentCard()
1923+
json_format.ParseDict(
1924+
json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True
1925+
)
1926+
1927+
if a2a_agent_card.supported_interfaces:
1928+
interface = a2a_agent_card.supported_interfaces[0]
1929+
if interface.protocol_binding != TransportProtocol.HTTP_JSON:
1930+
raise ValueError(
1931+
"Only HTTP+JSON is supported for preferred transport on agent card"
1932+
)
1933+
else:
1934+
raise ValueError("Agent card does not define any supported interfaces.")
1935+
1936+
base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
1937+
api_version = self.api_client._api_client._http_options.api_version
1938+
a2a_agent_card.supported_interfaces[0].url = (
1939+
f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1940+
)
1941+
1942+
config = ClientConfig(
1943+
supported_protocol_bindings=[
1944+
TransportProtocol.HTTP_JSON,
1945+
],
1946+
use_client_preference=True,
1947+
httpx_client=httpx.AsyncClient(
1948+
headers={
1949+
"Authorization": (
1950+
f"Bearer {self.api_client._api_client._credentials.token}"
1951+
)
1952+
},
1953+
timeout=(
1954+
self.api_client._api_client._http_options.timeout / 1000.0
1955+
if self.api_client._api_client._http_options.timeout
1956+
else None
1957+
),
1958+
),
1959+
)
1960+
factory = ClientFactory(config)
1961+
client = factory.create(a2a_agent_card)
1962+
1963+
context = kwargs.pop("context", None)
1964+
if context is not None:
1965+
from a2a.client.client import ClientCallContext
1966+
1967+
if not isinstance(context, ClientCallContext):
1968+
actual_context = ClientCallContext()
1969+
if hasattr(context, "state"):
1970+
actual_context.state = context.state
1971+
elif isinstance(context, dict):
1972+
actual_context.state = context
1973+
context = actual_context
1974+
1975+
req = kwargs["request"]
1976+
if method_name == "on_message_send":
1977+
response = client.send_message(req, context=context)
1978+
chunks = []
1979+
async for chunk in response:
1980+
chunks.append(chunk)
1981+
return chunks
1982+
elif method_name == "on_get_task":
1983+
return await client.get_task(req, context=context)
1984+
elif method_name == "on_cancel_task":
1985+
return await client.cancel_task(req, context=context)
1986+
elif method_name == "on_get_extended_agent_card":
1987+
return await client.get_extended_agent_card(req, context=context)
1988+
else:
1989+
raise ValueError(f"Unknown method name: {method_name}")
1990+
1991+
return _method # type: ignore[return-value]
1992+
1993+
1994+
if _A2A_SDK_VERSION != "1.0":
1995+
_wrap_a2a_operation = _wrap_a2a_operation_v03
1996+
1997+
18571998
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
18581999
"""Converts the body of the HTTP Response message to JSON format.
18592000

0 commit comments

Comments
 (0)