Skip to content

Commit 577e9f5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: update agent engine utils to support python-a2a sdk 1.0
PiperOrigin-RevId: 913139425
1 parent 4ba222b commit 577e9f5

1 file changed

Lines changed: 147 additions & 25 deletions

File tree

vertexai/_genai/_agent_engines_utils.py

Lines changed: 147 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -110,30 +110,46 @@
110110

111111

112112
try:
113-
from a2a.types import (
114-
AgentCard,
115-
TransportProtocol,
116-
Message,
117-
TaskIdParams,
118-
TaskQueryParams,
119-
)
120-
from a2a.client import ClientConfig, ClientFactory
121-
122-
AgentCard = AgentCard
123-
TransportProtocol = TransportProtocol
124-
Message = Message
125-
ClientConfig = ClientConfig
126-
ClientFactory = ClientFactory
127-
TaskIdParams = TaskIdParams
128-
TaskQueryParams = TaskQueryParams
129-
except (ImportError, AttributeError):
130-
AgentCard = None
131-
TransportProtocol = None
132-
Message = None
133-
ClientConfig = None
134-
ClientFactory = None
135-
TaskIdParams = None
136-
TaskQueryParams = None
113+
from a2a.utils.constants import TransportProtocol as _TpTest
114+
115+
_A2A_SDK_VERSION: Optional[str] = "1.0"
116+
except ImportError:
117+
try:
118+
from a2a.types import TransportProtocol as _TpTest
119+
120+
_A2A_SDK_VERSION = "0.3"
121+
except ImportError:
122+
_A2A_SDK_VERSION = None
123+
124+
if _A2A_SDK_VERSION == "1.0":
125+
from a2a.types import (
126+
AgentCard,
127+
Message,
128+
)
129+
from a2a.client import ClientConfig, ClientFactory
130+
from a2a.utils.constants import TransportProtocol
131+
from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams
132+
elif _A2A_SDK_VERSION == "0.3":
133+
from a2a.types import (
134+
AgentCard,
135+
TransportProtocol,
136+
Message,
137+
TaskIdParams,
138+
TaskQueryParams,
139+
)
140+
from a2a.client import ClientConfig, ClientFactory
141+
else:
142+
AgentCard = None
143+
TransportProtocol = None
144+
Message = None
145+
ClientConfig = None
146+
ClientFactory = None
147+
TaskIdParams = None
148+
TaskQueryParams = None
149+
SendMessageRequest = None
150+
GetTaskRequest = None
151+
CancelTaskRequest = None
152+
GetExtendedAgentCardRequest = None
137153

138154
_ACTIONS_KEY = "actions"
139155
_ACTION_APPEND = "append"
@@ -1737,7 +1753,7 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
17371753
return _method
17381754

17391755

1740-
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1756+
def _wrap_a2a_operation_v03(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
17411757
"""Wraps an Agent Engine method, creating a callable for A2A API.
17421758
17431759
Args:
@@ -1854,6 +1870,112 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18541870
return _method # type: ignore[return-value]
18551871

18561872

1873+
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]:
1874+
"""Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+).
1875+
1876+
Args:
1877+
method_name: The name of the Agent Engine method to call.
1878+
agent_card: The agent card JSON string to use for the A2A API call.
1879+
Example: {'name': 'Sample Agent', 'description': ( 'A helpful
1880+
assistant agent that can answer questions.' ),
1881+
'supportedInterfaces': [{ 'url': 'http://localhost:8080/a2a/rest/',
1882+
'protocolBinding': 'HTTP+JSON', 'protocolVersion': '1.0', }],
1883+
'version': '1.0.0', 'capabilities': { 'streaming': True,
1884+
'pushNotifications': False, 'extendedAgentCard': True, },
1885+
'defaultInputModes': ['text'], 'defaultOutputModes': ['text'],
1886+
'skills': [{ 'id': 'question_answer', 'name': 'Q&A Agent',
1887+
'description': ( 'A helpful assistant agent that can answer
1888+
questions.' ), 'tags': ['Question-Answer'], 'examples': [ 'Who is
1889+
leading 2025 F1 Standings?', 'Where can i find an active volcano?',
1890+
], 'inputModes': ['text'], 'outputModes': ['text'], }]}
1891+
1892+
Returns:
1893+
A callable object that executes the method on the Agent Engine via
1894+
the A2A API.
1895+
"""
1896+
1897+
async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
1898+
if not self.api_client:
1899+
raise ValueError("api_client is not initialized.")
1900+
if not self.api_resource:
1901+
raise ValueError("api_resource is not initialized.")
1902+
1903+
a2a_agent_card = AgentCard()
1904+
json_format.ParseDict(
1905+
json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True
1906+
)
1907+
1908+
if a2a_agent_card.supported_interfaces:
1909+
interface = a2a_agent_card.supported_interfaces[0]
1910+
if interface.protocol_binding != TransportProtocol.HTTP_JSON:
1911+
raise ValueError(
1912+
"Only HTTP+JSON is supported for preferred transport on agent card"
1913+
)
1914+
else:
1915+
raise ValueError("Agent card does not define any supported interfaces.")
1916+
1917+
# base_url = self.api_client._api_client._http_options.base_url.rstrip("/")
1918+
# api_version = self.api_client._api_client._http_options.api_version
1919+
# a2a_agent_card.supported_interfaces[0].url = (
1920+
# f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1921+
# )
1922+
1923+
config = ClientConfig(
1924+
supported_protocol_bindings=[
1925+
TransportProtocol.HTTP_JSON,
1926+
],
1927+
use_client_preference=True,
1928+
httpx_client=httpx.AsyncClient(
1929+
headers={
1930+
"Authorization": (
1931+
f"Bearer {self.api_client._api_client._credentials.token}"
1932+
)
1933+
},
1934+
timeout=(
1935+
self.api_client._api_client._http_options.timeout / 1000.0
1936+
if self.api_client._api_client._http_options.timeout
1937+
else None
1938+
),
1939+
),
1940+
)
1941+
factory = ClientFactory(config)
1942+
client = factory.create(a2a_agent_card)
1943+
1944+
context = kwargs.pop("context", None)
1945+
if context is not None:
1946+
from a2a.client.client import ClientCallContext
1947+
1948+
if not isinstance(context, ClientCallContext):
1949+
actual_context = ClientCallContext()
1950+
if hasattr(context, "state"):
1951+
actual_context.state = context.state
1952+
elif isinstance(context, dict):
1953+
actual_context.state = context
1954+
context = actual_context
1955+
1956+
req = kwargs["request"]
1957+
if method_name == "on_message_send":
1958+
response = client.send_message(req, context=context)
1959+
chunks = []
1960+
async for chunk in response:
1961+
chunks.append(chunk)
1962+
return chunks
1963+
elif method_name == "on_get_task":
1964+
return await client.get_task(req, context=context)
1965+
elif method_name == "on_cancel_task":
1966+
return await client.cancel_task(req, context=context)
1967+
elif method_name == "on_get_extended_agent_card":
1968+
return await client.get_extended_agent_card(req, context=context)
1969+
else:
1970+
raise ValueError(f"Unknown method name: {method_name}")
1971+
1972+
return _method # type: ignore[return-value]
1973+
1974+
1975+
if _A2A_SDK_VERSION != "1.0":
1976+
_wrap_a2a_operation = _wrap_a2a_operation_v03
1977+
1978+
18571979
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
18581980
"""Converts the body of the HTTP Response message to JSON format.
18591981

0 commit comments

Comments
 (0)