Skip to content

Commit ca90cdd

Browse files
authored
fix(adk): Propagate access token to subagents and mcp tools (#1858)
**Overview** - Fixes an issue where setting the `KAGENT_PROPAGATE_TOKEN` env var for `Agent` CR's using the `python` runtime would only propagate the access token to mcp tools and not to sub agents. - Fixes an issue where the setting the `KAGENT_PROPAGATE_TOKEN` env var for the `go` run would not propagate the token to mcp or subagents as it was unimplented. **Testing** - Manually tested by deploying a coordinator agent and a subagent that uses the `kagent-tool-server` and making an a2a request to the coordinator agent then inspecting the logs. I added logs [here](https://github.com/kagent-dev/kagent/pull/1858/changes#diff-2c6ac47a132ff8be9df40e1902380810f3d2c8ead904caaf53d22161440903c5R56) and [here](https://github.com/kagent-dev/kagent/pull/1858/changes#diff-aa78af986baa8c0d72522820c58c6a60eb1d656aff48985d4ff3fd683a4312bfR267) for the go runtime and [here](https://github.com/kagent-dev/kagent/pull/1858/changes#diff-e2f83de8fb21d0c7da42da4ab5f1ad96d03fe03abbb40432b80ac2adaf483ba8R90) for the python runtime to validate the allowed header was being added by the interceptor closes #1745 --------- Signed-off-by: JM Huibonhoa <jm.huibonhoa@solo.io>
1 parent 6230807 commit ca90cdd

9 files changed

Lines changed: 232 additions & 41 deletions

File tree

go/adk/pkg/agent/agent.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
4848
return nil, nil, fmt.Errorf("agent config is required")
4949
}
5050

51-
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools)
51+
propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true"
52+
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken)
5253
subagentSessionIDs := make(map[string]string)
5354

5455
var remoteAgentTools []tool.Tool
@@ -57,7 +58,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
5758
log.Info("Skipping remote agent with empty URL", "name", remoteAgent.Name)
5859
continue
5960
}
60-
remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers)
61+
remoteTool, sessionID, err := tools.NewKAgentRemoteA2ATool(remoteAgent.Name, remoteAgent.Description, remoteAgent.Url, nil, remoteAgent.Headers, propagateToken)
6162
if err != nil {
6263
return nil, nil, fmt.Errorf("failed to create remote A2A tool for %s: %w", remoteAgent.Name, err)
6364
}

go/adk/pkg/constants/const.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package constants
2+
3+
const (
4+
// A2A call context's NewRequestMeta normalizes header names to lowercase.
5+
// This is why we use "authorization" instead of "Authorization".
6+
AuthorizationHeader = "authorization"
7+
)

go/adk/pkg/mcp/registry.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/a2aproject/a2a-go/a2asrv"
1313
"github.com/go-logr/logr"
14+
"github.com/kagent-dev/kagent/go/adk/pkg/constants"
1415
"github.com/kagent-dev/kagent/go/api/adk"
1516
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
1617
"google.golang.org/adk/tool"
@@ -62,6 +63,7 @@ type mcpServerParams struct {
6263
URL string
6364
Headers map[string]string
6465
AllowedHeaders []string // header names to forward from incoming request
66+
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
6567
ServerType string // "http" or "sse"
6668
Timeout *float64
6769
SseReadTimeout *float64
@@ -73,7 +75,11 @@ type mcpServerParams struct {
7375
// CreateToolsets creates toolsets from all configured HTTP and SSE MCP servers,
7476
// returning the accumulated toolsets. Errors on individual servers are logged
7577
// and skipped.
76-
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig) []tool.Toolset {
78+
//
79+
// When propagateToken is true, Authorization is forwarded to every MCP server
80+
// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
81+
// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
82+
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset {
7783
log := logr.FromContextOrDiscard(ctx)
7884
var toolsets []tool.Toolset
7985

@@ -83,6 +89,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
8389
URL: httpTool.Params.Url,
8490
Headers: httpTool.Params.Headers,
8591
AllowedHeaders: httpTool.AllowedHeaders,
92+
PropagateToken: propagateToken,
8693
ServerType: "http",
8794
Timeout: httpTool.Params.Timeout,
8895
SseReadTimeout: httpTool.Params.SseReadTimeout,
@@ -103,6 +110,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
103110
URL: sseTool.Params.Url,
104111
Headers: sseTool.Params.Headers,
105112
AllowedHeaders: sseTool.AllowedHeaders,
113+
PropagateToken: propagateToken,
106114
ServerType: "sse",
107115
Timeout: sseTool.Params.Timeout,
108116
SseReadTimeout: sseTool.Params.SseReadTimeout,
@@ -200,11 +208,12 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
200208
}
201209

202210
var httpTransport http.RoundTripper = baseTransport
203-
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 {
211+
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken {
204212
httpTransport = &headerRoundTripper{
205213
base: baseTransport,
206214
headers: params.Headers,
207215
allowedHeaders: params.AllowedHeaders,
216+
propagateToken: params.PropagateToken,
208217
}
209218
}
210219

@@ -230,30 +239,41 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
230239
}
231240

232241
// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
233-
// requests. It supports two sources of headers:
234-
// - headers: static key/value pairs configured on the MCP server spec
235-
// - allowedHeaders: header names to forward from the incoming A2A request;
236-
// values are read on each call via allowedRequestHeaders directly from the
237-
// A2A CallContext that is already present in the Go context.
238-
//
239-
// Static headers take precedence: if an allowed header has the same name as a
240-
// static header, the static value wins.
242+
// requests. It supports three sources of headers, applied in this order so that
243+
// higher-priority sources win on collision:
244+
// 1. propagateToken: when true, Authorization is read from the incoming A2A
245+
// CallContext and forwarded unconditionally (independent of allowedHeaders).
246+
// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
247+
// 3. headers: static key/value pairs configured on the MCP server spec (highest
248+
// priority — always wins).
241249
type headerRoundTripper struct {
242250
base http.RoundTripper
243251
headers map[string]string
244252
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
253+
propagateToken bool // when true, Authorization is forwarded independently
245254
}
246255

247256
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
248257
req = req.Clone(req.Context())
249258

250-
// Forward allowed headers from the incoming A2A request first so that
251-
// static headers can override them if there is a name collision.
259+
// When KAGENT_PROPAGATE_TOKEN is set, forward Authorization from the incoming
260+
// A2A request independently of allowedHeaders.
261+
if rt.propagateToken {
262+
if callCtx, ok := a2asrv.CallContextFrom(req.Context()); ok {
263+
if meta := callCtx.RequestMeta(); meta != nil {
264+
if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" {
265+
req.Header.Set(constants.AuthorizationHeader, vals[0])
266+
}
267+
}
268+
}
269+
}
270+
271+
// Forward explicitly allowed headers from the incoming A2A request.
252272
for k, v := range allowedRequestHeaders(req.Context(), rt.allowedHeaders) {
253273
req.Header.Set(k, v)
254274
}
255275

256-
// Apply static headers (override any dynamic ones with the same name).
276+
// Apply static headers last — they take precedence over all dynamic sources.
257277
for key, value := range rt.headers {
258278
req.Header.Set(key, value)
259279
}

go/adk/pkg/mcp/registry_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,73 @@ func TestAllowedRequestHeaders_MultiValueFirstWins(t *testing.T) {
239239
}
240240
}
241241

242+
// TestPropagateToken_ForwardsAuthorizationToMCP verifies that when propagateToken
243+
// is set on headerRoundTripper, the Authorization header from the incoming A2A
244+
// CallContext is forwarded to the outbound MCP request independently of allowedHeaders.
245+
func TestPropagateToken_ForwardsAuthorizationToMCP(t *testing.T) {
246+
t.Parallel()
247+
var capturedAuth string
248+
249+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
250+
capturedAuth = r.Header.Get("Authorization")
251+
w.WriteHeader(http.StatusOK)
252+
}))
253+
defer srv.Close()
254+
255+
ctx := a2aCtx(map[string][]string{
256+
"Authorization": {"Bearer propagated-token"},
257+
})
258+
259+
rt := &headerRoundTripper{
260+
base: http.DefaultTransport,
261+
propagateToken: true,
262+
}
263+
264+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
265+
resp, err := rt.RoundTrip(req)
266+
if err != nil {
267+
t.Fatalf("RoundTrip failed: %v", err)
268+
}
269+
resp.Body.Close()
270+
271+
if capturedAuth != "Bearer propagated-token" {
272+
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer propagated-token")
273+
}
274+
}
275+
276+
// TestPropagateToken_DoesNotForwardWhenDisabled verifies that when propagateToken
277+
// is false, the Authorization header is not forwarded unless listed in allowedHeaders.
278+
func TestPropagateToken_DoesNotForwardWhenDisabled(t *testing.T) {
279+
t.Parallel()
280+
var capturedAuth string
281+
282+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
283+
capturedAuth = r.Header.Get("Authorization")
284+
w.WriteHeader(http.StatusOK)
285+
}))
286+
defer srv.Close()
287+
288+
ctx := a2aCtx(map[string][]string{
289+
"Authorization": {"Bearer propagated-token"},
290+
})
291+
292+
rt := &headerRoundTripper{
293+
base: http.DefaultTransport,
294+
propagateToken: false,
295+
}
296+
297+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
298+
resp, err := rt.RoundTrip(req)
299+
if err != nil {
300+
t.Fatalf("RoundTrip failed: %v", err)
301+
}
302+
resp.Body.Close()
303+
304+
if capturedAuth != "" {
305+
t.Errorf("Authorization should not be forwarded when propagateToken=false, got %q", capturedAuth)
306+
}
307+
}
308+
242309
// TestAllowedRequestHeaders_ReturnsNilWhenNoMatches verifies that the helper returns
243310
// nil rather than an empty map when the allowed list has entries but none of them
244311
// appear in the request metadata.

go/adk/pkg/tools/remote_a2a_tool.go

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
a2atype "github.com/a2aproject/a2a-go/a2a"
1212
"github.com/a2aproject/a2a-go/a2aclient"
1313
"github.com/a2aproject/a2a-go/a2aclient/agentcard"
14+
"github.com/a2aproject/a2a-go/a2asrv"
1415
"github.com/kagent-dev/kagent/go/adk/pkg/a2a"
16+
"github.com/kagent-dev/kagent/go/adk/pkg/constants"
1517
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
1618
"google.golang.org/adk/tool"
1719
"google.golang.org/adk/tool/functiontool"
@@ -32,6 +34,30 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient
3234
return ctx, nil
3335
}
3436

37+
// authzForwardingInterceptor forwards the Authorization header from the
38+
// incoming A2A request context to outbound sub-agent A2A calls.
39+
type authzForwardingInterceptor struct {
40+
a2aclient.PassthroughInterceptor
41+
}
42+
43+
func (a *authzForwardingInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) {
44+
callCtx, ok := a2asrv.CallContextFrom(ctx)
45+
if !ok {
46+
return ctx, nil
47+
}
48+
meta := callCtx.RequestMeta()
49+
if meta == nil {
50+
return ctx, nil
51+
}
52+
if len(req.Meta.Get(constants.AuthorizationHeader)) > 0 {
53+
return ctx, nil
54+
}
55+
if vals, ok := meta.Get(constants.AuthorizationHeader); ok && len(vals) > 0 && vals[0] != "" {
56+
req.Meta.Append(constants.AuthorizationHeader, vals[0])
57+
}
58+
return ctx, nil
59+
}
60+
3561
// remoteA2AInput is the typed argument for the remote A2A function tool.
3662
type remoteA2AInput struct {
3763
Request string `json:"request"`
@@ -40,11 +66,12 @@ type remoteA2AInput struct {
4066
// remoteA2AState holds the mutable state for one remote A2A agent connection.
4167
// All external interaction goes through the tool.Tool returned by NewKAgentRemoteA2ATool.
4268
type remoteA2AState struct {
43-
name string
44-
description string
45-
baseURL string
46-
httpClient *http.Client
47-
extraHeaders map[string]string
69+
name string
70+
description string
71+
baseURL string
72+
httpClient *http.Client
73+
extraHeaders map[string]string
74+
propagateToken bool
4875

4976
a2aClient *a2aclient.Client
5077
agentCard *a2atype.AgentCard
@@ -62,18 +89,19 @@ type remoteA2AState struct {
6289
// The agent card is fetched lazily from baseURL/.well-known/agent.json.
6390
// If httpClient is nil, a default client is created. The client's transport is
6491
// wrapped with otelhttp to propagate W3C trace context to subagents.
65-
func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string) (tool.Tool, string, error) {
92+
func NewKAgentRemoteA2ATool(name, description, baseURL string, httpClient *http.Client, extraHeaders map[string]string, propagateToken bool) (tool.Tool, string, error) {
6693
if httpClient == nil {
6794
httpClient = &http.Client{}
6895
}
6996
httpClient = withOTelTransport(httpClient)
7097
state := &remoteA2AState{
71-
name: name,
72-
description: description,
73-
baseURL: baseURL,
74-
httpClient: httpClient,
75-
extraHeaders: extraHeaders,
76-
lastContextID: a2atype.NewContextID(),
98+
name: name,
99+
description: description,
100+
baseURL: baseURL,
101+
httpClient: httpClient,
102+
extraHeaders: extraHeaders,
103+
propagateToken: propagateToken,
104+
lastContextID: a2atype.NewContextID(),
77105
}
78106
ft, err := functiontool.New(functiontool.Config{
79107
Name: name,
@@ -119,10 +147,14 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e
119147
for k, v := range s.extraHeaders {
120148
meta.Append(k, v)
121149
}
122-
opts = append(opts, a2aclient.WithInterceptors(
150+
interceptors := []a2aclient.CallInterceptor{
123151
a2aclient.NewStaticCallMetaInjector(meta),
124152
&userIDForwardingInterceptor{},
125-
))
153+
}
154+
if s.propagateToken {
155+
interceptors = append(interceptors, &authzForwardingInterceptor{})
156+
}
157+
opts = append(opts, a2aclient.WithInterceptors(interceptors...))
126158

127159
client, err := a2aclient.NewFromCard(ctx, card, opts...)
128160
if err != nil {

0 commit comments

Comments
 (0)