Skip to content

Commit f85ad95

Browse files
onematchfoxsupreme-gg-ggEItanya
authored
fix: ensure user identity is propagated across A2A requests/sessions (#1775)
Ensures that caller identity correctly propagates from controller->agent->controller. Addresses #1293 (comment) and potentially also #1771 --------- Signed-off-by: Brian Fox <878612+onematchfox@users.noreply.github.com> Co-authored-by: Jet Chiang <pokyuen.jetchiang-ext@solo.io> Co-authored-by: Eitan Yarmush <eitan.yarmush@solo.io>
1 parent ca90cdd commit f85ad95

9 files changed

Lines changed: 128 additions & 128 deletions

File tree

go/adk/pkg/a2a/executor.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/a2aproject/a2a-go/a2asrv"
1212
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
1313
"github.com/go-logr/logr"
14+
"github.com/kagent-dev/kagent/go/adk/pkg/auth"
1415
"github.com/kagent-dev/kagent/go/adk/pkg/models"
1516
"github.com/kagent-dev/kagent/go/adk/pkg/session"
1617
"github.com/kagent-dev/kagent/go/adk/pkg/skills"
@@ -117,6 +118,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
117118
sessionID := reqCtx.ContextID
118119

119120
ctx = withBearerToken(ctx)
121+
ctx = auth.WithUserID(ctx, userID)
120122

121123
e.logger.Info("Execute",
122124
"taskID", reqCtx.TaskID,

go/adk/pkg/auth/token.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@ import (
88
"time"
99
)
1010

11+
type contextKey int
12+
13+
const userIDKey contextKey = iota
14+
15+
// WithUserID returns a copy of ctx that carries the user ID for injection into
16+
// outgoing HTTP requests by TokenRoundTripper.
17+
func WithUserID(ctx context.Context, userID string) context.Context {
18+
return context.WithValue(ctx, userIDKey, userID)
19+
}
20+
21+
func userIDFromContext(ctx context.Context) string {
22+
id, _ := ctx.Value(userIDKey).(string)
23+
return id
24+
}
25+
1126
const kagentTokenPath = "/var/run/secrets/tokens/kagent-token"
1227

1328
// KAgentTokenService reads a k8s token from a file and reloads it periodically
@@ -61,6 +76,9 @@ func (s *KAgentTokenService) AddHeaders(req *http.Request) {
6176
if token := s.GetToken(); token != "" {
6277
req.Header.Set("Authorization", "Bearer "+token)
6378
}
79+
if userID := userIDFromContext(req.Context()); userID != "" {
80+
req.Header.Set("X-User-Id", userID)
81+
}
6482
}
6583

6684
// readToken reads the token from the file

go/core/internal/httpserver/auth/proxy_authn.go

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,67 +27,74 @@ func NewProxyAuthenticator(userIDClaim string) *ProxyAuthenticator {
2727

2828
func (a *ProxyAuthenticator) Authenticate(ctx context.Context, reqHeaders http.Header, query url.Values) (auth.Session, error) {
2929
authHeader := reqHeaders.Get("Authorization")
30-
31-
// Always read agent identity from X-Agent-Name header (used by agents calling back)
3230
agentID := reqHeaders.Get("X-Agent-Name")
3331

34-
// If we have a Bearer token, parse JWT
35-
if tokenString, ok := strings.CutPrefix(authHeader, "Bearer "); ok {
36-
// Parse JWT without validation (oauth2-proxy or k8s service account already validated)
37-
rawClaims, err := parseJWTPayload(tokenString)
38-
if err != nil {
39-
return nil, ErrUnauthenticated
40-
}
32+
tokenString, ok := strings.CutPrefix(authHeader, "Bearer ")
33+
if !ok {
34+
return nil, ErrUnauthenticated
35+
}
36+
37+
// Parse JWT without validation (oauth2-proxy or k8s service account already validated)
38+
rawClaims, err := parseJWTPayload(tokenString)
39+
if err != nil {
40+
return nil, ErrUnauthenticated
41+
}
4142

42-
userID, _ := rawClaims[a.userIDClaim].(string)
43-
if userID == "" && a.userIDClaim != "sub" {
43+
if agentID != "" {
44+
// Agent call: the Bearer SA token authenticates the pod; the caller's
45+
// identity should be supplied explicitly via X-User-Id / user_id.
46+
// Fall back to the SA sub claim for direct calls to agent pods that
47+
// do not yet propagate the caller identity.
48+
userID := userIDFromRequest(reqHeaders, query)
49+
if userID == "" {
4450
userID, _ = rawClaims["sub"].(string)
4551
}
4652
if userID == "" {
4753
return nil, ErrUnauthenticated
4854
}
49-
5055
return &SimpleSession{
5156
P: auth.Principal{
52-
User: auth.User{ID: userID},
53-
Agent: auth.Agent{ID: agentID},
54-
Claims: rawClaims,
57+
User: auth.User{ID: userID},
58+
Agent: auth.Agent{ID: agentID},
5559
},
5660
authHeader: authHeader,
5761
}, nil
5862
}
5963

60-
// Fall back to service account auth for internal agent-to-controller calls.
61-
// Requires X-Agent-Name to identify the calling agent.
62-
if agentID == "" {
63-
return nil, ErrUnauthenticated
64-
}
65-
66-
// Agents authenticate via user_id query param or X-User-Id header
67-
userID := query.Get("user_id")
68-
if userID == "" {
69-
userID = reqHeaders.Get("X-User-Id")
64+
// Direct user call: identity comes from the OIDC JWT claims.
65+
userID, _ := rawClaims[a.userIDClaim].(string)
66+
if userID == "" && a.userIDClaim != "sub" {
67+
userID, _ = rawClaims["sub"].(string)
7068
}
7169
if userID == "" {
7270
return nil, ErrUnauthenticated
7371
}
74-
7572
return &SimpleSession{
7673
P: auth.Principal{
77-
User: auth.User{
78-
ID: userID,
79-
},
80-
Agent: auth.Agent{
81-
ID: agentID,
82-
},
74+
User: auth.User{ID: userID},
75+
Claims: rawClaims,
8376
},
8477
authHeader: authHeader,
8578
}, nil
8679
}
8780

81+
// userIDFromRequest returns the user identity from the user_id query param or
82+
// X-User-Id header, preferring the query param.
83+
func userIDFromRequest(headers http.Header, query url.Values) string {
84+
if v := query.Get("user_id"); v != "" {
85+
return v
86+
}
87+
return headers.Get("X-User-Id")
88+
}
89+
8890
func (a *ProxyAuthenticator) UpstreamAuth(r *http.Request, session auth.Session, upstreamPrincipal auth.Principal) error {
89-
if simpleSession, ok := session.(*SimpleSession); ok && simpleSession.authHeader != "" {
90-
r.Header.Set("Authorization", simpleSession.authHeader)
91+
if simpleSession, ok := session.(*SimpleSession); ok {
92+
if simpleSession.authHeader != "" {
93+
r.Header.Set("Authorization", simpleSession.authHeader)
94+
}
95+
if userID := simpleSession.P.User.ID; userID != "" {
96+
r.Header.Set("X-User-Id", userID)
97+
}
9198
}
9299
return nil
93100
}

go/core/internal/httpserver/auth/proxy_authn_test.go

Lines changed: 36 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -159,112 +159,63 @@ func TestProxyAuthenticator_Authenticate(t *testing.T) {
159159
}
160160
}
161161

162-
func TestProxyAuthenticator_JWTWithAgentHeader(t *testing.T) {
162+
func TestProxyAuthenticator_AgentCalls(t *testing.T) {
163163
tests := []struct {
164164
name string
165-
claims map[string]any
166-
agentName string
165+
headers map[string]string
166+
queryParams map[string]string
167167
wantUserID string
168168
wantAgentID string
169+
wantErr bool
169170
}{
170171
{
171-
name: "extracts agent identity from header when JWT is present",
172-
claims: map[string]any{
173-
"sub": "system:serviceaccount:kagent:kebab-agent",
174-
"iss": "https://kubernetes.default.svc.cluster.local",
175-
"aud": []any{"kagent"},
172+
name: "agent with SA Bearer token and X-User-Id header uses header identity",
173+
headers: map[string]string{
174+
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
175+
"X-Agent-Name": "kagent/test-agent",
176+
"X-User-Id": "user@example.com",
176177
},
177-
agentName: "kagent__NS__kebab_agent",
178-
wantUserID: "system:serviceaccount:kagent:kebab-agent",
179-
wantAgentID: "kagent__NS__kebab_agent",
178+
wantUserID: "user@example.com",
179+
wantAgentID: "kagent/test-agent",
180180
},
181181
{
182-
name: "works with OIDC JWT and agent header",
183-
claims: map[string]any{
184-
"sub": "user123",
185-
"email": "user@example.com",
182+
name: "agent with SA Bearer token and user_id query param uses query identity",
183+
headers: map[string]string{
184+
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
185+
"X-Agent-Name": "kagent/test-agent",
186186
},
187-
agentName: "kagent__NS__my_agent",
188-
wantUserID: "user123",
189-
wantAgentID: "kagent__NS__my_agent",
190-
},
191-
{
192-
name: "handles JWT without agent header",
193-
claims: map[string]any{
194-
"sub": "user123",
187+
queryParams: map[string]string{
188+
"user_id": "user@example.com",
195189
},
196-
agentName: "",
197-
wantUserID: "user123",
198-
wantAgentID: "",
190+
wantUserID: "user@example.com",
191+
wantAgentID: "kagent/test-agent",
199192
},
200-
}
201-
202-
for _, tt := range tests {
203-
t.Run(tt.name, func(t *testing.T) {
204-
auth := authimpl.NewProxyAuthenticator("")
205-
206-
headers := http.Header{}
207-
token := createTestJWT(tt.claims)
208-
headers.Set("Authorization", "Bearer "+token)
209-
if tt.agentName != "" {
210-
headers.Set("X-Agent-Name", tt.agentName)
211-
}
212-
213-
session, err := auth.Authenticate(context.Background(), headers, url.Values{})
214-
if err != nil {
215-
t.Fatalf("unexpected error: %v", err)
216-
}
217-
218-
principal := session.Principal()
219-
if principal.User.ID != tt.wantUserID {
220-
t.Errorf("User.ID = %q, want %q", principal.User.ID, tt.wantUserID)
221-
}
222-
if principal.Agent.ID != tt.wantAgentID {
223-
t.Errorf("Agent.ID = %q, want %q", principal.Agent.ID, tt.wantAgentID)
224-
}
225-
})
226-
}
227-
}
228-
229-
func TestProxyAuthenticator_ServiceAccountFallback(t *testing.T) {
230-
tests := []struct {
231-
name string
232-
headers map[string]string
233-
queryParams map[string]string
234-
wantUserID string
235-
wantAgentID string
236-
wantErr bool
237-
}{
238193
{
239-
name: "authenticates via user_id query param with agent name",
240-
queryParams: map[string]string{
241-
"user_id": "system:serviceaccount:kagent:kebab-agent",
242-
},
194+
name: "agent with no X-User-Id falls back to SA sub claim",
243195
headers: map[string]string{
244-
"X-Agent-Name": "kagent/kebab-agent",
196+
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
197+
"X-Agent-Name": "kagent/test-agent",
245198
},
246-
wantUserID: "system:serviceaccount:kagent:kebab-agent",
247-
wantAgentID: "kagent/kebab-agent",
248-
wantErr: false,
199+
wantUserID: "system:serviceaccount:kagent:test-agent",
200+
wantAgentID: "kagent/test-agent",
249201
},
202+
// Error cases.
250203
{
251-
name: "authenticates via X-User-Id header with agent name",
204+
name: "agent without Bearer token is rejected",
252205
headers: map[string]string{
253-
"X-User-Id": "system:serviceaccount:kagent:test-agent",
254206
"X-Agent-Name": "kagent/test-agent",
207+
"X-User-Id": "user@example.com",
255208
},
256-
wantUserID: "system:serviceaccount:kagent:test-agent",
257-
wantAgentID: "kagent/test-agent",
258-
wantErr: false,
209+
wantErr: true,
259210
},
260211
{
261-
name: "returns error when no auth method available",
212+
name: "no token and no X-Agent-Name is rejected",
262213
wantErr: true,
263214
},
264215
{
265-
name: "returns error when no X-Agent-Name header for fallback",
216+
name: "user_id without X-Agent-Name is rejected",
266217
queryParams: map[string]string{
267-
"user_id": "system:serviceaccount:kagent:kebab-agent",
218+
"user_id": "user@example.com",
268219
},
269220
wantErr: true,
270221
},
@@ -339,4 +290,9 @@ func TestProxyAuthenticator_UpstreamAuth(t *testing.T) {
339290
if got := req.Header.Get("Authorization"); got != authHeader {
340291
t.Errorf("Authorization header = %q, want %q", got, authHeader)
341292
}
293+
294+
// Verify X-User-Id is forwarded so downstream A2A runtimes receive the real user identity
295+
if got := req.Header.Get("X-User-Id"); got != "user123" {
296+
t.Errorf("X-User-Id header = %q, want %q", got, "user123")
297+
}
342298
}

python/packages/kagent-adk/src/kagent/adk/_session_service.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ async def create_session(
4949
response = await self.client.post(
5050
"/api/sessions",
5151
json=request_data,
52-
headers={"X-User-ID": user_id},
5352
)
5453
response.raise_for_status()
5554

@@ -88,10 +87,7 @@ async def get_session(
8887
url += "&limit=-1"
8988

9089
# Make API call to get session
91-
response: httpx.Response = await self.client.get(
92-
url,
93-
headers={"X-User-ID": user_id},
94-
)
90+
response: httpx.Response = await self.client.get(url)
9591
if response.status_code == 404:
9692
return None
9793
response.raise_for_status()
@@ -131,7 +127,7 @@ async def get_session(
131127
@override
132128
async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsResponse:
133129
# Make API call to list sessions
134-
response = await self.client.get(f"/api/sessions?user_id={user_id}", headers={"X-User-ID": user_id})
130+
response = await self.client.get(f"/api/sessions?user_id={user_id}")
135131
response.raise_for_status()
136132

137133
data = response.json()
@@ -151,10 +147,7 @@ def list_sessions_sync(self, *, app_name: str, user_id: str) -> ListSessionsResp
151147
@override
152148
async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None:
153149
# Make API call to delete session
154-
response = await self.client.delete(
155-
f"/api/sessions/{session_id}?user_id={user_id}",
156-
headers={"X-User-ID": user_id},
157-
)
150+
response = await self.client.delete(f"/api/sessions/{session_id}?user_id={user_id}")
158151
response.raise_for_status()
159152

160153
@override
@@ -172,7 +165,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
172165
response = await self.client.post(
173166
f"/api/sessions/{session.id}/events?user_id={session.user_id}",
174167
json=event_data,
175-
headers={"X-User-ID": session.user_id},
176168
)
177169
response.raise_for_status()
178170

python/packages/kagent-adk/src/kagent/adk/_token.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Optional
55

66
import httpx
7+
from kagent.core.a2a import get_request_user_id
78

89
KAGENT_TOKEN_PATH = "/var/run/secrets/tokens/kagent-token"
910
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ def event_hooks(self):
3536
"""Returns a dictionary of event hooks for the application
3637
to use when creating the httpx.AsyncClient.
3738
"""
38-
return {"request": [self._add_bearer_token]}
39+
return {"request": [self._add_headers]}
3940

4041
async def _update_token_loop(self) -> None:
4142
self.token = await self._read_kagent_token()
@@ -61,12 +62,13 @@ async def _refresh_token(self):
6162
async with self.update_lock:
6263
self.token = token
6364

64-
async def _add_bearer_token(self, request: httpx.Request):
65-
# Your function to generate headers dynamically
65+
async def _add_headers(self, request: httpx.Request):
6666
token = await self._get_token()
6767
headers = {"X-Agent-Name": self.app_name}
6868
if token:
6969
headers["Authorization"] = f"Bearer {token}"
70+
if user_id := get_request_user_id():
71+
headers["X-User-Id"] = user_id
7072
request.headers.update(headers)
7173

7274

0 commit comments

Comments
 (0)