|
11 | 11 | This is a BaseToolset wrapper around KAgentRemoteA2ATool for runner cleanup purposes. |
12 | 12 | """ |
13 | 13 |
|
| 14 | +# ruff: noqa: E402 |
| 15 | + |
14 | 16 | import logging |
15 | 17 | import uuid |
16 | 18 | from typing import Any, Callable, Optional, Protocol, runtime_checkable |
17 | 19 | from urllib.parse import urlparse |
18 | 20 |
|
19 | 21 | import httpx |
| 22 | +from kagent.core.a2a._compat import install_v03_type_aliases |
| 23 | + |
| 24 | +install_v03_type_aliases() |
| 25 | + |
20 | 26 | from a2a.client import Client as A2AClient |
21 | 27 | from a2a.client.card_resolver import A2ACardResolver |
22 | 28 | from a2a.client.client import ClientConfig as A2AClientConfig |
23 | 29 | from a2a.client.client_factory import ClientFactory as A2AClientFactory |
24 | 30 | from a2a.client.errors import A2AClientHTTPError |
25 | 31 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor |
26 | | -from a2a.types import ( |
| 32 | +from a2a.compat.v0_3.types import ( |
27 | 33 | AgentCard, |
28 | 34 | DataPart, |
29 | 35 | Role, |
30 | 36 | Task, |
31 | 37 | TaskState, |
32 | 38 | TextPart, |
33 | 39 | ) |
34 | | -from a2a.types import ( |
| 40 | +from a2a.compat.v0_3.types import ( |
35 | 41 | Message as A2AMessage, |
36 | 42 | ) |
37 | | -from a2a.types import ( |
| 43 | +from a2a.compat.v0_3.types import ( |
38 | 44 | Part as A2APart, |
39 | 45 | ) |
40 | | -from a2a.types import ( |
| 46 | +from a2a.compat.v0_3.types import ( |
41 | 47 | TransportProtocol as A2ATransport, |
42 | 48 | ) |
43 | 49 | from google.adk.agents.readonly_context import ReadonlyContext |
@@ -70,6 +76,24 @@ class _SubagentInterceptor(ClientCallInterceptor): |
70 | 76 | headers stored in the call context state under ``_EXTRA_HEADERS_CONTEXT_KEY``. |
71 | 77 | """ |
72 | 78 |
|
| 79 | + async def before(self, args) -> None: |
| 80 | + context = args.context |
| 81 | + if context is None: |
| 82 | + return |
| 83 | + headers = dict(context.service_parameters or {}) |
| 84 | + headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT |
| 85 | + |
| 86 | + if _USER_ID_CONTEXT_KEY in context.state: |
| 87 | + headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY] |
| 88 | + extra = context.state.get(_EXTRA_HEADERS_CONTEXT_KEY) |
| 89 | + if extra: |
| 90 | + headers.update(extra) |
| 91 | + |
| 92 | + context.service_parameters = headers |
| 93 | + |
| 94 | + async def after(self, args) -> None: |
| 95 | + return None |
| 96 | + |
73 | 97 | async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context): |
74 | 98 | headers = dict(http_kwargs.get("headers", {})) |
75 | 99 | headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT |
|
0 commit comments