Skip to content

Commit fcd5adf

Browse files
committed
feat: upgrade a2a-sdk dependency to v1.0.0-alpha.0 (A2A 1.0 spec)
- Bump a2a-sdk from 0.3.x to >=1.0.0a0 - Migrate from Pydantic models to Protocol Buffer messages - Replace TextPart/DataPart with unified Part proto - Update all enums to proto naming (Role.ROLE_USER, TaskState.TASK_STATE_SUBMITTED, etc.) - Replace deprecated AgentCard.url with supported_interfaces - Switch to MessageToDict for proto serialization - Remove deprecated request_metadata from send_message - Fix (update, task) tuple unpacking for streaming responses All 426 tests pass (315 a2a + 94 remote agent + 17 agent registry).
1 parent f973673 commit fcd5adf

35 files changed

Lines changed: 1346 additions & 1009 deletions

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ dev = [
9797

9898
a2a = [
9999
# go/keep-sorted start
100-
"a2a-sdk>=0.3.4,<0.4.0",
100+
"a2a-sdk>=1.0.0a0",
101101
# go/keep-sorted end
102102
]
103103

@@ -120,7 +120,7 @@ eval = [
120120

121121
test = [
122122
# go/keep-sorted start
123-
"a2a-sdk>=0.3.0,<0.4.0",
123+
"a2a-sdk>=1.0.0a0",
124124
"anthropic>=0.43.0", # For anthropic model tests
125125
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+
126126
"kubernetes>=29.0.0", # For GkeCodeExecutor

src/google/adk/a2a/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,42 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from a2a.types import Role
18+
from a2a.types import TaskState
19+
20+
21+
def _install_task_state_aliases() -> None:
22+
"""Adds pre-1.0 TaskState aliases expected by ADK code and tests."""
23+
alias_by_name = {
24+
"working": "TASK_STATE_WORKING",
25+
"failed": "TASK_STATE_FAILED",
26+
"input_required": "TASK_STATE_INPUT_REQUIRED",
27+
"auth_required": "TASK_STATE_AUTH_REQUIRED",
28+
"completed": "TASK_STATE_COMPLETED",
29+
"submitted": "TASK_STATE_SUBMITTED",
30+
"canceled": "TASK_STATE_CANCELED",
31+
"unknown": "TASK_STATE_UNKNOWN",
32+
}
33+
for alias, canonical in alias_by_name.items():
34+
if not hasattr(TaskState, alias) and hasattr(TaskState, canonical):
35+
setattr(TaskState, alias, getattr(TaskState, canonical))
36+
37+
38+
_install_task_state_aliases()
39+
40+
41+
def _install_role_aliases() -> None:
42+
"""Adds pre-1.0 Role aliases expected by ADK code and tests."""
43+
alias_by_name = {
44+
"user": "ROLE_USER",
45+
"agent": "ROLE_AGENT",
46+
}
47+
for alias, canonical in alias_by_name.items():
48+
if not hasattr(Role, alias) and hasattr(Role, canonical):
49+
setattr(Role, alias, getattr(Role, canonical))
50+
51+
52+
_install_role_aliases()

src/google/adk/a2a/agent/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Optional
2323
from typing import Union
2424

25-
from a2a.client.middleware import ClientCallContext
25+
from a2a.client import ClientCallContext
2626
from a2a.server.events import Event as A2AEvent
2727
from a2a.types import Message as A2AMessage
2828
from pydantic import BaseModel

src/google/adk/a2a/agent/interceptors/new_integration_extension.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typing import Union
1919

20-
from a2a.client.middleware import ClientCallContext
20+
from a2a.client import ClientCallContext
2121
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2222
from a2a.types import Message as A2AMessage
2323
from google.adk.a2a.agent.config import ParametersConfig
@@ -39,15 +39,13 @@ async def _before_request(
3939
if params.client_call_context is None:
4040
params.client_call_context = ClientCallContext()
4141

42-
http_kwargs = params.client_call_context.state.get('http_kwargs', {})
43-
headers = http_kwargs.get('headers', {})
42+
headers = params.client_call_context.service_parameters or {}
4443
a2a_extensions = headers.get(HTTP_EXTENSION_HEADER, '').split(',')
4544
a2a_extensions = [ext for ext in a2a_extensions if ext]
4645
if _NEW_A2A_ADK_INTEGRATION_EXTENSION not in a2a_extensions:
4746
a2a_extensions.append(_NEW_A2A_ADK_INTEGRATION_EXTENSION)
4847
headers[HTTP_EXTENSION_HEADER] = ','.join(a2a_extensions)
49-
http_kwargs['headers'] = headers
50-
params.client_call_context.state['http_kwargs'] = http_kwargs
48+
params.client_call_context.service_parameters = headers
5149
return a2a_request, params
5250

5351

src/google/adk/a2a/agent/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Union
2121

2222
from a2a.client import ClientEvent as A2AClientEvent
23-
from a2a.client.middleware import ClientCallContext
23+
from a2a.client import ClientCallContext
2424
from a2a.types import Message as A2AMessage
2525

2626
from ...agents.invocation_context import InvocationContext

src/google/adk/a2a/converters/event_converter.py

Lines changed: 144 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@
2424
from typing import Optional
2525

2626
from a2a.server.events import Event as A2AEvent
27-
from a2a.types import DataPart
2827
from a2a.types import Message
2928
from a2a.types import Part as A2APart
3029
from a2a.types import Role
3130
from a2a.types import Task
3231
from a2a.types import TaskState
3332
from a2a.types import TaskStatus
3433
from a2a.types import TaskStatusUpdateEvent
35-
from a2a.types import TextPart
3634
from google.adk.platform import time as platform_time
3735
from google.adk.platform import uuid as platform_uuid
3836
from google.genai import types as genai_types
37+
from google.protobuf.json_format import MessageToDict
38+
from google.protobuf.timestamp_pb2 import Timestamp
3939

4040
from ...agents.invocation_context import InvocationContext
4141
from ...events.event import Event
@@ -105,6 +105,83 @@ def _serialize_metadata_value(value: Any) -> str:
105105
return str(value)
106106

107107

108+
def _get_part_metadata_value(part: A2APart, key: str) -> Any:
109+
"""Returns a metadata value from either proto Struct or dict-like metadata."""
110+
metadata = getattr(part, "metadata", None)
111+
if not metadata:
112+
return None
113+
try:
114+
return metadata.get(key)
115+
except AttributeError:
116+
try:
117+
return metadata[key]
118+
except Exception:
119+
return None
120+
121+
122+
def _get_part_data_dict(part: A2APart) -> Dict[str, Any]:
123+
"""Returns a part's data payload as a plain dict when possible."""
124+
data = getattr(part, "data", None)
125+
if data is None:
126+
return {}
127+
if isinstance(data, dict):
128+
return data
129+
get_method = getattr(data, "get", None)
130+
if callable(get_method):
131+
try:
132+
return {
133+
"id": get_method("id"),
134+
"name": get_method("name"),
135+
}
136+
except Exception:
137+
pass
138+
try:
139+
return MessageToDict(data)
140+
except Exception:
141+
return {}
142+
143+
144+
def _coerce_a2a_message(message: Message | Any) -> Message:
145+
"""Returns a proto Message, tolerating older mock/dict-style inputs in tests."""
146+
if (
147+
isinstance(message, Message)
148+
and type(message).__module__ != "unittest.mock"
149+
):
150+
return message
151+
152+
coerced_message = Message()
153+
for field_name in ("message_id", "task_id", "context_id"):
154+
field_value = getattr(message, field_name, None)
155+
if field_value:
156+
setattr(coerced_message, field_name, field_value)
157+
158+
role = getattr(message, "role", None)
159+
if role is not None:
160+
coerced_message.role = role
161+
else:
162+
coerced_message.role = Role.ROLE_AGENT
163+
164+
parts = getattr(message, "parts", None)
165+
if parts:
166+
for part in parts:
167+
if isinstance(part, A2APart):
168+
coerced_message.parts.append(part)
169+
170+
metadata = getattr(message, "metadata", None)
171+
if metadata:
172+
coerced_message.metadata.update(metadata)
173+
174+
return coerced_message
175+
176+
177+
def _create_timestamp() -> Timestamp:
178+
"""Creates a protobuf timestamp from the current platform time."""
179+
now = platform_time.get_time()
180+
seconds = int(now)
181+
nanos = int((now - seconds) * 1_000_000_000)
182+
return Timestamp(seconds=seconds, nanos=nanos)
183+
184+
108185
def _get_context_metadata(
109186
event: Event, invocation_context: InvocationContext
110187
) -> Dict[str, str]:
@@ -184,19 +261,30 @@ def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None:
184261
a2a_part: The A2A part to potentially mark as long-running.
185262
event: The ADK event containing long-running tool information.
186263
"""
187-
if (
188-
isinstance(a2a_part.root, DataPart)
189-
and event.long_running_tool_ids
190-
and a2a_part.root.metadata
191-
and a2a_part.root.metadata.get(
192-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
193-
)
194-
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
195-
and a2a_part.root.data.get("id") in event.long_running_tool_ids
196-
):
197-
a2a_part.root.metadata[
198-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
199-
] = True
264+
if not event.long_running_tool_ids or not getattr(a2a_part, "metadata", None):
265+
return
266+
has_data = getattr(a2a_part, "HasField", None)
267+
if callable(has_data):
268+
try:
269+
if not a2a_part.HasField("data"):
270+
return
271+
except Exception:
272+
pass
273+
274+
type_key = _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
275+
part_type = (
276+
_get_part_metadata_value(a2a_part, type_key)
277+
or _get_part_metadata_value(a2a_part, A2A_DATA_PART_METADATA_TYPE_KEY)
278+
or _get_part_metadata_value(a2a_part, "adk_type")
279+
)
280+
if part_type != A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL:
281+
return
282+
283+
data_dict = _get_part_data_dict(a2a_part)
284+
if data_dict.get("id") in event.long_running_tool_ids:
285+
a2a_part.metadata.update({
286+
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True
287+
})
200288

201289

202290
def convert_a2a_task_to_event(
@@ -229,7 +317,7 @@ def convert_a2a_task_to_event(
229317
message = None
230318
if a2a_task.artifacts:
231319
message = Message(
232-
message_id="", role=Role.agent, parts=a2a_task.artifacts[-1].parts
320+
message_id="", role=Role.ROLE_AGENT, parts=a2a_task.artifacts[-1].parts
233321
)
234322
elif (
235323
a2a_task.status
@@ -321,15 +409,10 @@ def convert_a2a_message_to_event(
321409
continue
322410

323411
# Check for long-running tools
324-
if (
325-
a2a_part.root.metadata
326-
and a2a_part.root.metadata.get(
327-
_get_adk_metadata_key(
328-
A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
329-
)
330-
)
331-
is True
332-
):
412+
if _get_part_metadata_value(
413+
a2a_part,
414+
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
415+
) is True:
333416
for part in parts:
334417
if part.function_call:
335418
long_running_tool_ids.add(part.function_call.id)
@@ -372,7 +455,7 @@ def convert_a2a_message_to_event(
372455
def convert_event_to_a2a_message(
373456
event: Event,
374457
invocation_context: InvocationContext | None = None,
375-
role: Role = Role.agent,
458+
role: Role = Role.ROLE_AGENT,
376459
part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
377460
) -> Optional[Message]:
378461
"""Converts an ADK event to an A2A message.
@@ -446,22 +529,19 @@ def _create_error_status_event(
446529
context_id=context_id,
447530
metadata=event_metadata,
448531
status=TaskStatus(
449-
state=TaskState.failed,
532+
state=TaskState.TASK_STATE_FAILED,
450533
message=Message(
451534
message_id=platform_uuid.new_uuid(),
452-
role=Role.agent,
453-
parts=[TextPart(text=error_message)],
535+
role=Role.ROLE_AGENT,
536+
parts=[A2APart(text=error_message)],
454537
metadata={
455538
_get_adk_metadata_key("error_code"): str(event.error_code)
456539
}
457540
if event.error_code
458541
else {},
459542
),
460-
timestamp=datetime.fromtimestamp(
461-
platform_time.get_time(), tz=timezone.utc
462-
).isoformat(),
463-
),
464-
final=False,
543+
timestamp=_create_timestamp(),
544+
)
465545
)
466546

467547

@@ -484,48 +564,45 @@ def _create_status_update_event(
484564
Returns:
485565
A TaskStatusUpdateEvent with RUNNING state.
486566
"""
567+
proto_message = _coerce_a2a_message(message)
568+
487569
status = TaskStatus(
488-
state=TaskState.working,
489-
message=message,
490-
timestamp=datetime.fromtimestamp(
491-
platform_time.get_time(), tz=timezone.utc
492-
).isoformat(),
570+
state=TaskState.TASK_STATE_WORKING,
571+
message=proto_message,
572+
timestamp=_create_timestamp(),
493573
)
494574

495-
if any(
496-
part.root.metadata.get(
497-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
498-
)
499-
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
500-
and part.root.metadata.get(
501-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
502-
)
503-
is True
504-
and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME
505-
for part in message.parts
506-
if part.root.metadata
507-
):
508-
status.state = TaskState.auth_required
509-
elif any(
510-
part.root.metadata.get(
511-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
512-
)
513-
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
514-
and part.root.metadata.get(
515-
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
516-
)
517-
is True
518-
for part in message.parts
519-
if part.root.metadata
520-
):
521-
status.state = TaskState.input_required
575+
def _is_long_running(part: A2APart) -> bool:
576+
val = _get_part_metadata_value(
577+
part,
578+
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
579+
)
580+
return str(val).lower() == "true" or val is True
581+
582+
for part in message.parts:
583+
part_type = (
584+
_get_part_metadata_value(
585+
part, _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
586+
)
587+
or _get_part_metadata_value(part, A2A_DATA_PART_METADATA_TYPE_KEY)
588+
or _get_part_metadata_value(part, "adk_type")
589+
)
590+
if (
591+
part_type == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
592+
and _is_long_running(part)
593+
):
594+
data_dict = _get_part_data_dict(part)
595+
if data_dict.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME:
596+
status.state = TaskState.TASK_STATE_AUTH_REQUIRED
597+
break
598+
status.state = TaskState.TASK_STATE_INPUT_REQUIRED
599+
break
522600

523601
return TaskStatusUpdateEvent(
524602
task_id=task_id,
525603
context_id=context_id,
526604
status=status,
527605
metadata=_get_context_metadata(event, invocation_context),
528-
final=False,
529606
)
530607

531608

0 commit comments

Comments
 (0)