From eaaa883f535b5360085853b2f88fcee6a3471158 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 18 May 2026 07:34:43 +0100 Subject: [PATCH 1/4] Add pkg/vmcp/session/binding leaf package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The vMCP session-binding format needs to be parsed and produced in two places: pkg/vmcp/session/types (from ShouldAllowAnonymous) and pkg/vmcp/session/internal/security (from BindSession / validateCaller). A private helper in each would drift over time, so this commit introduces a single-owner leaf package. Format encodes a bound identity as iss + "\x00" + sub, rejecting empty halves and stray NULs. Parse mirrors Format strictly — including rejecting trailing NULs in the sub half so values that did not pass through Format (e.g. direct writes to Redis) fail loudly. A literal "unauthenticated" sentinel covers sessions created without an authenticated identity. No callers wired yet — those land in the next commits. Co-Authored-By: Claude Opus 4.7 --- pkg/vmcp/session/binding/binding.go | 78 ++++++++++++++ pkg/vmcp/session/binding/binding_test.go | 129 +++++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 pkg/vmcp/session/binding/binding.go create mode 100644 pkg/vmcp/session/binding/binding_test.go diff --git a/pkg/vmcp/session/binding/binding.go b/pkg/vmcp/session/binding/binding.go new file mode 100644 index 0000000000..8ba3346a8f --- /dev/null +++ b/pkg/vmcp/session/binding/binding.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package binding is the single owner of the identity-binding format used by +// vMCP session storage. An identity binding encodes the OIDC principal that +// created a session as a single opaque string suitable for storage in a +// key/value store such as Redis or Valkey. +// +// # Format +// +// A bound identity is encoded as iss + "\x00" + sub. NUL is rejected from +// either half by Format and Parse: OIDC Core does not formally forbid NUL in +// sub, but no real-world issuer emits one, and accepting it would let a +// corrupted or adversarial value re-split during Parse. +// +// Sessions created before any auth middleware ran use the literal +// "unauthenticated" sentinel. The sentinel must not contain '\x00' so it +// cannot collide with any bound form; any future format change must preserve +// that disjointness. +// +// # Trust boundary +// +// Bindings are stored plaintext at rest. They are PII but not credentials — +// they identify but do not authenticate. They carry no freshness signal (no +// exp, no nonce) and are NOT a substitute for token validation: callers must +// only compare a stored binding against a freshly-validated (iss, sub) pair +// from the current request's token. +package binding + +import ( + "errors" + "strings" +) + +// UnauthenticatedSentinel is the binding value stored for sessions that were +// created without an authenticated identity (auth middleware not present or +// identity nil). +const UnauthenticatedSentinel = "unauthenticated" + +// ErrInvalidBinding is returned by Format when either input is empty or +// contains a NUL byte. +var ErrInvalidBinding = errors.New("invalid identity binding") + +// Format returns the canonical on-the-wire form of an identity binding: +// iss + "\x00" + sub. Returns ErrInvalidBinding when either input is empty +// or contains a NUL byte. +func Format(iss, sub string) (string, error) { + if iss == "" || sub == "" { + return "", ErrInvalidBinding + } + if strings.ContainsRune(iss, '\x00') || strings.ContainsRune(sub, '\x00') { + return "", ErrInvalidBinding + } + return iss + "\x00" + sub, nil +} + +// Parse splits an on-the-wire binding into its (iss, sub) components. +// Returns ok=true only when s contains exactly one NUL and both halves are +// non-empty. Returns ok=false for the unauthenticated sentinel, for malformed +// input, and for empty strings. Callers must check ok; the empty-string +// return values are not meaningful when ok=false. +func Parse(s string) (iss, sub string, ok bool) { + iss, sub, found := strings.Cut(s, "\x00") + if !found || iss == "" || sub == "" { + return "", "", false + } + // strings.Cut splits on the first NUL, so iss cannot contain one. Sub may + // still carry trailing NULs from a malformed input; reject those. + if strings.ContainsRune(sub, '\x00') { + return "", "", false + } + return iss, sub, true +} + +// IsUnauthenticated reports whether s is the literal unauthenticated sentinel. +func IsUnauthenticated(s string) bool { + return s == UnauthenticatedSentinel +} diff --git a/pkg/vmcp/session/binding/binding_test.go b/pkg/vmcp/session/binding/binding_test.go new file mode 100644 index 0000000000..67048dd1c7 --- /dev/null +++ b/pkg/vmcp/session/binding/binding_test.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package binding_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" +) + +func TestFormat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + iss string + sub string + want string + wantErr bool + }{ + { + name: "typical OIDC issuer and subject", + iss: "https://idp.example", + sub: "user-42", + want: "https://idp.example\x00user-42", + }, + { + name: "issuer with path", + iss: "https://accounts.google.com/o/oauth2", + sub: "109876543210987654321", + want: "https://accounts.google.com/o/oauth2\x00109876543210987654321", + }, + { + name: "minimal non-empty inputs", + iss: "i", + sub: "s", + want: "i\x00s", + }, + { + name: "email-like subject", + iss: "https://auth.example.com", + sub: "alice@example.com", + want: "https://auth.example.com\x00alice@example.com", + }, + {name: "empty iss", iss: "", sub: "sub", wantErr: true}, + {name: "empty sub", iss: "iss", sub: "", wantErr: true}, + {name: "NUL in iss", iss: "a\x00b", sub: "sub", wantErr: true}, + {name: "NUL in sub", iss: "iss", sub: "a\x00b", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := binding.Format(tt.iss, tt.sub) + if tt.wantErr { + assert.Empty(t, got) + require.ErrorIs(t, err, binding.ErrInvalidBinding) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + + gotIss, gotSub, ok := binding.Parse(got) + require.True(t, ok, "Parse must return ok=true for a value produced by Format") + assert.Equal(t, tt.iss, gotIss) + assert.Equal(t, tt.sub, gotSub) + }) + } +} + +func TestParse_RejectsInvalid(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + }{ + {name: "empty string", input: ""}, + {name: "no NUL separator", input: "nothing-here"}, + {name: "empty iss half", input: "\x00sub"}, + {name: "empty sub half", input: "iss\x00"}, + // Defense-in-depth: a value that splits into three parts on NUL must be + // rejected so a corrupt or adversarial store write cannot smuggle extra + // data past Parse. + {name: "trailing NUL with extra data", input: "iss\x00sub\x00trailer"}, + // Callers must call IsUnauthenticated before Parse on the restore path; + // the sentinel must not accidentally parse as a bound binding. + {name: "unauthenticated sentinel", input: binding.UnauthenticatedSentinel}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + iss, sub, ok := binding.Parse(tt.input) + assert.False(t, ok) + assert.Empty(t, iss) + assert.Empty(t, sub) + }) + } +} + +func TestIsUnauthenticated(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + {name: "sentinel", input: binding.UnauthenticatedSentinel, want: true}, + {name: "empty string", input: "", want: false}, + {name: "arbitrary string", input: "foo", want: false}, + {name: "bound binding", input: "iss\x00sub", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, binding.IsUnauthenticated(tt.input)) + }) + } +} From 2a4dc07ac7467f138ed9584917766f32253bff72 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 18 May 2026 07:45:37 +0100 Subject: [PATCH 2/4] Add MetadataKeyIdentityBinding key and tighten ShouldAllowAnonymous MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The hijack-prevention decorator needs a stable per-identity key in session metadata. Add MetadataKeyIdentityBinding alongside the existing token-hash keys (the legacy keys stay temporarily so already-running sessions can be invalidated cleanly during the migration; final removal happens in the operator-side follow-up). ShouldAllowAnonymous treated every empty-token identity as anonymous, which lumped LocalUserMiddleware identities into the same equivalence class even when they carried distinct (iss, sub) claims — letting one local user reuse another's session ID. Tighten the rule so any identity with a valid (iss, sub) pair from Claims goes through the bound path, even when Token is empty. Pull iss and sub from Identity.Claims rather than Identity.Subject so the introspection path and the JWT path canonicalize against the same source. Fail-closed on non-string iss/sub claims: a misbehaving validator that stores a numeric or array value is treated as bound (not anonymous), with a WARN logged so the misbehavior surfaces to operators. The new key is unused until the security and factory layers are wired in the next commit. Co-Authored-By: Claude Opus 4.7 --- pkg/vmcp/session/session.go | 17 ++--- pkg/vmcp/session/types/session.go | 77 +++++++++++++++----- pkg/vmcp/session/types/session_test.go | 97 ++++++++++++++++++++++++++ 3 files changed, 162 insertions(+), 29 deletions(-) create mode 100644 pkg/vmcp/session/types/session_test.go diff --git a/pkg/vmcp/session/session.go b/pkg/vmcp/session/session.go index c8bc492cac..0b615603b2 100644 --- a/pkg/vmcp/session/session.go +++ b/pkg/vmcp/session/session.go @@ -11,18 +11,13 @@ import ( // backward compatibility and convenience. type MultiSession = sessiontypes.MultiSession +// Re-exports from the types package for convenience. See the types package for +// authoritative documentation. const ( - // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 - // hash of the bearer token used to create the session. For authenticated sessions - // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty - // string sentinel. The raw token is never stored — only the hash. - // - // Re-exported from types package for convenience. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); invalidated on read. MetadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash - - // MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded - // random salt used for HMAC-SHA256 token hashing. Omitted for anonymous sessions. - // - // Re-exported from types package for convenience. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306). MetadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt + + MetadataKeyIdentityBinding = sessiontypes.MetadataKeyIdentityBinding ) diff --git a/pkg/vmcp/session/types/session.go b/pkg/vmcp/session/types/session.go index f5752b48f1..ae8dc7c867 100644 --- a/pkg/vmcp/session/types/session.go +++ b/pkg/vmcp/session/types/session.go @@ -12,10 +12,12 @@ package types import ( "context" "errors" + "log/slog" "github.com/stacklok/toolhive/pkg/auth" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" ) // Caller represents the ability to invoke MCP protocol operations against a @@ -139,37 +141,76 @@ type MultiSession interface { } const ( - // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 - // hash of the bearer token used to create the session. For authenticated sessions - // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty - // string sentinel. The raw token is never stored — only the hash. + // MetadataKeyTokenHash held hex(HMAC-SHA256(bearerToken)) for authenticated + // sessions and "" for anonymous sessions. // - // This constant is the single source of truth used by the session factory and - // security layer to store and validate token binding metadata. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); sessions written + // under this key are invalidated on read. Constant removed in the follow-up PR. MetadataKeyTokenHash = "vmcp.token.hash" //nolint:gosec // This is a metadata key name, not a credential. - // MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded - // random salt used for HMAC-SHA256 token hashing. Each authenticated session has a - // unique salt to prevent attacks across multiple sessions. Anonymous sessions do not - // generate a salt and this key is omitted from their metadata. + // MetadataKeyTokenSalt held the hex-encoded per-session salt used for + // HMAC-SHA256 token hashing. // - // This constant is the single source of truth used by the session factory and - // security layer to store and validate token binding metadata. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); removed in the follow-up PR. MetadataKeyTokenSalt = "vmcp.token.salt" //nolint:gosec // This is a metadata key name, not a credential. + + // MetadataKeyIdentityBinding is the session metadata key that holds the + // identity-binding string in the format defined by + // [pkg/vmcp/session/binding]. Bound sessions store binding.Format(iss, sub); + // unauthenticated sessions store binding.UnauthenticatedSentinel. + // + // Storage is plaintext PII — the Redis/Valkey instance must be access-controlled. + MetadataKeyIdentityBinding = "vmcp.identity.binding" ) -// ShouldAllowAnonymous determines if a session should allow anonymous access -// based on the creator's identity. Sessions without an identity (nil) or with -// an empty token are treated as anonymous. +// ShouldAllowAnonymous reports whether a session has no per-user binding and +// should be treated as anonymous. The identity is bound when Token is set, or +// when Claims yields a (iss, sub) pair accepted by binding.Format. +// +// Fail-closed: a Claims["iss"] or Claims["sub"] that is present but not a +// string (a misbehaving validator) is treated as bound, with a WARN logged. func ShouldAllowAnonymous(identity *auth.Identity) bool { - return identity == nil || identity.Token == "" + if identity == nil { + return true + } + if identity.Token != "" { + return false + } + iss, issOK := claimString(identity.Claims, "iss") + sub, subOK := claimString(identity.Claims, "sub") + if !issOK || !subOK { + slog.Warn("auth identity has present-but-non-string iss/sub claim; treating as bound") + return false + } + if _, err := binding.Format(iss, sub); err == nil { + return false + } + return true +} + +// claimString returns (value, ok) for a string claim. ok is true when the +// claim is missing entirely (absence is benign — the caller proceeds to +// binding.Format which rejects empty halves) or when the claim is present +// and a string (including the empty string). ok is false only when the +// claim is present but not a string — the caller must treat that as a +// misconfigured validator and fail closed. +func claimString(claims map[string]any, key string) (string, bool) { + v, present := claims[key] + if !present { + return "", true + } + s, isString := v.(string) + if !isString { + return "", false + } + return s, true } -// Token binding errors returned by Caller methods when caller identity +// Identity binding errors returned by Caller methods when caller identity // validation fails. var ( // ErrUnauthorizedCaller is returned when the caller identity does not - // match the session owner's identity (token hash mismatch). + // match the session owner's identity (identity binding mismatch). ErrUnauthorizedCaller = errors.New("caller identity does not match session owner") // ErrNilCaller is returned when a bound session receives a nil caller. diff --git a/pkg/vmcp/session/types/session_test.go b/pkg/vmcp/session/types/session_test.go new file mode 100644 index 0000000000..3e6b19ce12 --- /dev/null +++ b/pkg/vmcp/session/types/session_test.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stacklok/toolhive/pkg/auth" +) + +func TestShouldAllowAnonymous(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identity *auth.Identity + want bool + }{ + { + name: "nil identity", + identity: nil, + want: true, + }, + { + name: "token present", + identity: &auth.Identity{ + Token: "x", + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{"iss": "https://idp.example", "sub": "alice"}, + }, + }, + want: false, + }, + { + name: "empty token with valid iss+sub claims", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{"iss": "toolhive-local", "sub": "alice"}, + }, + }, + want: false, + }, + { + name: "empty token and nil claims", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: nil}, + }, + want: true, + }, + { + name: "empty token and missing iss", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"sub": "x"}}, + }, + want: true, + }, + { + name: "empty token and missing sub", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "x"}}, + }, + want: true, + }, + { + name: "empty token and empty string iss and sub", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "", "sub": ""}}, + }, + want: true, + }, + { + // Fail-closed: non-string claim from a misbehaving validator → bound, not anonymous. + name: "non-string iss fails closed", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": 42, "sub": "x"}}, + }, + want: false, + }, + { + name: "non-string sub fails closed", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "x", "sub": []string{"foo"}}}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ShouldAllowAnonymous(tt.identity)) + }) + } +} From a9f0e46aab5ef909b9d97ae537a00a5407129287 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 18 May 2026 11:21:42 +0100 Subject: [PATCH 3/4] Bind sessions to (iss, sub), not raw token bytes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The hijack-prevention decorator hashed the incoming bearer token at session creation and rejected any subsequent request whose token hashed differently. Every legitimate OAuth refresh produces a new access token with different bytes — same identity, same iss, same sub — so the decorator misclassified the refresh as a hijack and terminated the session with Unauthorized: caller identity does not match session owner. Drop the HMAC plumbing entirely and bind to a stable (iss, sub) tuple extracted from the OIDC identity's Claims. The binding lives in session metadata under MetadataKeyIdentityBinding, written exclusively by the new BindSession constructor (renamed from PreventSessionHijacking). Validation reads the caller's claims through the same path so the JWT and introspection code paths canonicalize against the same source. The session-upgrade defense (anonymous session, caller presents a token) moves to the unauthenticated-sentinel branch and works exactly as before. RestoreSession reconstructs identity from the stored binding rather than from a separate identity-subject key. Sessions written under the legacy token-hash schema return the bare transportsession.ErrSessionNotFound sentinel so the client receives the standard "re-initialise" signal one forced re-auth at deploy is preferable to a rebind-on-first-use window that would re-introduce the hijack the check exists to block. Downstream callers of the removed WithHMACSecret option (cli/serve.go) and the Phase-2 marker switch (sessionmanager) follow in the next commits; tests catch up after that. Co-Authored-By: Claude Opus 4.7 Use MetadataKeyIdentityBinding as the Phase-2 marker BindSession writes MetadataKeyIdentityBinding on every successful session creation (sentinel for anonymous, real binding for authenticated), so its presence is the new way to distinguish a fully-initialised Phase-2 session from a Generate()-only placeholder. Without this swap, the now-unwritten MetadataKeyTokenHash would always read as absent and Terminate would take the placeholder path on every session, breaking termination semantics. The behaviour for legacy sessions still in Redis from before the migration is intentional and documented in the plan: Terminate writes the placeholder-terminated marker (Update with MetadataKeyTerminated) which sits for the TTL; loadSession returns transportsession.ErrSessionNotFound so the client transparently re-initialises. Co-Authored-By: Claude Opus 4.7 --- pkg/vmcp/server/integration_test.go | 4 +- .../session_management_integration_test.go | 17 +- .../horizontal_scaling_integration_test.go | 30 +- .../server/sessionmanager/session_manager.go | 22 +- .../sessionmanager/session_manager_test.go | 343 +++++++++++---- pkg/vmcp/server/telemetry_integration_test.go | 2 +- pkg/vmcp/server/testfactory_test.go | 2 +- .../session/connector_integration_test.go | 235 ++++------ pkg/vmcp/session/decorating_factory.go | 3 +- pkg/vmcp/session/decorating_factory_test.go | 24 +- pkg/vmcp/session/default_session_test.go | 108 ++--- pkg/vmcp/session/factory.go | 180 +++----- pkg/vmcp/session/factory_metadata_test.go | 12 +- pkg/vmcp/session/identity_binding_test.go | 270 ++++++++++++ .../security/hijack_prevention_test.go | 364 +++++++++------- .../session/internal/security/restore_test.go | 183 ++++---- .../session/internal/security/security.go | 407 ++++++++---------- pkg/vmcp/session/mocks/mock_factory.go | 8 +- pkg/vmcp/session/token_binding_test.go | 343 --------------- pkg/vmcp/session/types/session.go | 10 + 20 files changed, 1277 insertions(+), 1290 deletions(-) create mode 100644 pkg/vmcp/session/identity_binding_test.go delete mode 100644 pkg/vmcp/session/token_binding_test.go diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index b3d2256d71..d6c4325d45 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -506,8 +506,8 @@ func TestIntegration_AuditLogging(t *testing.T) { // table needed for tool calls and resource reads to be audit-logged correctly. auditSessionFactory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) auditSessionFactory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() diff --git a/pkg/vmcp/server/session_management_integration_test.go b/pkg/vmcp/server/session_management_integration_test.go index 4a8c930d81..3b1eb6eba4 100644 --- a/pkg/vmcp/server/session_management_integration_test.go +++ b/pkg/vmcp/server/session_management_integration_test.go @@ -48,8 +48,8 @@ func newNoopMockFactory(t *testing.T) *sessionfactorymocks.MockMultiSessionFacto t.Helper() ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() @@ -90,13 +90,12 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* t.Helper() state := &mockFactoryState{} factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, identity *auth.Identity, allowAnonymous bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { state.makeWithIDCalled.Store(true) - tokenHash := "" - if identity != nil && identity.Token != "" && !allowAnonymous { - tokenHash = "fake-hash-for-testing" - } + // All sessions carry MetadataKeyIdentityBinding so Terminate takes the + // Phase 2 (storage.Delete) path. The sentinel value is sufficient for + // tests that don't validate the binding content. mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() @@ -105,7 +104,7 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* mock.EXPECT().GetData().Return(nil).AnyTimes() mock.EXPECT().SetData(gomock.Any()).AnyTimes() mock.EXPECT().GetMetadata().Return(map[string]string{ - vmcpsession.MetadataKeyTokenHash: tokenHash, + vmcpsession.MetadataKeyIdentityBinding: "unauthenticated", }).AnyTimes() mock.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes() toolsCopy := make([]vmcp.Tool, len(tools)) diff --git a/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go b/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go index 3b6395ff72..b131fd10a2 100644 --- a/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go +++ b/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go @@ -28,9 +28,6 @@ import ( sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -// hmacSecret is a fixed 32-byte secret used across all integration tests. -var hmacSecret = []byte("test-hmac-secret-32bytes-exactly") - // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- @@ -63,9 +60,8 @@ func newSharedRedisStorage(t *testing.T, mr *miniredis.Miniredis) transportsessi } // newTestManagerWithSharedStorage creates a Manager backed by the given -// DataStorage, a real session factory with the package-level hmacSecret, and -// an ImmutableRegistry containing backends. Cleanup is registered via -// t.Cleanup. +// DataStorage, a real session factory, and an ImmutableRegistry containing +// backends. Cleanup is registered via t.Cleanup. func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.DataStorage, backends []*vmcp.Backend) *Manager { t.Helper() backendList := make([]vmcp.Backend, len(backends)) @@ -75,7 +71,6 @@ func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.Data registry := vmcp.NewImmutableRegistry(backendList) factory := vmcpsession.NewSessionFactory( newUnauthenticatedAuthRegistry(t), - vmcpsession.WithHMACSecret(hmacSecret), ) sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) require.NoError(t, err) @@ -215,13 +210,26 @@ func TestHorizontalScaling_CrossPodHijackPrevention(t *testing.T) { storage := newSharedRedisStorage(t, mr) backend := startMCPBackend(t, "backend-alpha", "echo") + // Both alice and eve need Claims with iss+sub so the identity-binding + // decorator can extract their (iss, sub) pairs (Token is not used for binding + // in the #5306 model; Claims are the canonical source). identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, - Token: "alice-bearer-token", + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, } wrongCaller := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, - Token: "eve-bearer-token", + PrincipalInfo: auth.PrincipalInfo{ + Subject: "eve", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "eve", + }, + }, } // Pod A: create session bound to alice. diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index 0b01e9b945..1164a8dbc2 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -304,18 +304,11 @@ func (sm *Manager) CreateSession( // Resolve the caller identity (may be nil for anonymous access). identity, _ := auth.IdentityFromContext(ctx) - // Note: Token hash and salt are computed and stored by the session factory - // (MakeSessionWithID below). Token binding enforcement happens at the session - // level via validateCaller(), which uses HMAC-SHA256 with a per-session salt. - // List all available backends from the registry. backends := sm.listAllBackends(ctx) // Build the fully-formed MultiSession using the SDK-assigned session ID. - // Sessions created with an identity are bound to that identity (allowAnonymous=false). - // Sessions created without an identity allow anonymous access (allowAnonymous=true). - allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, allowAnonymous, backends) + sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, backends) if err != nil { sm.cleanupFailedPlaceholder(sessionID, placeholder) return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err) @@ -482,7 +475,7 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, fmt.Errorf("Manager.Terminate: failed to load session %q: %w", sessionID, loadErr) } - if _, isFullSession := metadata[sessiontypes.MetadataKeyTokenHash]; isFullSession { + if _, isFullSession := metadata[sessiontypes.MetadataKeyIdentityBinding]; isFullSession { // Phase 2 (full MultiSession): delete from storage. The cache entry will be // evicted lazily on the next Get when checkSession finds the session gone. if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { @@ -701,16 +694,17 @@ func (sm *Manager) loadSession(sessionID string) (vmcpsession.MultiSession, erro } // Don't restore placeholder sessions (Phase 2 never ran). - // PreventSessionHijacking always writes MetadataKeyTokenHash during Phase 2 - // (empty sentinel for anonymous, non-empty hash for authenticated). Its - // absence means Generate() stored this record but CreateSession() never - // completed — treat it as "not found" rather than "corrupted". + // BindSession always writes MetadataKeyIdentityBinding during Phase 2 + // (the unauthenticated sentinel for anonymous sessions, a bound (iss, sub) + // binding for authenticated ones). Its absence means Generate() stored + // this record but CreateSession() never completed — treat it as "not + // found" rather than "corrupted". // // Note: this is intentionally different from RestoreSession's fail-closed // check (absent key → error). Here we know a placeholder's empty metadata // is valid storage state produced by Generate(), so we return the // SDK-standard ErrSessionNotFound instead of an error. - if _, hashPresent := metadata[sessiontypes.MetadataKeyTokenHash]; !hashPresent { + if _, bindingPresent := metadata[sessiontypes.MetadataKeyIdentityBinding]; !bindingPresent { return nil, transportsession.ErrSessionNotFound } diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index bffce554f9..8a8fb70cd8 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -65,7 +65,7 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, sess vmcpsession.Mult t.Helper() factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil).AnyTimes() return factory } @@ -75,7 +75,7 @@ func newMockFactoryWithError(t *testing.T, ctrl *gomock.Controller, err error) * t.Helper() factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, err).AnyTimes() return factory } @@ -300,8 +300,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) var createdSess *sessionmocks.MockMultiSession factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = newMockSession(t, ctrl, id, tools) return createdSess, nil }).AnyTimes() @@ -367,8 +367,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factoryCalled := false factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { factoryCalled = true sess := newMockSession(t, ctrl, id, tools) return sess, nil @@ -402,8 +402,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factoryCalled := false factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { factoryCalled = true sess := newMockSession(t, ctrl, id, tools) return sess, nil @@ -440,8 +440,8 @@ func TestSessionManager_CreateSession(t *testing.T) { var createdSess *sessionmocks.MockMultiSession factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { // Sleep to simulate slow backend initialization, creating a window // where the client can terminate the session after the first check passes. time.Sleep(50 * time.Millisecond) @@ -605,18 +605,18 @@ func TestSessionManager_Terminate(t *testing.T) { tools := []vmcp.Tool{{Name: "t1", Description: "tool 1"}} ctrl := gomock.NewController(t) - // tokenHashMeta is carried by the session so CreateSession writes it to + // bindingMeta is carried by the session so CreateSession writes it to // storage and Terminate takes the Phase 2 (storage.Delete) path. - tokenHashMeta := map[string]string{sessiontypes.MetadataKeyTokenHash: ""} + bindingMeta := map[string]string{sessiontypes.MetadataKeyIdentityBinding: "unauthenticated"} var createdSess *sessionmocks.MockMultiSession factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = sessionmocks.NewMockMultiSession(ctrl) createdSess.EXPECT().ID().Return(id).AnyTimes() - createdSess.EXPECT().GetMetadata().Return(tokenHashMeta).AnyTimes() + createdSess.EXPECT().GetMetadata().Return(bindingMeta).AnyTimes() createdSess.EXPECT().Tools().Return(tools).AnyTimes() createdSess.EXPECT().Type().Return(transportsession.SessionType("")).AnyTimes() createdSess.EXPECT().CreatedAt().Return(time.Time{}).AnyTimes() @@ -666,8 +666,8 @@ func TestSessionManager_Terminate(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) // Close is called by onEvict when Terminate removes the cache entry. sess.EXPECT().Close().Return(nil).AnyTimes() @@ -683,10 +683,10 @@ func TestSessionManager_Terminate(t *testing.T) { _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // Seed MetadataKeyIdentityBinding into storage so Terminate recognises this // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. _, err = storage.Update(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -847,8 +847,8 @@ func TestSessionManager_GetMultiSession(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) return sess, nil }).Times(1) @@ -873,7 +873,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Cross-pod restore path: session is in storage but not in the in-memory // cache (simulates pod restart or eviction). loadSession is called on Get. - t.Run("restore path: placeholder in storage (absent token hash) is treated as not found", func(t *testing.T) { + t.Run("restore path: placeholder in storage (absent identity binding) is treated as not found", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -885,24 +885,24 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sessionID := "restore-placeholder-session" // Write placeholder metadata directly to storage, bypassing the cache. - // Generate() stores an empty map with no token hash. + // Generate() stores an empty map with no identity binding. _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{}) require.NoError(t, err) - // loadSession detects absent MetadataKeyTokenHash → ErrSessionNotFound. + // loadSession detects absent MetadataKeyIdentityBinding → ErrSessionNotFound. multiSess, ok := sm.GetMultiSession(sessionID) assert.False(t, ok, "placeholder should not be restorable") assert.Nil(t, multiSess) }) - t.Run("restore path: fully-initialized zero-backend session (has token hash) is restored", func(t *testing.T) { + t.Run("restore path: fully-initialized zero-backend session (has identity binding) is restored", func(t *testing.T) { t.Parallel() tools := []vmcp.Tool{{Name: "zero-backend-tool", Description: "tool with no backends"}} ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) // MakeSessionWithID is only for Phase 2; unused in the restore path. - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-zero-backend-session" @@ -914,12 +914,13 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) - // Metadata matching what populateBackendMetadata now writes for a - // Phase-2-complete session with zero backends: MetadataKeyBackendIDs - // is always written (empty string for zero backends). + // Metadata matching what BindSession and populateBackendMetadata write + // for a Phase-2-complete anonymous session with zero backends: + // MetadataKeyIdentityBinding holds the unauthenticated sentinel; + // MetadataKeyBackendIDs is always written (empty string for zero backends). initializedMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", // anonymous sentinel — present but empty - vmcpsession.MetadataKeyBackendIDs: "", // always written; empty = zero backends + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", // anonymous sentinel + vmcpsession.MetadataKeyBackendIDs: "", // always written; empty = zero backends } _, err := sm.storage.Create(context.Background(), sessionID, initializedMeta) require.NoError(t, err) @@ -941,7 +942,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // RestoreSession. This test documents that backward-compat behaviour. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-legacy-session" @@ -953,9 +954,10 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) - // Legacy metadata: token hash present but MetadataKeyBackendIDs absent. + // Metadata with identity binding but MetadataKeyBackendIDs absent + // (sessions written before populateBackendMetadata always wrote the key). legacyMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", // Phase 2 completion marker + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", // Phase 2 completion marker // MetadataKeyBackendIDs intentionally absent (legacy record) } _, err := sm.storage.Create(context.Background(), sessionID, legacyMeta) @@ -976,14 +978,14 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // that stale per-backend session IDs do not persist indefinitely in storage. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-metadata-persist-session" // The restored session returns fresh per-backend session metadata. freshMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", vmcpsession.MetadataKeyBackendIDs: "backend-a", vmcpsession.MetadataKeyBackendSessionPrefix + "backend-a": "fresh-session-id", } @@ -999,7 +1001,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Seed storage with stale per-backend session ID. staleMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", vmcpsession.MetadataKeyBackendIDs: "backend-a", vmcpsession.MetadataKeyBackendSessionPrefix + "backend-a": "stale-session-id", } @@ -1027,14 +1029,14 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // restored session. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-concurrent-delete-session" restored := sessionmocks.NewMockMultiSession(ctrl) restored.EXPECT().ID().Return(sessionID).AnyTimes() restored.EXPECT().GetMetadata().Return(map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }).AnyTimes() factory.EXPECT(). @@ -1053,7 +1055,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Seed the inner storage with a valid session record. _, err = innerStorage.Create(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -1073,7 +1075,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // metadata drift on the next liveness check and evict if necessary. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-update-error-session" @@ -1090,7 +1092,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { t.Cleanup(func() { _ = cleanup(context.Background()) }) _, err = innerStorage.Create(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -1142,8 +1144,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1202,8 +1204,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1255,8 +1257,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { } factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(callToolResult, nil).Times(1) @@ -1296,8 +1298,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("backend exploded")).Times(1) @@ -1328,8 +1330,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1364,8 +1366,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { var capturedMeta map[string]any factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, _ *auth.Identity, _ string, _ map[string]any, meta map[string]any) (*vmcp.ToolCallResult, error) { @@ -1430,8 +1432,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { authErr := tc.authError factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, authErr).Times(1) @@ -1515,8 +1517,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) // Override default Resources() AnyTimes with a specific return sess.EXPECT().Resources().Return(resources).AnyTimes() @@ -1567,8 +1569,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///data.txt"). @@ -1615,8 +1617,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///broken.txt"). @@ -1662,8 +1664,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///binary.bin"). @@ -1725,8 +1727,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///protected.txt"). @@ -1805,8 +1807,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { // Create mock directly (without newMockSession) so there is no // pre-existing Prompts().Return(nil).AnyTimes() that would win // the FIFO expectation race over our specific prompts list. @@ -1866,8 +1868,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -1908,8 +1910,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -1959,8 +1961,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -2013,8 +2015,8 @@ func TestSessionManager_DecorateSession(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -2076,17 +2078,17 @@ func TestSessionManager_DecorateSession(t *testing.T) { // decorator fn, so the re-check that follows fn() sees it is gone. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - // The mock session carries MetadataKeyTokenHash so that: + // The mock session carries MetadataKeyIdentityBinding so that: // 1. CreateSession stores it in storage (via sess.GetMetadata()), keeping // cache and storage in sync for checkSession's maps.Equal comparison. // 2. Terminate sees the key and takes the Phase 2 path (storage.Delete). - tokenHashMeta := map[string]string{sessiontypes.MetadataKeyTokenHash: ""} + bindingMeta := map[string]string{sessiontypes.MetadataKeyIdentityBinding: "unauthenticated"} factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() - sess.EXPECT().GetMetadata().Return(tokenHashMeta).AnyTimes() + sess.EXPECT().GetMetadata().Return(bindingMeta).AnyTimes() sess.EXPECT().Close().Return(nil).AnyTimes() // Other methods called by the session manager infrastructure. sess.EXPECT().Type().Return(transportsession.SessionType("")).AnyTimes() @@ -2136,7 +2138,7 @@ func TestSessionManager_CheckSession(t *testing.T) { t.Helper() ctrl := gomock.NewController(t) f := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - f.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + f.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes().Return(nil, nil) f.EXPECT().RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes().Return(nil, nil) @@ -2450,10 +2452,10 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // Seed MetadataKeyIdentityBinding into storage so Terminate recognises this // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. _, err = storage.Update(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -2577,6 +2579,179 @@ func TestNotifyBackendExpired(t *testing.T) { }) } +// --------------------------------------------------------------------------- +// Tests: Phase 2 marker migration (#5306) +// --------------------------------------------------------------------------- + +// TestLoadSession_Phase2Marker_UsesIdentityBindingKey documents that the +// Phase-2 detection key for the restore path is MetadataKeyIdentityBinding, +// not the legacy MetadataKeyTokenHash. Sessions stored with only the legacy +// key are treated as not found (ErrSessionNotFound) and the client must +// re-initialize. Sessions stored with MetadataKeyIdentityBinding are restored +// normally. +func TestLoadSession_Phase2Marker_UsesIdentityBindingKey(t *testing.T) { + t.Parallel() + + t.Run("legacy session (only MetadataKeyTokenHash) returns not found", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + // RestoreSession must NOT be called for legacy sessions on the restore path. + factory.EXPECT().RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + sessionID := "legacy-only-token-hash-session" + // Seed storage with only the legacy key — no MetadataKeyIdentityBinding. + _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + }) + require.NoError(t, err) + + // loadSession must treat absent MetadataKeyIdentityBinding as a legacy session + // and return (nil, false) — not attempting RestoreSession. + multiSess, ok := sm.GetMultiSession(sessionID) + assert.False(t, ok, "legacy session with only MetadataKeyTokenHash must not be restored") + assert.Nil(t, multiSess) + }) + + t.Run("session with MetadataKeyIdentityBinding is restored normally", func(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{{Name: "restored-tool", Description: "a restored tool"}} + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + + sessionID := "identity-binding-session" + restored := newMockSession(t, ctrl, sessionID, tools) + + factory.EXPECT(). + RestoreSession(gomock.Any(), sessionID, gomock.Any(), gomock.Any()). + Return(restored, nil).Times(1) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", + vmcpsession.MetadataKeyBackendIDs: "", + }) + require.NoError(t, err) + + multiSess, ok := sm.GetMultiSession(sessionID) + require.True(t, ok, "session with MetadataKeyIdentityBinding must be restorable") + require.NotNil(t, multiSess) + assert.Equal(t, sessionID, multiSess.ID()) + }) +} + +// TestTerminate_Phase2DetectionUsesIdentityBindingKey verifies that Terminate +// uses MetadataKeyIdentityBinding (not the legacy MetadataKeyTokenHash) to +// distinguish Phase 2 sessions (full MultiSession → Delete) from Phase 1 +// placeholders (→ mark terminated). +func TestTerminate_Phase2DetectionUsesIdentityBindingKey(t *testing.T) { + t.Parallel() + + t.Run("session with MetadataKeyIdentityBinding takes delete path", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "s", nil) + sess.EXPECT().Close().Return(nil).AnyTimes() + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // Write MetadataKeyIdentityBinding into storage to simulate a Phase 2 session. + _, err := storage.Update(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", + }) + require.NoError(t, err) + + // Terminate must take the Phase 2 path: storage.Delete (not marked terminated). + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must be deleted from storage, not just marked terminated. + _, loadErr := storage.Load(context.Background(), sessionID) + assert.ErrorIs(t, loadErr, transportsession.ErrSessionNotFound, + "Phase 2 Terminate must delete the session from storage") + }) + + t.Run("placeholder without MetadataKeyIdentityBinding takes mark-terminated path", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "", nil) + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // No MetadataKeyIdentityBinding in storage — this is a Phase 1 placeholder. + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must remain in storage but marked as terminated (not deleted). + metadata, loadErr := storage.Load(context.Background(), sessionID) + require.NoError(t, loadErr, "placeholder must remain in storage (TTL will clean it)") + assert.Equal(t, MetadataValTrue, metadata[MetadataKeyTerminated], + "Phase 1 Terminate must mark the session terminated, not delete it") + }) +} + +// TestTerminate_LegacyFormatSession_TakesPlaceholderPath is a B5 documentation +// test. It verifies that a legacy session stored in Redis with only the old +// MetadataKeyTokenHash key (no MetadataKeyIdentityBinding) causes Terminate to +// take the placeholder (mark-terminated) path rather than deleting the session. +// +// This is intentional: without the identity binding key, the Manager cannot +// tell whether the record is a real Phase-2 session or a corrupted/partial +// record. Treating it as a placeholder (soft termination) is safe — the TTL +// will eventually clean it up. The comment at session_manager.go line 485 +// references this test. +func TestTerminate_LegacyFormatSession_TakesPlaceholderPath(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "", nil) + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // Seed storage with only the legacy key — simulates a pre-#5306 session in Redis. + // MetadataKeyIdentityBinding is absent. + _, err := storage.Update(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + }) + require.NoError(t, err) + + // Terminate must take the placeholder path: mark as terminated, NOT delete. + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must still exist in storage — marked terminated, not deleted. + // (Storage cleanup happens via TTL or the next GET → checkSession → eviction.) + metadata, loadErr := storage.Load(context.Background(), sessionID) + require.NoError(t, loadErr, + "legacy session must remain in storage after Terminate (not deleted)") + assert.Equal(t, MetadataValTrue, metadata[MetadataKeyTerminated], + "legacy session Terminate must set MetadataKeyTerminated rather than deleting") +} + // --------------------------------------------------------------------------- // Helper // --------------------------------------------------------------------------- diff --git a/pkg/vmcp/server/telemetry_integration_test.go b/pkg/vmcp/server/telemetry_integration_test.go index 850a350457..70c7a76f8f 100644 --- a/pkg/vmcp/server/telemetry_integration_test.go +++ b/pkg/vmcp/server/telemetry_integration_test.go @@ -111,7 +111,7 @@ func newBackendAwareTestFactory(tools []vmcp.Tool, rt *vmcp.RoutingTable) (*back } func (f *backendAwareTestFactory) MakeSessionWithID( - _ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend, + _ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend, ) (vmcpsession.MultiSession, error) { return &backendAwareTestSession{ Session: transportsession.NewStreamableSession(id), diff --git a/pkg/vmcp/server/testfactory_test.go b/pkg/vmcp/server/testfactory_test.go index aba748c299..e9b7e62145 100644 --- a/pkg/vmcp/server/testfactory_test.go +++ b/pkg/vmcp/server/testfactory_test.go @@ -27,7 +27,7 @@ type minimalTestFactory struct{} var _ vmcpsession.MultiSessionFactory = (*minimalTestFactory)(nil) func (*minimalTestFactory) MakeSessionWithID( - _ context.Context, _ string, _ *auth.Identity, _ bool, _ []*vmcp.Backend, + _ context.Context, _ string, _ *auth.Identity, _ []*vmcp.Backend, ) (vmcpsession.MultiSession, error) { return nil, fmt.Errorf("minimalTestFactory: MakeSessionWithID not implemented in test helper") } diff --git a/pkg/vmcp/session/connector_integration_test.go b/pkg/vmcp/session/connector_integration_test.go index 3fc0e2d58f..c66b1bdeaf 100644 --- a/pkg/vmcp/session/connector_integration_test.go +++ b/pkg/vmcp/session/connector_integration_test.go @@ -5,7 +5,6 @@ package session import ( "context" - "encoding/hex" "net/http" "net/http/httptest" "sync" @@ -22,7 +21,7 @@ import ( vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/session/internal/security" + internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) @@ -115,7 +114,7 @@ func TestSessionFactory_Integration_CapabilityDiscovery(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -144,7 +143,7 @@ func TestSessionFactory_Integration_CallTool(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -167,7 +166,7 @@ func TestSessionFactory_Integration_ReadResource(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -190,7 +189,7 @@ func TestSessionFactory_Integration_GetPrompt(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -219,7 +218,7 @@ func TestSessionFactory_Integration_MultipleBackends(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -229,21 +228,46 @@ func TestSessionFactory_Integration_MultipleBackends(t *testing.T) { } // --------------------------------------------------------------------------- -// Token-binding integration tests — HMAC rejection for ReadResource / GetPrompt +// Identity-binding integration tests — CallerRejection for ReadResource / GetPrompt // --------------------------------------------------------------------------- // TestTokenBinding_CallerRejection verifies that the hijack-prevention decorator // is applied to all three protected methods (CallTool, ReadResource, GetPrompt): -// each rejects a wrong token (ErrUnauthorizedCaller) and a nil caller -// (ErrNilCaller) before any backend routing occurs, so nilBackendConnector suffices. -func TestTokenBinding_CallerRejection(t *testing.T) { +// each rejects a wrong caller (ErrUnauthorizedCaller) and a nil caller +// (ErrNilCaller) before any backend routing occurs. +// +// The identity binding is derived from Claims["iss"] and Claims["sub"] per the +// new #5306 model (no HMAC secret required). +func TestIdentityBinding_CallerRejection(t *testing.T) { t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: "alice-token"} - wrongCaller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "bob"}, Token: "wrong-token"} + // Both alice and bob need valid Claims so the binding decorator can extract + // their (iss, sub) pairs. alice creates the session; bob is the wrong caller. + alice := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } + wrongCaller := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "bob", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "bob", + }, + }, + } - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) + // No backend connector needed: auth validation fires before any routing. + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), alice, nil) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) @@ -266,7 +290,7 @@ func TestTokenBinding_CallerRejection(t *testing.T) { } for _, fn := range callFns { - t.Run(fn.name+"/wrong token", func(t *testing.T) { + t.Run(fn.name+"/wrong caller", func(t *testing.T) { t.Parallel() assert.ErrorIs(t, fn.call(wrongCaller), sessiontypes.ErrUnauthorizedCaller) }) @@ -280,7 +304,7 @@ func TestTokenBinding_CallerRejection(t *testing.T) { // TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend verifies that a // bound session accepts ReadResource and GetPrompt calls from the correct caller // when a real backend is connected. -func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { +func TestIdentityBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { t.Parallel() baseURL := startInProcessMCPServer(t) @@ -291,15 +315,24 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { TransportType: "streamable-http", } - const rawToken = "alice-real-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} + // Identity with Claims so binding.Format(iss, sub) succeeds and the session + // is bound to the caller identity. + identity := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } - factory := NewSessionFactory(newUnauthenticatedRegistry(t), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, []*vmcp.Backend{backend}) + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) - t.Run("allows ReadResource with correct token", func(t *testing.T) { + t.Run("allows ReadResource with correct caller", func(t *testing.T) { t.Parallel() result, err := sess.ReadResource(context.Background(), identity, "test://data") require.NoError(t, err) @@ -308,7 +341,7 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { assert.Equal(t, "hello", result.Contents[0].Text) }) - t.Run("allows GetPrompt with correct token", func(t *testing.T) { + t.Run("allows GetPrompt with correct caller", func(t *testing.T) { t.Parallel() result, err := sess.GetPrompt(context.Background(), identity, "greet", nil) require.NoError(t, err) @@ -319,71 +352,43 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { }) } -// TestTokenBinding_DifferentSecretsProduceDifferentHashes verifies that two -// session factories configured with different HMAC secrets store different token -// hashes for the same raw bearer token. This is the key isolation property that -// prevents sessions from one secret epoch from being validated against another. -func TestTokenBinding_DifferentSecretsProduceDifferentHashes(t *testing.T) { - t.Parallel() - - const rawToken = "shared-token-same-for-both" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: rawToken} - - factoryA := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("secret-A-exactly-32-bytes-long!!"))) - factoryB := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("secret-B-exactly-32-bytes-long!!"))) - - sessA, err := factoryA.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessA.Close() }) - - sessB, err := factoryB.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessB.Close() }) - - hashA := sessA.GetMetadata()[MetadataKeyTokenHash] - hashB := sessB.GetMetadata()[MetadataKeyTokenHash] - - assert.NotEmpty(t, hashA) - assert.NotEmpty(t, hashB) - assert.NotEqual(t, hashA, hashB, - "different HMAC secrets must produce different token hashes for the same input token") -} - -// TestRestoreHijackPrevention_Integration_RoundTrip verifies the full +// TestTokenBinding_RestoreSession_RoundTrip verifies the full // store-then-restore flow across a real factory-created session: // -// 1. Create a session via the factory (writes tokenHash + tokenSalt to metadata). -// 2. Extract the persisted values. -// 3. Wrap a fresh base session with RestoreHijackPrevention using those values. -// 4. Confirm the restored decorator accepts the original token and rejects others. -func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { +// 1. Create a session via the factory (writes MetadataKeyIdentityBinding to metadata). +// 2. Extract the persisted binding. +// 3. Restore the session via RestoreSession using the persisted metadata. +// 4. Confirm the restored decorator accepts the original caller and rejects others. +func TestIdentityBinding_RestoreSession_RoundTrip(t *testing.T) { t.Parallel() - const rawToken = "integration-token" - hmacSecret := []byte("test-hmac-secret-exactly-32bytes") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} + identity := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(hmacSecret)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, nil) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) // Extract persisted values — these simulate what would be read back from Redis. meta := sess.GetMetadata() - persistedHash := meta[MetadataKeyTokenHash] - persistedSalt := meta[sessiontypes.MetadataKeyTokenSalt] - require.NotEmpty(t, persistedHash, "factory must write tokenHash to metadata") - require.NotEmpty(t, persistedSalt, "factory must write tokenSalt to metadata") - - // Simulate "Pod B": restore the decorator from persisted metadata. - // We use a nil-connector session as the inner session (no real backend needed - // to test auth path). - innerSess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = innerSess.Close() }) + storedBinding := meta[MetadataKeyIdentityBinding] + require.NotEmpty(t, storedBinding, "factory must write MetadataKeyIdentityBinding to metadata") - restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, hmacSecret) + // Simulate "Pod B": restore the session from persisted metadata. + restored, err := factory.RestoreSession(context.Background(), uuid.New().String(), meta, nil) require.NoError(t, err) + t.Cleanup(func() { _ = restored.Close() }) ctx := context.Background() @@ -393,8 +398,16 @@ func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { require.NotErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) require.NotErrorIs(t, err, sessiontypes.ErrNilCaller) - // A different caller is rejected at the auth layer — before any backend routing. - wrongCaller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, Token: "eve-token"} + // A different caller is rejected at the auth layer. + wrongCaller := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "eve", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "eve", + }, + }, + } _, err = restored.CallTool(ctx, wrongCaller, "any-tool", nil, nil) require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) @@ -403,70 +416,6 @@ func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { require.ErrorIs(t, err, sessiontypes.ErrNilCaller) } -// TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch verifies -// that a session restored on a replica with a different HMAC secret rejects -// the original caller's token, documenting the operational requirement that -// all replicas must share the same secret. -func TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch(t *testing.T) { - t.Parallel() - - secretA := []byte("secret-A-exactly-32-bytes-long!!") - secretB := []byte("secret-B-exactly-32-bytes-long!!") - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: "alice-token"} - - // Pod A creates the session with secretA, persisting the hash. - factoryA := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretA)) - sessA, err := factoryA.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessA.Close() }) - - persistedHash := sessA.GetMetadata()[MetadataKeyTokenHash] - persistedSalt := sessA.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - - // Pod B restores with secretB — the persisted hash was computed with secretA, - // so validation will produce a different HMAC and reject the caller. - factoryB := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretB)) - innerSess, err := factoryB.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = innerSess.Close() }) - - restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, secretB) - require.NoError(t, err) - - _, err = restored.CallTool(context.Background(), identity, "any-tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, - "cross-replica secret mismatch must reject the original caller") -} - -// TestTokenBinding_MetadataEncoding verifies that the token hash and salt stored -// in session metadata are valid hex strings of the expected lengths: -// - token hash: 64 hex chars (32-byte HMAC-SHA256) -// - token salt: 32 hex chars (16-byte random salt) -func TestTokenBinding_MetadataEncoding(t *testing.T) { - t.Parallel() - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token-123"} - - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sess.Close() }) - - tokenHash := sess.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, tokenHash) - assert.Len(t, tokenHash, 64, "HMAC-SHA256 hex-encoded hash must be 64 characters") - hashBytes, err := hex.DecodeString(tokenHash) - require.NoError(t, err, "token hash must be valid hex") - assert.Len(t, hashBytes, 32, "decoded token hash must be 32 bytes") - - tokenSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.NotEmpty(t, tokenSalt) - saltBytes, err := hex.DecodeString(tokenSalt) - require.NoError(t, err, "token salt must be valid hex") - assert.Len(t, saltBytes, 16, "decoded token salt must be 16 bytes") -} - // startInProcessMCPServerWithHeaderCapture starts an in-process MCP server and // returns the base URL along with a function that returns all Mcp-Session-Id // header values received by the server from clients. @@ -529,7 +478,7 @@ func TestSessionFactory_Integration_RestoreSession_SendsStoredSessionHintToBacke // Create the original session — the backend assigns a session ID over // streamable-HTTP and we store it in metadata. - orig, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + orig, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { _ = orig.Close() }) diff --git a/pkg/vmcp/session/decorating_factory.go b/pkg/vmcp/session/decorating_factory.go index f536e9ef0f..2254634a23 100644 --- a/pkg/vmcp/session/decorating_factory.go +++ b/pkg/vmcp/session/decorating_factory.go @@ -51,10 +51,9 @@ func (f *decoratingMultiSessionFactory) MakeSessionWithID( ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) { - sess, err := f.base.MakeSessionWithID(ctx, id, identity, allowAnonymous, backends) + sess, err := f.base.MakeSessionWithID(ctx, id, identity, backends) if err != nil { return nil, err } diff --git a/pkg/vmcp/session/decorating_factory_test.go b/pkg/vmcp/session/decorating_factory_test.go index 361ad261e9..ae4776ed85 100644 --- a/pkg/vmcp/session/decorating_factory_test.go +++ b/pkg/vmcp/session/decorating_factory_test.go @@ -34,7 +34,7 @@ func TestNewDecoratingFactory_DecoratorsAppliedInOrder(t *testing.T) { sess := sessionmocks.NewMockMultiSession(ctrl) base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) var order []int @@ -48,7 +48,7 @@ func TestNewDecoratingFactory_DecoratorsAppliedInOrder(t *testing.T) { } factory := session.NewDecoratingFactory(base, dec1, dec2) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.NoError(t, err) assert.Equal(t, []int{1, 2}, order) } @@ -62,7 +62,7 @@ func TestNewDecoratingFactory_DecoratorError_ClosesSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("decorator boom") @@ -70,7 +70,7 @@ func TestNewDecoratingFactory_DecoratorError_ClosesSession(t *testing.T) { return nil, decErr }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.ErrorIs(t, err, decErr) } @@ -85,7 +85,7 @@ func TestNewDecoratingFactory_SecondDecoratorError_ClosesCurrentSession(t *testi base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("second decorator boom") @@ -93,7 +93,7 @@ func TestNewDecoratingFactory_SecondDecoratorError_ClosesCurrentSession(t *testi dec2 := func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return nil, decErr } factory := session.NewDecoratingFactory(base, dec1, dec2) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.ErrorIs(t, err, decErr) } @@ -106,14 +106,14 @@ func TestNewDecoratingFactory_NilReturnWithNoError_ClosesSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return nil, nil // buggy decorator }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "nil session") } @@ -127,7 +127,7 @@ func TestNewDecoratingFactory_CloseErrorOnDecoratorFailure_DoesNotSuppressOrigin base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("decorator error") @@ -135,7 +135,7 @@ func TestNewDecoratingFactory_CloseErrorOnDecoratorFailure_DoesNotSuppressOrigin return nil, decErr }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) // The original decorator error, not the close error, is returned. require.ErrorIs(t, err, decErr) } @@ -149,14 +149,14 @@ func TestNewDecoratingFactory_HappyPath_ReturnsFinalSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return finalSess, nil }, ) - got, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + got, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.NoError(t, err) assert.Equal(t, finalSess, got) } diff --git a/pkg/vmcp/session/default_session_test.go b/pkg/vmcp/session/default_session_test.go index a0055f57e9..a2848748e8 100644 --- a/pkg/vmcp/session/default_session_test.go +++ b/pkg/vmcp/session/default_session_test.go @@ -549,7 +549,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) @@ -567,9 +567,9 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - s1, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + s1, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) - s2, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + s2, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) assert.NotEqual(t, s1.ID(), s2.ID()) @@ -582,7 +582,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, nil) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, nil) require.NoError(t, err) require.NotNil(t, sess) @@ -598,7 +598,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { factory := newSessionFactoryWithConnector(successConnector) // Mix of valid and nil entries; nil must not cause a panic. backends := []*vmcp.Backend{nil, backend, nil} - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) require.NotNil(t, sess) @@ -626,7 +626,7 @@ func TestNewSessionFactory_PartialInitialisation(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err, "partial init must not return an error") require.NotNil(t, sess) @@ -685,7 +685,7 @@ func TestNewSessionFactory_ConnectorReturnsNilWithoutError(t *testing.T) { } factory := newSessionFactoryWithConnector(wrappedConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) assert.Empty(t, sess.Tools()) @@ -712,7 +712,7 @@ func TestNewSessionFactory_ConnectorReturnsConnWithError(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "partial failure must not abort the session") require.NotNil(t, sess) assert.Empty(t, sess.Tools()) @@ -741,7 +741,7 @@ func TestNewSessionFactory_CapabilityNameConflictIsResolvedDeterministically(t * } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) require.NotNil(t, sess) defer func() { require.NoError(t, sess.Close()) }() @@ -771,7 +771,7 @@ func TestNewSessionFactory_AllBackendsFail(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "all-fail must still return a valid (empty) session") require.NotNil(t, sess) @@ -795,7 +795,7 @@ func TestNewSessionFactory_BackendInitTimeout(t *testing.T) { } factory := newSessionFactoryWithConnector(connector, WithBackendInitTimeout(50*time.Millisecond)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "timeout is a partial failure, not a hard error") require.NotNil(t, sess) @@ -844,7 +844,7 @@ func TestNewSessionFactory_ParallelInit(t *testing.T) { } factory := newSessionFactoryWithConnector(connector, WithMaxBackendInitConcurrency(3)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) // All backends must have been initialised. @@ -872,46 +872,54 @@ func TestNewSessionFactory_MakeSession_Metadata(t *testing.T) { } tests := []struct { - name string - connector backendConnector - identity *auth.Identity - backends []*vmcp.Backend - wantSubject string // non-empty → assert equal; empty → assert key absent - wantBackendIDs string // always asserted equal (key is always written, "" for zero backends) + name string + connector backendConnector + identity *auth.Identity + backends []*vmcp.Backend + wantIdentityBinding string // expected MetadataKeyIdentityBinding value + wantBackendIDs string // always asserted equal (key is always written, "" for zero backends) }{ { - name: "sets identity subject and backend IDs", - connector: successConnector, - identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user-123"}}, - backends: []*vmcp.Backend{backend1}, - wantSubject: "user-123", - wantBackendIDs: "b1", + // Identity with Subject but no Claims — no extractable (iss, sub) + // pair → binding is the unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity has no iss/sub claims", + connector: successConnector, + identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user-123"}}, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "omits subject when identity is nil", - connector: successConnector, - identity: nil, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "b1", + // Nil identity → unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity is nil", + connector: successConnector, + identity: nil, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "omits subject when subject is empty", - connector: successConnector, - identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: ""}}, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "b1", + // Empty Subject, no Claims → unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity is empty", + connector: successConnector, + identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: ""}}, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "backend IDs are sorted", - connector: successConnector, - backends: []*vmcp.Backend{backend2, backend1}, // intentionally reversed - wantBackendIDs: "b1,b2", + name: "backend IDs are sorted", + connector: successConnector, + backends: []*vmcp.Backend{backend2, backend1}, // intentionally reversed + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1,b2", }, { - name: "writes empty backend IDs when no backends connect", - connector: failConnector, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "", // key present, value empty — explicit zero-backend sentinel + name: "writes empty backend IDs when no backends connect", + connector: failConnector, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "", // key present, value empty — explicit zero-backend sentinel }, } @@ -920,19 +928,17 @@ func TestNewSessionFactory_MakeSession_Metadata(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(tt.connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), tt.identity, true, tt.backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), tt.identity, tt.backends) require.NoError(t, err) require.NotNil(t, sess) defer func() { require.NoError(t, sess.Close()) }() meta := sess.GetMetadata() - if tt.wantSubject != "" { - assert.Equal(t, tt.wantSubject, meta[MetadataKeyIdentitySubject]) - } else { - _, ok := meta[MetadataKeyIdentitySubject] - assert.False(t, ok, "identity subject key should be absent") - } + // MetadataKeyIdentityBinding is always written by BindSession. + bindingVal, bindingPresent := meta[MetadataKeyIdentityBinding] + assert.True(t, bindingPresent, "MetadataKeyIdentityBinding must always be written") + assert.Equal(t, tt.wantIdentityBinding, bindingVal) // MetadataKeyBackendIDs is always written (even "" for zero backends). backendIDsVal, backendIDsPresent := meta[MetadataKeyBackendIDs] @@ -1136,11 +1142,11 @@ func TestMakeSessionWithID_InvalidIDReturnsError(t *testing.T) { return nil, nil, nil }) - _, err := f.MakeSessionWithID(context.Background(), "", nil, true, nil) + _, err := f.MakeSessionWithID(context.Background(), "", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "must not be empty") - _, err = f.MakeSessionWithID(context.Background(), "bad id", nil, true, nil) + _, err = f.MakeSessionWithID(context.Background(), "bad id", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid character") } diff --git a/pkg/vmcp/session/factory.go b/pkg/vmcp/session/factory.go index a5e115eacd..a7ca9803f2 100644 --- a/pkg/vmcp/session/factory.go +++ b/pkg/vmcp/session/factory.go @@ -19,6 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" "github.com/stacklok/toolhive/pkg/vmcp/session/internal/security" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" @@ -28,11 +29,6 @@ const ( defaultMaxBackendInitConcurrency = 10 defaultBackendInitTimeout = 30 * time.Second - // MetadataKeyIdentitySubject is the transport-session metadata key that - // holds the subject claim of the authenticated caller (identity.Subject). - // Set at session creation; empty for anonymous callers. - MetadataKeyIdentitySubject = "vmcp.identity.subject" - // MetadataKeyBackendIDs is the transport-session metadata key that holds // a comma-separated, sorted list of successfully-connected backend IDs. // The key is always written, even as an empty string for zero-backend @@ -46,18 +42,6 @@ const ( MetadataKeyBackendSessionPrefix = "vmcp.backend.session." ) -var ( - // defaultHMACSecret is the fallback HMAC secret used when WithHMACSecret is not provided. - // WARNING: This is INSECURE and should ONLY be used for testing/development. - // Production deployments MUST provide a secure secret via WithHMACSecret option. - // - // NOTE: In multi-replica deployments, all replicas must use the same HMAC secret, - // injected via the VMCP_SESSION_HMAC_SECRET environment variable. If replicas use - // different secrets, cross-pod token validation will silently reject legitimate - // callers. The default insecure secret must NOT be used in production. - defaultHMACSecret = []byte("insecure-default-for-testing-only-change-in-production") -) - // MultiSessionFactory creates new MultiSessions for connecting clients. type MultiSessionFactory interface { // MakeSessionWithID creates a new MultiSession with a specific session ID. @@ -67,9 +51,8 @@ type MultiSessionFactory interface { // The id parameter must be non-empty and should be a valid MCP session ID // (visible ASCII characters, 0x21 to 0x7E per the MCP specification). // - // The allowAnonymous parameter controls whether the session allows nil caller - // identity. If false, all session method calls must provide a valid caller - // that matches the session creator's identity. + // Whether the session allows anonymous (nil) caller identity is derived + // internally from identity via ShouldAllowAnonymous. // // All other behaviour (partial initialisation, bounded concurrency, etc.) // is identical to MakeSession. @@ -77,18 +60,19 @@ type MultiSessionFactory interface { ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) // RestoreSession reconstructs a live MultiSession from persisted metadata. // It reconnects to the backends whose IDs are listed in storedMetadata under // MetadataKeyBackendIDs, rebuilds the routing table, and reapplies the - // hijack-prevention decorator using the stored token hash and salt. + // session-binding decorator from the stored identity binding. // // Use this when the node-local session cache misses — for example after a // pod restart or when a request is routed to a different pod. It is more // expensive than a cache hit because it opens new backend connections. + // Because MCP clients cannot be serialised, sticky sessions (session affinity + // at the load balancer) minimise how often this path is taken. // // allBackends is the current backend list from the registry; RestoreSession // filters it to the subset originally included in this session. @@ -131,7 +115,6 @@ type defaultMultiSessionFactory struct { connector backendConnector maxConcurrency int backendInitTimeout time.Duration - hmacSecret []byte // Server-managed secret for HMAC-SHA256 token hashing aggregator aggregator.Aggregator // Optional: applies tool transforms (overrides, conflict resolution, filter) } @@ -158,27 +141,6 @@ func WithBackendInitTimeout(d time.Duration) MultiSessionFactoryOption { } } -// WithHMACSecret sets the server-managed secret used for HMAC-SHA256 token hashing. -// The secret should be 32+ bytes and loaded from secure configuration (e.g., environment -// variable, secret management system). -// -// The secret is defensively copied to prevent external modification after assignment. -// Empty or nil secrets are rejected (function is a no-op) to prevent accidental security downgrades. -// -// If not set, a default insecure secret is used (NOT RECOMMENDED for production). -func WithHMACSecret(secret []byte) MultiSessionFactoryOption { - return func(f *defaultMultiSessionFactory) { - // Reject empty/nil secrets to prevent silent security downgrade - if len(secret) == 0 { - slog.Warn("WithHMACSecret: empty or nil secret rejected, falling back to default insecure secret", - "recommendation", "provide a secure secret via VMCP_SESSION_HMAC_SECRET environment variable") - return - } - // Make a defensive copy to prevent external modification - f.hmacSecret = append([]byte(nil), secret...) - } -} - // WithAggregator configures the factory to apply per-backend tool overrides, // conflict resolution, and advertising filters when building sessions. // If not set, raw backend tool names are used unchanged. @@ -202,7 +164,6 @@ func newSessionFactoryWithConnector(connector backendConnector, opts ...MultiSes connector: connector, maxConcurrency: defaultMaxBackendInitConcurrency, backendInitTimeout: defaultBackendInitTimeout, - hmacSecret: defaultHMACSecret, // Initialize with default (insecure) secret } for _, opt := range opts { opt(f) @@ -349,29 +310,11 @@ func (f *defaultMultiSessionFactory) MakeSessionWithID( ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) { if err := validateSessionID(id); err != nil { return nil, err } - - // Validate allowAnonymous is consistent with identity to prevent security footguns. - // If identity has a token, allowAnonymous must be false (caller wants a bound session). - // If identity is nil or has no token, allowAnonymous should be true (anonymous session). - if identity != nil && identity.Token != "" && allowAnonymous { - return nil, fmt.Errorf( - "invalid session configuration: cannot create anonymous session " + - "(allowAnonymous=true) with bearer token (identity.Token is non-empty)", - ) - } - if (identity == nil || identity.Token == "") && !allowAnonymous { - return nil, fmt.Errorf( - "invalid session configuration: cannot create bound session " + - "(allowAnonymous=false) without bearer token (identity is nil or has empty token)", - ) - } - return f.makeSession(ctx, id, identity, backends) } @@ -406,17 +349,15 @@ func populateBackendMetadata(transportSess transportsession.Session, results []i transportSess.SetMetadata(MetadataKeyBackendSessionPrefix+r.target.WorkloadID, sessID) } } - // Always write MetadataKeyBackendIDs, even for zero-backend sessions (""). - // This distinguishes an explicit zero-backend state from absent/corrupted metadata - // in RestoreSession, preventing filterBackendsByStoredIDs from silently - // falling back to all backends when the key is missing. + // Always write MetadataKeyBackendIDs — key presence distinguishes explicit + // zero-backend from absent/corrupted metadata (see const doc). transportSess.SetMetadata(MetadataKeyBackendIDs, strings.Join(ids, ",")) } // makeBaseSession initialises backends and assembles a defaultMultiSession -// WITHOUT applying the hijack-prevention security wrapper. +// WITHOUT applying the session-binding security wrapper. // Callers are responsible for wrapping the result with the appropriate decorator -// (PreventSessionHijacking for new sessions, RestoreHijackPrevention for restored ones). +// (BindSession for new sessions, RestoreSessionBinding for restored ones). func (f *defaultMultiSessionFactory) makeBaseSession( ctx context.Context, sessID string, @@ -488,9 +429,6 @@ func (f *defaultMultiSessionFactory) makeBaseSession( } transportSess := transportsession.NewStreamableSession(sessID) - if identity != nil && identity.Subject != "" { - transportSess.SetMetadata(MetadataKeyIdentitySubject, identity.Subject) - } populateBackendMetadata(transportSess, results) return &defaultMultiSession{ @@ -507,7 +445,7 @@ func (f *defaultMultiSessionFactory) makeBaseSession( } // makeSession is the shared implementation for MakeSession and MakeSessionWithID. -// It builds the base session via makeBaseSession, then applies the hijack-prevention +// It builds the base session via makeBaseSession, then applies the session-binding // security wrapper using the caller's identity. func (f *defaultMultiSessionFactory) makeSession( ctx context.Context, @@ -520,9 +458,10 @@ func (f *defaultMultiSessionFactory) makeSession( return nil, err } - // Apply hijack prevention: computes token binding, stores metadata, and wraps - // the session with validation logic. - decorated, err := security.PreventSessionHijacking(baseSession, f.hmacSecret, identity) + // Apply session binding: extracts the (iss, sub) identity tuple, stores it in + // session metadata under MetadataKeyIdentityBinding, and wraps the session with + // validation logic that checks every subsequent caller against that binding. + decorated, err := security.BindSession(baseSession, identity) if err != nil { _ = baseSession.Close() return nil, err @@ -532,8 +471,10 @@ func (f *defaultMultiSessionFactory) makeSession( // RestoreSession implements MultiSessionFactory. // It reconnects to the backends whose IDs are listed in storedMetadata, rebuilds -// the routing table, and reapplies the hijack-prevention decorator from the stored -// token hash and salt — without recomputing them from a (unavailable) token. +// the routing table, and reapplies the session-binding decorator from the stored +// identity binding. Because the original bearer token is not available at restore +// time, identity is reconstructed from the (iss, sub) tuple in +// MetadataKeyIdentityBinding. func (f *defaultMultiSessionFactory) RestoreSession( ctx context.Context, id string, @@ -557,13 +498,44 @@ func (f *defaultMultiSessionFactory) RestoreSession( // Filter allBackends to the subset originally connected in this session. filteredBackends := filterBackendsByStoredIDs(allBackends, storedBackendIDs) - // Reconstruct a minimal identity from stored metadata. The original bearer - // token is never persisted (only its HMAC-SHA256 hash is), so Token is empty. - // The security decorator is restored from the stored hash/salt below. + // Reconstruct identity from the stored identity binding so makeBaseSession + // can pass it to backend connectors (e.g. for outgoing auth injection). + // The original bearer token is not available at restore time — only the + // (iss, sub) tuple stored in MetadataKeyIdentityBinding is used. + storedBinding, hasBinding := storedMetadata[sessiontypes.MetadataKeyIdentityBinding] + if !hasBinding { + // Legacy token-hash key present confirms not corrupted — safe to invalidate. + if _, hasLegacy := storedMetadata[sessiontypes.MetadataKeyTokenHash]; hasLegacy { + slog.Warn("RestoreSession: legacy session missing identity binding; invalidating", + "reason", "legacy_session_missing_identity_binding", + ) + return nil, transportsession.ErrSessionNotFound + } + return nil, fmt.Errorf("RestoreSession: %q metadata key absent (corrupted session metadata)", + sessiontypes.MetadataKeyIdentityBinding) + } + + // Restore identity from the stored (iss, sub) binding. Token is intentionally + // empty — the original bearer token is not persisted. Outgoing-auth strategies + // must derive credentials from Claims or a token store keyed by tsid; a + // strategy that reads identity.Token will reject the request with an + // identity-has-no-token error. The anonymous sentinel (identity == nil) is + // handled by the IsUnauthenticated guard below. var identity *auth.Identity - if subject := storedMetadata[MetadataKeyIdentitySubject]; subject != "" { - identity = &auth.Identity{} - identity.Subject = subject + if !binding.IsUnauthenticated(storedBinding) { + iss, sub, ok := binding.Parse(storedBinding) + if !ok { + return nil, fmt.Errorf("RestoreSession: stored identity binding is malformed: %q", storedBinding) + } + identity = &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: sub, + Claims: map[string]any{ + "iss": iss, + "sub": sub, + }, + }, + } } // Extract stored per-backend session IDs as hints so each backend can @@ -576,45 +548,23 @@ func (f *defaultMultiSessionFactory) RestoreSession( } // Build the base session (backend connections + routing table) without the - // security wrapper. The wrapper is applied separately using stored hash/salt. + // security wrapper. The wrapper is applied separately below. baseSession, err := f.makeBaseSession(ctx, id, identity, filteredBackends, sessionHints) if err != nil { return nil, fmt.Errorf("RestoreSession: failed to rebuild backend connections: %w", err) } - // Restore only the security keys (token hash and salt) from stored metadata. - // MetadataKeyIdentitySubject is already set by makeBaseSession via the - // reconstructed identity. MetadataKeyBackendIDs and the per-backend session - // keys (MetadataKeyBackendSessionPrefix.*) are freshly computed by - // makeBaseSession from the actual reconnected backends; overwriting them with - // stored values would make metadata inconsistent if any backend failed to - // reconnect during restore. - for _, key := range []string{ - sessiontypes.MetadataKeyTokenHash, - sessiontypes.MetadataKeyTokenSalt, - } { - if v, ok := storedMetadata[key]; ok { - baseSession.SetMetadata(key, v) - } - } + // Restore only the identity-binding key from stored metadata. The other + // keys (MetadataKeyBackendIDs, MetadataKeyBackendSessionPrefix.*) are + // freshly computed by makeBaseSession from the actual reconnected backends; + // overwriting them with stored values would make metadata inconsistent if + // any backend failed to reconnect during restore. + baseSession.SetMetadata(sessiontypes.MetadataKeyIdentityBinding, storedBinding) - // Recreate the hijack-prevention decorator using the stored hash and salt, - // not by recomputing from identity.Token (which is unavailable at restore time). - // - // Fail closed if the token-hash key is entirely absent from stored metadata: - // PreventSessionHijacking always writes the key (empty string for anonymous, - // non-empty for authenticated), so an absent key indicates corrupted or - // truncated metadata — not a legitimately anonymous session. - storedHash, hashKeyPresent := storedMetadata[sessiontypes.MetadataKeyTokenHash] - if !hashKeyPresent { - _ = baseSession.Close() - return nil, fmt.Errorf("RestoreSession: token hash metadata key absent (corrupted session metadata)") - } - storedSalt := storedMetadata[sessiontypes.MetadataKeyTokenSalt] - restored, err := security.RestoreHijackPrevention(baseSession, storedHash, storedSalt, f.hmacSecret) + restored, err := security.RestoreSessionBinding(baseSession, storedBinding) if err != nil { _ = baseSession.Close() - return nil, fmt.Errorf("RestoreSession: failed to restore hijack prevention: %w", err) + return nil, fmt.Errorf("RestoreSession: failed to restore session binding: %w", err) } return restored, nil } diff --git a/pkg/vmcp/session/factory_metadata_test.go b/pkg/vmcp/session/factory_metadata_test.go index 81f1ed50fc..4d26808cda 100644 --- a/pkg/vmcp/session/factory_metadata_test.go +++ b/pkg/vmcp/session/factory_metadata_test.go @@ -41,7 +41,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { {ID: "backend-a"}, {ID: "backend-b"}, } - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) meta := sess.GetMetadata() @@ -56,7 +56,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, nil) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, nil) require.NoError(t, err) meta := sess.GetMetadata() @@ -85,7 +85,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { {ID: "backend-ok"}, {ID: "backend-fail"}, } - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) meta := sess.GetMetadata() @@ -123,7 +123,7 @@ func TestRestoreSession_FreshlyPopulatesMetadataKeyBackendIDs(t *testing.T) { sessionID := "restore-test-session" // Create the initial session so we have a real token hash in metadata. - original, err := factory.MakeSessionWithID(t.Context(), sessionID, nil, true, backends) + original, err := factory.MakeSessionWithID(t.Context(), sessionID, nil, backends) require.NoError(t, err) t.Cleanup(func() { _ = original.Close() }) @@ -195,7 +195,7 @@ func TestRestoreSession_PassesStoredSessionHintToConnector(t *testing.T) { } // Create the original session — connector receives empty hints. - original, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + original, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) t.Cleanup(func() { _ = original.Close() }) @@ -241,7 +241,7 @@ func TestMakeSession_PassesEmptySessionHintToConnector(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, []*vmcp.Backend{{ID: "backend-a"}}) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, []*vmcp.Backend{{ID: "backend-a"}}) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) diff --git a/pkg/vmcp/session/identity_binding_test.go b/pkg/vmcp/session/identity_binding_test.go new file mode 100644 index 0000000000..0363e0bf59 --- /dev/null +++ b/pkg/vmcp/session/identity_binding_test.go @@ -0,0 +1,270 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "errors" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" + internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +// nilBackendConnector is a connector that returns (nil, nil, nil), causing the +// backend to be skipped during init. This lets us exercise session-metadata +// logic without real backend connections. +func nilBackendConnector() backendConnector { + return func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } +} + +// identityWithClaims builds an *auth.Identity whose Claims map is set verbatim +// from claims. Used in tests that need specific claim values without setting +// the Subject field (binding extraction reads only Claims["iss"] and Claims["sub"]). +func identityWithClaims(token string, claims map[string]any) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: claims}, + Token: token, + } +} + +func TestMakeSession_StoresIdentityBinding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identity *auth.Identity + wantBinding string + }{ + { + name: "authenticated_oidc", + identity: identityWithClaims("bearer-token", map[string]any{"iss": "https://idp.example", "sub": "user-42"}), + wantBinding: "https://idp.example\x00user-42", + }, + { + name: "nil_identity_anonymous", + identity: nil, + wantBinding: binding.UnauthenticatedSentinel, + }, + { + // LocalUserMiddleware sets Token="" and populates Claims with + // iss="toolhive-local" and sub=. + name: "local_user_shape", + identity: identityWithClaims("", map[string]any{"iss": "toolhive-local", "sub": "alice"}), + wantBinding: "toolhive-local\x00alice", + }, + { + // AnonymousMiddleware (dev-only) sets Token="" with iss="toolhive-local" + // and sub="anonymous". All such sessions share one binding — intentional. + name: "anonymous_middleware_shape", + identity: identityWithClaims("", map[string]any{"iss": "toolhive-local", "sub": "anonymous"}), + wantBinding: "toolhive-local\x00anonymous", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + sess, err := factory.MakeSessionWithID( + t.Context(), uuid.New().String(), tt.identity, nil, + ) + require.NoError(t, err) + require.NotNil(t, sess) + t.Cleanup(func() { _ = sess.Close() }) + + meta := sess.GetMetadata() + assert.Equal(t, tt.wantBinding, meta[MetadataKeyIdentityBinding], + "MetadataKeyIdentityBinding must equal expected binding") + }) + } +} + +// TestMakeSession_RejectsBoundSessionWithoutIdentifyingClaims verifies the +// ordering invariant from the factory through to BindSession: creating a session +// with an identity that carries no valid (iss, sub) pair returns an error from +// MakeSessionWithID. +func TestMakeSession_RejectsBoundSessionWithoutIdentifyingClaims(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + + // Token is present but Claims are empty, so BindSession's extractBindingID fails. + identity := identityWithClaims("x", map[string]any{}) + + _, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, nil) + require.Error(t, err, "session creation must fail when bound identity lacks identifying claims") +} + +func TestRestoreSession_ErrorCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storedMetadata map[string]string + wantNotFoundErr bool // true → must be ErrSessionNotFound; false → must NOT be + wantErrContains string + }{ + { + // A session carrying the legacy MetadataKeyTokenHash but no + // MetadataKeyIdentityBinding is invalidated with ErrSessionNotFound + // so the MCP client can re-initialize. + name: "legacy_token_hash_only", + storedMetadata: map[string]string{ + MetadataKeyBackendIDs: "", + sessiontypes.MetadataKeyTokenHash: "deadbeefdeadbeef", + }, + wantNotFoundErr: true, + }, + { + // Genuinely corrupted metadata (no binding, no legacy key) must + // NOT masquerade as a session-not-found; it is a distinct error. + name: "absent_identity_binding_key", + storedMetadata: map[string]string{ + MetadataKeyBackendIDs: "", + }, + wantNotFoundErr: false, + wantErrContains: "absent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + _, err := factory.RestoreSession(t.Context(), uuid.New().String(), tt.storedMetadata, nil) + require.Error(t, err) + + if tt.wantNotFoundErr { + require.True(t, errors.Is(err, transportsession.ErrSessionNotFound), + "legacy token-hash-only session must return ErrSessionNotFound") + } else { + assert.False(t, errors.Is(err, transportsession.ErrSessionNotFound), + "corrupted metadata must not return ErrSessionNotFound") + } + if tt.wantErrContains != "" { + assert.Contains(t, err.Error(), tt.wantErrContains) + } + }) + } +} + +// TestRestoreSession_PopulatesBothSubjectFieldAndClaims verifies that after +// RestoreSession the reconstructed identity's binding is stored and the +// decorator accepts a matching caller. +func TestRestoreSession_PopulatesBothSubjectFieldAndClaims(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + + const storedBinding = "https://idp.example\x00alice" + storedMetadata := map[string]string{ + MetadataKeyBackendIDs: "", + sessiontypes.MetadataKeyIdentityBinding: storedBinding, + } + + sess, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) + require.NoError(t, err) + require.NotNil(t, sess) + t.Cleanup(func() { _ = sess.Close() }) + + meta := sess.GetMetadata() + assert.Equal(t, storedBinding, meta[sessiontypes.MetadataKeyIdentityBinding]) + + // Call a tool with the expected identity to verify the decorator accepts it. + caller := identityWithClaims("any-token", map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }) + _, err = sess.CallTool(t.Context(), caller, "nonexistent", nil, nil) + // ErrToolNotFound is acceptable — it means auth passed. + if err != nil { + assert.ErrorIs(t, err, ErrToolNotFound, + "restored session must accept matching caller; any error must be ErrToolNotFound (not auth)") + } +} + +// TestRestoreSession_ReconstructsIdentityWithEmptyTokenButPopulatedClaims pins +// the contract that RestoreSession passes a reconstructed identity to the +// backendConnector whose Token field is intentionally empty (the bearer token is +// not persisted), but whose Claims["iss"] and Claims["sub"] are populated from +// the stored identity binding. +func TestRestoreSession_ReconstructsIdentityWithEmptyTokenButPopulatedClaims(t *testing.T) { + t.Parallel() + + const ( + origIss = "https://idp.example" + origSub = "carol" + ) + + var capturedIdentity *auth.Identity + capturingConnector := func( + _ context.Context, + _ *vmcp.BackendTarget, + id *auth.Identity, + _ string, + ) (internalbk.Session, *vmcp.CapabilityList, error) { + capturedIdentity = id + return nil, nil, nil + } + + // Step 1: create the original session with an authenticated identity. + originalIdentity := identityWithClaims("bearer-AT1", map[string]any{ + "iss": origIss, + "sub": origSub, + }) + + factory := newSessionFactoryWithConnector(capturingConnector) + multiSess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), originalIdentity, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = multiSess.Close() }) + + // Step 2: capture the persisted metadata (simulates what Redis would hold). + meta := multiSess.GetMetadata() + require.NotEmpty(t, meta[sessiontypes.MetadataKeyIdentityBinding], + "factory must write MetadataKeyIdentityBinding to metadata") + + capturedIdentity = nil + + // Step 3: restore the session on "Pod B" with a backend present so the + // connector is actually invoked. + backend := &vmcp.Backend{ + ID: "test-backend", + Name: "test-backend", + } + storedMeta := make(map[string]string, len(meta)+1) + for k, v := range meta { + storedMeta[k] = v + } + storedMeta[MetadataKeyBackendIDs] = backend.ID + + restored, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMeta, []*vmcp.Backend{backend}) + require.NoError(t, err) + t.Cleanup(func() { _ = restored.Close() }) + + // Step 4: the connector must receive an identity with an empty Token but + // populated Claims. Any outgoing-auth strategy that reads identity.Token + // will silently produce unauthenticated backend requests after a pod restart. + require.NotNil(t, capturedIdentity, "connector must be called with a non-nil identity for an authenticated session") + assert.Empty(t, capturedIdentity.Token, + "restored identity.Token must be empty — bearer token is not persisted across pod restarts") + assert.Equal(t, origIss, capturedIdentity.Claims["iss"], + "restored identity.Claims[iss] must match original issuer") + assert.Equal(t, origSub, capturedIdentity.Claims["sub"], + "restored identity.Claims[sub] must match original subject") + assert.Equal(t, origSub, capturedIdentity.Subject, + "restored identity.Subject must match original subject") +} diff --git a/pkg/vmcp/session/internal/security/hijack_prevention_test.go b/pkg/vmcp/session/internal/security/hijack_prevention_test.go index d0746ee030..4f5ad9b67c 100644 --- a/pkg/vmcp/session/internal/security/hijack_prevention_test.go +++ b/pkg/vmcp/session/internal/security/hijack_prevention_test.go @@ -5,22 +5,19 @@ package security import ( "context" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -var ( - // Test HMAC secret and salt for consistent test results - testSecret = []byte("test-secret") - testTokenSalt = []byte("test-salt-123456") // 16 bytes -) - // mockSession is a minimal implementation of MultiSession for testing. // It embeds the interface so only the methods exercised by tests need to be defined. type mockSession struct { @@ -56,204 +53,241 @@ func (*mockSession) GetPrompt(_ context.Context, _ *auth.Identity, _ string, _ m func (*mockSession) Close() error { return nil } -// TestValidateCaller_EdgeCases tests edge cases in caller validation logic. -func TestValidateCaller_EdgeCases(t *testing.T) { +// newDecoratedSession creates a mockSession wrapped with BindSession using the given identity. +func newDecoratedSession(t *testing.T, identity *auth.Identity) sessiontypes.MultiSession { + t.Helper() + base := newMockSession("test-session") + decorated, err := BindSession(base, identity) + require.NoError(t, err) + require.NotNil(t, decorated) + return decorated +} + +// authedIdentity builds an authenticated identity with the given issuer and subject. +// token is used as the raw bearer token (may be any non-empty string). +func authedIdentity(token, iss, sub string) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{ + "iss": iss, + "sub": sub, + }, + }, + Token: token, + } +} + +// identityWithClaims builds an *auth.Identity whose Claims map is set verbatim +// from claims. Used in tests that need malformed claim values (missing keys, +// non-string values, NUL bytes) that authedIdentity cannot express. +func identityWithClaims(token string, claims map[string]any) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: claims}, + Token: token, + } +} + +// TestBindSession_AcceptsRefreshedTokenWithSameIdentity is the regression test +// for issue #5306. A caller presenting a new bearer token (refreshed access +// token) but the same (iss, sub) identity must be accepted because validation +// is now identity-based, not token-hash-based. +func TestBindSession_AcceptsRefreshedTokenWithSameIdentity(t *testing.T) { + t.Parallel() + + const iss = "https://idp.example" + const sub = "user-42" + + // Session created with AT1. + creator := authedIdentity("AT1", iss, sub) + decorated := newDecoratedSession(t, creator) + + // Subsequent call arrives with AT2 (refreshed token) but same (iss, sub). + refreshed := authedIdentity("AT2", iss, sub) + _, err := decorated.CallTool(context.Background(), refreshed, "tool", nil, nil) + require.NoError(t, err, "refreshed token with same identity must be accepted") +} + +func TestBindSession_RejectsInvalidIdentityAtCreation(t *testing.T) { t.Parallel() tests := []struct { - name string - allowAnonymous bool - boundTokenHash string - caller *auth.Identity - wantErr error + name string + identity *auth.Identity + }{ + {name: "missing_sub_claim", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example"})}, + {name: "missing_iss_claim", identity: identityWithClaims("tok", map[string]any{"sub": "alice"})}, + {name: "both_claims_empty", identity: identityWithClaims("tok", map[string]any{})}, + {name: "non_string_iss", identity: identityWithClaims("tok", map[string]any{"iss": 42, "sub": "alice"})}, + {name: "non_string_sub", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example", "sub": true})}, + {name: "nul_byte_in_iss", identity: identityWithClaims("tok", map[string]any{"iss": "bad\x00iss", "sub": "alice"})}, + {name: "nul_byte_in_sub", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example", "sub": "bad\x00sub"})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + base := &metadataObservingSession{mockSession: newMockSession("test-session")} + decorated, err := BindSession(base, tt.identity) + + require.Error(t, err) + assert.Nil(t, decorated) + assert.False(t, base.setMetadataCalled, "SetMetadata must not be called if binding extraction fails") + }) + } +} + +func TestBindSession_RejectsMismatchedCaller(t *testing.T) { + t.Parallel() + + const boundIss = "https://idp.example" + const boundSub = "alice" + + tests := []struct { + name string + caller *auth.Identity + }{ + {name: "different_sub", caller: authedIdentity("tok2", boundIss, "bob")}, + {name: "different_iss", caller: authedIdentity("tok2", "https://idp-b.example", boundSub)}, + {name: "missing_iss_claim", caller: identityWithClaims("tok2", map[string]any{"sub": boundSub})}, + {name: "missing_sub_claim", caller: identityWithClaims("tok2", map[string]any{"iss": boundIss})}, + {name: "both_claims_empty", caller: identityWithClaims("tok2", map[string]any{})}, + {name: "non_string_iss_claim", caller: identityWithClaims("tok2", map[string]any{"iss": []string{boundIss}, "sub": boundSub})}, + {name: "non_string_sub_claim", caller: identityWithClaims("tok2", map[string]any{"iss": boundIss, "sub": 12345})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + decorated := newDecoratedSession(t, authedIdentity("tok", boundIss, boundSub)) + _, err := decorated.CallTool(context.Background(), tt.caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }) + } +} + +func TestBindSession_AnonymousSession(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + assertFn func(t *testing.T, decorated sessiontypes.MultiSession) }{ { - name: "anonymous session with nil caller", - allowAnonymous: true, - boundTokenHash: "", - caller: nil, - wantErr: nil, // Should succeed - }, - { - name: "anonymous session rejects caller with token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, // Prevent session upgrade attack - }, - { - name: "bound session with nil caller", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: nil, - wantErr: sessiontypes.ErrNilCaller, - }, - { - name: "bound session with matching token", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "correct-token"}, - wantErr: nil, // Should succeed - }, - { - name: "bound session with wrong token", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "wrong-token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, - }, - { - name: "bound session with empty token in identity", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: sessiontypes.ErrUnauthorizedCaller, + name: "nil_identity_stores_sentinel", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + meta := decorated.GetMetadata() + assert.Equal(t, binding.UnauthenticatedSentinel, meta[sessiontypes.MetadataKeyIdentityBinding], + "anonymous session must store UnauthenticatedSentinel in metadata") + }, }, { - name: "anonymous session accepts caller with empty token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: nil, // Empty token is equivalent to no token + name: "rejects_caller_presenting_token", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + caller := &auth.Identity{Token: "some-token"} + _, err := decorated.CallTool(context.Background(), caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }, }, { - name: "misconfigured bound session with empty hash rejects empty token", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: sessiontypes.ErrSessionOwnerUnknown, // Fail closed + name: "accepts_nil_caller", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + _, err := decorated.CallTool(context.Background(), nil, "tool", nil, nil) + require.NoError(t, err) + }, }, { - name: "misconfigured bound session with empty hash rejects nil caller", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: nil, - wantErr: sessiontypes.ErrNilCaller, // Nil check happens first + name: "accepts_caller_with_empty_token", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + caller := &auth.Identity{Token: ""} + _, err := decorated.CallTool(context.Background(), caller, "tool", nil, nil) + require.NoError(t, err) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - - // Create a base session - baseSession := newMockSession("test-session") - - // Wrap with decorator that has the test configuration - decorator := &hijackPreventionDecorator{ - MultiSession: baseSession, - allowAnonymous: tt.allowAnonymous, - boundTokenHash: tt.boundTokenHash, - tokenSalt: testTokenSalt, - hmacSecret: testSecret, - } - - ctx := context.Background() - - // Test all three decorated methods to verify validation is integrated correctly - toolResult, errCallTool := decorator.CallTool(ctx, tt.caller, "test-tool", nil, nil) - resourceResult, errReadResource := decorator.ReadResource(ctx, tt.caller, "test://uri") - promptResult, errGetPrompt := decorator.GetPrompt(ctx, tt.caller, "test-prompt", nil) - - if tt.wantErr != nil { - require.ErrorIs(t, errCallTool, tt.wantErr) - require.ErrorIs(t, errReadResource, tt.wantErr) - require.ErrorIs(t, errGetPrompt, tt.wantErr) - assert.Nil(t, toolResult) - assert.Nil(t, resourceResult) - assert.Nil(t, promptResult) - } else { - require.NoError(t, errCallTool) - require.NoError(t, errReadResource) - require.NoError(t, errGetPrompt) - assert.NotNil(t, toolResult) - assert.NotNil(t, resourceResult) - assert.NotNil(t, promptResult) - } + base := newMockSession("test-session") + decorated, err := BindSession(base, nil) + require.NoError(t, err) + require.NotNil(t, decorated) + tt.assertFn(t, decorated) }) } } -// TestPreventSessionHijacking_NilSession tests that a nil session is rejected before any method call. -func TestPreventSessionHijacking_NilSession(t *testing.T) { +func TestBindSession_NilSession(t *testing.T) { t.Parallel() - decorated, err := PreventSessionHijacking(nil, testSecret, &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"}) + identity := authedIdentity("tok", "https://idp.example", "alice") + decorated, err := BindSession(nil, identity) require.Error(t, err) assert.Nil(t, decorated) } -// TestPreventSessionHijacking_BasicFunctionality tests the main entry point. -func TestPreventSessionHijacking_BasicFunctionality(t *testing.T) { +func TestBindSession_BoundRejectsNilCaller(t *testing.T) { t.Parallel() - t.Run("authenticated session", func(t *testing.T) { - t.Parallel() - - baseSession := newMockSession("test-session") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - - decorated, err := PreventSessionHijacking(baseSession, testSecret, identity) - require.NoError(t, err) - require.NotNil(t, decorated) - - // Verify metadata was set (no cast needed - returns concrete type) - metadata := decorated.GetMetadata() - assert.NotEmpty(t, metadata[metadataKeyTokenHash]) - assert.NotEmpty(t, metadata[metadataKeyTokenSalt]) - }) + decorated := newDecoratedSession(t, authedIdentity("tok", "https://idp.example", "alice")) - t.Run("anonymous session", func(t *testing.T) { - t.Parallel() - - baseSession := newMockSession("test-session") - - decorated, err := PreventSessionHijacking(baseSession, testSecret, nil) - require.NoError(t, err) - require.NotNil(t, decorated) - - // Verify metadata was set (empty for anonymous, no cast needed) - metadata := decorated.GetMetadata() - assert.Empty(t, metadata[metadataKeyTokenHash]) - assert.Empty(t, metadata[metadataKeyTokenSalt]) - }) + _, err := decorated.CallTool(context.Background(), nil, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrNilCaller) } -// TestRestoreHijackPrevention tests restoration of the hijack-prevention decorator. -func TestRestoreHijackPrevention(t *testing.T) { +// TestBindSession_ConcurrentRefreshRace verifies that two goroutines calling +// CallTool concurrently with different bearer tokens but the same (iss, sub) +// both succeed. This tests the core fix for issue #5306. +func TestBindSession_ConcurrentRefreshRace(t *testing.T) { t.Parallel() - t.Run("anonymous session (empty hash and salt)", func(t *testing.T) { - t.Parallel() + const iss = "https://idp.example" + const sub = "alice" - base := newMockSession("s1") - restored, err := RestoreHijackPrevention(base, "", "", testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - }) + decorated := newDecoratedSession(t, authedIdentity("AT1", iss, sub)) - t.Run("hash present but salt absent is rejected", func(t *testing.T) { - t.Parallel() + const goroutines = 20 - base := newMockSession("s2") - _, err := RestoreHijackPrevention(base, "somehash", "", testSecret) - require.Error(t, err) - assert.Contains(t, err.Error(), "salt is missing") - }) + errs := make([]error, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + for i := range goroutines { + go func(i int) { + defer wg.Done() + // Each goroutine uses a distinct token string but the same identity. + caller := authedIdentity("refreshed-token-"+string(rune('A'+i%26)), iss, sub) + _, errs[i] = decorated.CallTool(context.Background(), caller, "tool", nil, nil) + }(i) + } - t.Run("salt present but hash absent is rejected", func(t *testing.T) { - t.Parallel() + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for concurrent CallTool goroutines") + } - base := newMockSession("s3") - _, err := RestoreHijackPrevention(base, "", "deadbeef", testSecret) - require.Error(t, err) - assert.Contains(t, err.Error(), "hash is missing") - }) + for i, err := range errs { + assert.NoError(t, err, "goroutine %d must succeed with same identity and different token", i) + } +} - t.Run("nil session is rejected", func(t *testing.T) { - t.Parallel() +// metadataObservingSession wraps mockSession and records whether SetMetadata +// was ever called. Used in tests that assert SetMetadata is NOT called before +// a binding is validated. +type metadataObservingSession struct { + *mockSession + setMetadataCalled bool +} - _, err := RestoreHijackPrevention(nil, "", "", testSecret) - require.Error(t, err) - }) +func (m *metadataObservingSession) SetMetadata(key, value string) { + m.setMetadataCalled = true + m.mockSession.SetMetadata(key, value) } diff --git a/pkg/vmcp/session/internal/security/restore_test.go b/pkg/vmcp/session/internal/security/restore_test.go index 03de0f5a62..c9f98854ea 100644 --- a/pkg/vmcp/session/internal/security/restore_test.go +++ b/pkg/vmcp/session/internal/security/restore_test.go @@ -5,126 +5,99 @@ package security import ( "context" - "encoding/hex" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -func TestRestoreHijackPrevention_NilSession(t *testing.T) { +func TestRestoreSessionBinding(t *testing.T) { t.Parallel() - restored, err := RestoreHijackPrevention(nil, "somehash", hex.EncodeToString(testTokenSalt), testSecret) - require.Error(t, err) - assert.Nil(t, restored) + tests := []struct { + name string + storedBinding string + wantErr bool + // checkFn is called with the restored session when wantErr is false. + // Leave nil to skip behavioral assertions. + checkFn func(t *testing.T, restored sessiontypes.MultiSession) + }{ + { + name: "bound_round_trip_accepts_matching_caller", + storedBinding: "https://idp.example\x00alice", + checkFn: func(t *testing.T, restored sessiontypes.MultiSession) { + t.Helper() + matchingCaller := identityWithClaims("some-token", map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }) + _, err := restored.CallTool(context.Background(), matchingCaller, "tool", nil, nil) + require.NoError(t, err, "matching (iss, sub) caller must be accepted after restore") + + intruder := identityWithClaims("other-token", map[string]any{ + "iss": "https://idp.example", + "sub": "bob", + }) + _, err = restored.CallTool(context.Background(), intruder, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }, + }, + { + name: "unauthenticated_sentinel_accepts_nil_rejects_token", + storedBinding: binding.UnauthenticatedSentinel, + checkFn: func(t *testing.T, restored sessiontypes.MultiSession) { + t.Helper() + _, err := restored.CallTool(context.Background(), nil, "tool", nil, nil) + require.NoError(t, err, "nil caller must be accepted for anonymous sessions") + + caller := &auth.Identity{Token: "some-token"} + _, err = restored.CallTool(context.Background(), caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, + "caller presenting a token must be rejected (session-upgrade attack prevention)") + }, + }, + { + name: "corrupted_binding_no_nul", + storedBinding: "no-nul-here", + wantErr: true, + }, + { + name: "empty_string_rejected", + storedBinding: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + base := newMockSession("sess") + restored, err := RestoreSessionBinding(base, tt.storedBinding) + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, restored) + return + } + + require.NoError(t, err) + require.NotNil(t, restored) + if tt.checkFn != nil { + tt.checkFn(t, restored) + } + }) + } } -func TestRestoreHijackPrevention_MissingSalt(t *testing.T) { +func TestRestoreSessionBinding_NilSession(t *testing.T) { t.Parallel() - // Non-empty tokenHash with empty tokenSaltHex is malformed state. - base := newMockSession("sess") - restored, err := RestoreHijackPrevention(base, "nonemptyhash", "", testSecret) + restored, err := RestoreSessionBinding(nil, binding.UnauthenticatedSentinel) require.Error(t, err) assert.Nil(t, restored) } - -func TestRestoreHijackPrevention_InvalidSaltHex(t *testing.T) { - t.Parallel() - - base := newMockSession("sess") - restored, err := RestoreHijackPrevention(base, "nonemptyhash", "gg", testSecret) - require.Error(t, err) - assert.Nil(t, restored) -} - -func TestRestoreHijackPrevention_AnonymousSession(t *testing.T) { - t.Parallel() - - base := newMockSession("sess") - // tokenHash="" and tokenSaltHex="" → anonymous. - restored, err := RestoreHijackPrevention(base, "", "", testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - - ctx := context.Background() - - // Nil caller is accepted. - _, err = restored.CallTool(ctx, nil, "tool", nil, nil) - require.NoError(t, err) - - // Caller presenting a token is rejected (session upgrade attack prevention). - caller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "u"}, Token: "t"} - _, err = restored.CallTool(ctx, caller, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) -} - -func TestRestoreHijackPrevention_AuthenticatedRoundTrip(t *testing.T) { - t.Parallel() - - // --- "Pod A": create a session, persist hash+salt from metadata. --- - base := newMockSession("sess") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "bearer-token"} - - created, err := PreventSessionHijacking(base, testSecret, identity) - require.NoError(t, err) - - meta := created.GetMetadata() - persistedHash := meta[metadataKeyTokenHash] - persistedSalt := meta[metadataKeyTokenSalt] - require.NotEmpty(t, persistedHash, "tokenHash must be persisted") - require.NotEmpty(t, persistedSalt, "tokenSalt must be persisted") - - // --- "Pod B": restore decorator from persisted values. --- - base2 := newMockSession("sess") - restored, err := RestoreHijackPrevention(base2, persistedHash, persistedSalt, testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - - ctx := context.Background() - - // Original token is accepted. - _, err = restored.CallTool(ctx, identity, "tool", nil, nil) - require.NoError(t, err) - - // A different token is rejected. - other := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "wrong-token"} - _, err = restored.CallTool(ctx, other, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) - - // Nil caller is rejected for a bound session. - _, err = restored.CallTool(ctx, nil, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrNilCaller) -} - -func TestRestoreHijackPrevention_CrossReplicaSecretMismatch(t *testing.T) { - t.Parallel() - - // Pod A creates with secretA. - secretA := []byte("secret-A") - secretB := []byte("secret-B") - - base := newMockSession("sess") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "token"} - - created, err := PreventSessionHijacking(base, secretA, identity) - require.NoError(t, err) - - meta := created.GetMetadata() - persistedHash := meta[metadataKeyTokenHash] - persistedSalt := meta[metadataKeyTokenSalt] - - // Pod B restores with a different secretB — token validation must fail. - base2 := newMockSession("sess") - restored, err := RestoreHijackPrevention(base2, persistedHash, persistedSalt, secretB) - require.NoError(t, err) // Construction succeeds; mismatch only shows at validation time. - - ctx := context.Background() - _, err = restored.CallTool(ctx, identity, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, - "cross-replica secret mismatch must reject the original token") -} diff --git a/pkg/vmcp/session/internal/security/security.go b/pkg/vmcp/session/internal/security/security.go index 358c4a0bba..881e172a0d 100644 --- a/pkg/vmcp/session/internal/security/security.go +++ b/pkg/vmcp/session/internal/security/security.go @@ -1,117 +1,119 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -// Package security provides cryptographic utilities for session token binding -// and hijacking prevention. It handles HMAC-SHA256 token hashing, salt generation, -// and constant-time comparison to prevent timing attacks. +// Package security provides the session-hijack-prevention decorator for +// vMCP sessions. It binds a session to a stable identity tuple (iss, sub) +// extracted from the OIDC identity that created the session, and validates +// that every subsequent request comes from the same identity. +// +// Session bindings are stored as plaintext at rest in the session metadata +// (see pkg/vmcp/session/binding for the format and trust-boundary statement). +// The binding is NOT a credential; it identifies but does not authenticate. +// Callers must validate the request's token independently before passing the +// resulting *auth.Identity to validateCaller. package security import ( "context" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/hex" + "crypto/subtle" "fmt" "log/slog" "github.com/stacklok/toolhive/pkg/auth" - pkgsecurity "github.com/stacklok/toolhive/pkg/security" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -const ( - // SHA256HexLen is the length of a hex-encoded SHA256 hash (32 bytes = 64 hex characters) - SHA256HexLen = 64 - - // metadataKeyTokenHash is the session metadata key for the token hash. - // Imported from types package to ensure consistency across all packages. - metadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash - - // metadataKeyTokenSalt is the session metadata key for the token salt. - // Imported from types package to ensure consistency across all packages. - metadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt -) - -// generateSalt generates a cryptographically secure random salt for token hashing. -// Returns 16 bytes of random data from crypto/rand. +// sessionBindingDecorator wraps a session and adds identity-binding validation +// to prevent session hijacking attacks. It validates that all requests come from +// the same identity that created the session. // -// Each session should have a unique salt to provide additional entropy and prevent -// attacks that work across multiple sessions. -func generateSalt() ([]byte, error) { - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return nil, fmt.Errorf("failed to generate salt: %w", err) - } - return salt, nil +// The decorator is applied by BindSession to ALL sessions (both authenticated +// and anonymous). For authenticated sessions, it validates the caller's (iss, sub) +// identity binding matches the creator's binding. For anonymous sessions +// (allowAnonymous=true), it allows nil callers and prevents session upgrade +// attacks by rejecting any token presentation. +// +// The decorator embeds MultiSession and only overrides the methods that require +// validation (CallTool, ReadResource, GetPrompt). All other methods are +// automatically delegated to the embedded session. +type sessionBindingDecorator struct { + sessiontypes.MultiSession // embedded — automatic delegation for unwrapped methods + + // boundIdentity is the canonical identity binding written at session + // creation. Immutable after construction. + // + // For sessions allowed to be anonymous, boundIdentity is + // binding.UnauthenticatedSentinel. + // For sessions bound to an authenticated identity, boundIdentity is the + // output of binding.Format(iss, sub). + boundIdentity string + + // allowAnonymous tracks whether the session was created without a bound + // identity. Used to reject session-upgrade attacks (caller presents a + // token on an anonymous session). + allowAnonymous bool } -// hashToken returns the hex-encoded HMAC-SHA256 hash of a raw bearer token string. -// Uses HMAC with a server-managed secret and per-session salt to prevent offline -// attacks if session storage is compromised. -// -// For empty tokens (anonymous sessions) it returns the empty string, which is -// the sentinel value used to identify sessions created without credentials. -// The raw token is never stored — only the hash. +// extractBindingID derives the canonical identity-binding string from the +// given auth identity's OIDC claims. It reads "iss" and "sub" from +// identity.Claims (not identity.Subject) so that JWT-validation and +// introspection paths canonicalize against the same source. // -// Parameters: -// - token: The bearer token to hash -// - secret: Server-managed HMAC secret (should be 32+ bytes) -// - salt: Per-session random salt (typically 16 bytes) +// Returns ("", error) when: +// - identity is nil. +// - identity.Claims is missing "iss" or "sub". +// - either claim is present but not a string. +// - binding.Format rejects the (iss, sub) pair (empty halves or stray NULs). // -// Security: Uses HMAC-SHA256 instead of plain SHA256 to prevent rainbow table -// attacks and offline brute force if session state leaks from Redis/Valkey. -func hashToken(token string, secret, salt []byte) string { - if token == "" { - return "" +// Callers MUST treat a non-nil error as "no identifying claims available" +// and fail closed (do not silently fall through to anonymous). +func extractBindingID(identity *auth.Identity) (string, error) { + if identity == nil { + return "", fmt.Errorf("auth identity is nil") } - h := hmac.New(sha256.New, secret) - h.Write(salt) - h.Write([]byte(token)) - return hex.EncodeToString(h.Sum(nil)) -} -// hijackPreventionDecorator wraps a session and adds token binding validation -// to prevent session hijacking attacks. It validates that all requests come from -// the same identity that created the session. -// -// The decorator is applied by PreventSessionHijacking to ALL sessions (both authenticated -// and anonymous). For authenticated sessions, it validates the caller's token matches -// the creator's token. For anonymous sessions (allowAnonymous=true), it allows nil -// callers and prevents session upgrade attacks by rejecting any token presentation. -// -// The decorator embeds MultiSession and only overrides the methods that require -// validation (CallTool, ReadResource, GetPrompt). All other methods are automatically -// delegated to the embedded session. -type hijackPreventionDecorator struct { - sessiontypes.MultiSession // Embedded interface - provides automatic delegation for most methods + issRaw, issPresent := identity.Claims["iss"] + if !issPresent { + return "", fmt.Errorf("auth identity is missing iss claim") + } + iss, issIsString := issRaw.(string) + if !issIsString { + return "", fmt.Errorf("auth identity has non-string iss claim") + } - // Token binding fields: enforce that subsequent requests come from the same - // identity that created the session. - // These fields are immutable after decorator creation (no mutex needed). - boundTokenHash string // HMAC-SHA256 hash of creator's token (empty for anonymous) - tokenSalt []byte // Random salt used for HMAC (empty for anonymous) - hmacSecret []byte // Server-managed secret for HMAC-SHA256 - allowAnonymous bool // Whether to allow nil caller + subRaw, subPresent := identity.Claims["sub"] + if !subPresent { + return "", fmt.Errorf("auth identity is missing sub claim") + } + sub, subIsString := subRaw.(string) + if !subIsString { + return "", fmt.Errorf("auth identity has non-string sub claim") + } + + b, err := binding.Format(iss, sub) + if err != nil { + return "", fmt.Errorf("auth identity (iss, sub) pair is invalid: %w", err) + } + return b, nil } -// validateCaller checks if the provided caller identity matches the session owner. -// Returns nil if validation succeeds, or an error if: -// - The session requires a bound identity but caller is nil (ErrNilCaller) -// - The caller's token hash doesn't match the session owner (ErrUnauthorizedCaller) -// - An anonymous session receives a caller with a non-empty token (ErrUnauthorizedCaller) +// validateCaller checks the caller against the session's bound identity. // -// For anonymous sessions (allowAnonymous=true, boundTokenHash=""), validation succeeds -// only when the caller is nil or has an empty token (prevents session upgrade attacks). -func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { - // No lock needed - token binding fields are immutable after decorator creation - - // Anonymous sessions: reject callers that present tokens - if d.allowAnonymous && d.boundTokenHash == "" { - // Prevent session upgrade attack: anonymous sessions cannot accept tokens +// Returns: +// - ErrSessionOwnerUnknown when the decorator was constructed with neither +// an anonymous marker nor a real binding (programming error). +// - ErrNilCaller when a bound session receives nil. +// - ErrUnauthorizedCaller when: +// - an anonymous session receives a caller presenting a token (upgrade attack), +// - the caller's identity binding does not match the session's bound identity. +func (d sessionBindingDecorator) validateCaller(caller *auth.Identity) error { + // Anonymous path: sessions that were created without a bound identity. + if d.allowAnonymous && binding.IsUnauthenticated(d.boundIdentity) { + // Prevent session upgrade attack: anonymous sessions cannot accept tokens. if caller != nil && caller.Token != "" { - slog.Warn("token validation failed: session upgrade attack prevented", + slog.Warn("identity binding validation failed: session upgrade attack prevented", "reason", "token_presented_to_anonymous_session", ) return sessiontypes.ErrUnauthorizedCaller @@ -119,32 +121,43 @@ func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { return nil } - // Bound sessions require a caller + // Bound sessions require a non-nil caller. if caller == nil { - slog.Warn("token validation failed: nil caller for bound session", + slog.Warn("identity binding validation failed: nil caller for bound session", "reason", "nil_caller", ) return sessiontypes.ErrNilCaller } - // Defensive check: bound sessions must have a non-empty token hash. - // This prevents misconfigured sessions from accepting empty tokens. - // Scenario: if boundTokenHash="" and caller.Token="", both would hash to "", - // and ConstantTimeHashCompare would return true (both empty case). - if d.boundTokenHash == "" { - slog.Error("token validation failed: bound session has empty token hash", - "reason", "misconfigured_session", - ) - return sessiontypes.ErrSessionOwnerUnknown + // Defensive check: the stored binding must be parseable. An unparseable value + // means the session was misconfigured at construction time — fail closed + // rather than accepting or rejecting based on garbage state. + if !binding.IsUnauthenticated(d.boundIdentity) { + if _, _, ok := binding.Parse(d.boundIdentity); !ok { + slog.Error("identity binding validation failed: stored binding is not parseable", + "reason", "misconfigured_session", + ) + return sessiontypes.ErrSessionOwnerUnknown + } } - // Compute caller's token hash using the same HMAC secret and salt - callerHash := hashToken(caller.Token, d.hmacSecret, d.tokenSalt) + // Compute the caller's binding from their identity claims. + callerBinding, err := extractBindingID(caller) + if err != nil { + slog.Warn("identity binding validation failed: could not extract caller binding", + "reason", "caller_binding_extraction_failed", + "error", err, + ) + return sessiontypes.ErrUnauthorizedCaller + } - // Constant-time comparison to prevent timing attacks - if !pkgsecurity.ConstantTimeHashCompare(d.boundTokenHash, callerHash, SHA256HexLen) { - slog.Warn("token validation failed: token hash mismatch", - "reason", "token_hash_mismatch", + // ConstantTimeCompare is constant-time over content but short-circuits on + // length mismatch. Leaking binding length is acceptable: iss is the OIDC + // issuer (public, in the discovery document) and sub is an opaque + // identifier whose length is typically per-issuer. Neither is secret. + if subtle.ConstantTimeCompare([]byte(d.boundIdentity), []byte(callerBinding)) != 1 { + slog.Warn("identity binding validation failed: identity binding mismatch", + "reason", "identity_binding_mismatch", ) return sessiontypes.ErrUnauthorizedCaller } @@ -153,14 +166,13 @@ func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { } // CallTool validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) CallTool( +func (d sessionBindingDecorator) CallTool( ctx context.Context, caller *auth.Identity, toolName string, arguments map[string]any, meta map[string]any, ) (*vmcp.ToolCallResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -169,12 +181,11 @@ func (d hijackPreventionDecorator) CallTool( } // ReadResource validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) ReadResource( +func (d sessionBindingDecorator) ReadResource( ctx context.Context, caller *auth.Identity, uri string, ) (*vmcp.ResourceReadResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -183,13 +194,12 @@ func (d hijackPreventionDecorator) ReadResource( } // GetPrompt validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) GetPrompt( +func (d sessionBindingDecorator) GetPrompt( ctx context.Context, caller *auth.Identity, name string, arguments map[string]any, ) (*vmcp.PromptGetResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -197,148 +207,101 @@ func (d hijackPreventionDecorator) GetPrompt( return d.MultiSession.GetPrompt(ctx, caller, name, arguments) } -// RestoreHijackPrevention recreates the hijack-prevention decorator from persisted -// metadata, rather than recomputing token binding from an identity. Use this when -// reconstructing a MultiSession after a pod restart or cross-pod failover where the -// original bearer token is no longer available but the stored hash and salt are. +// BindSession wraps a session with identity-binding validation. It writes +// the canonical binding into session metadata under MetadataKeyIdentityBinding +// and returns a decorator that validates the caller on every operation. +// +// Whether the session is anonymous is derived from the identity via +// types.ShouldAllowAnonymous. // -// If tokenHash is empty the session is treated as anonymous (allowAnonymous=true). -// The hmacSecret must be the same server-managed secret used at creation time. -func RestoreHijackPrevention( +// For bound sessions, the (iss, sub) tuple is extracted from identity.Claims. +// If the tuple cannot be extracted (missing claims, non-string claims, or +// invalid format), BindSession returns an error BEFORE writing anything to +// the session metadata — preserving the invariant that a session always has +// a valid MetadataKeyIdentityBinding value once written. +// +// Returns an error if session is nil, or if a bound identity is required +// but no valid (iss, sub) binding can be produced. +func BindSession( session sessiontypes.MultiSession, - tokenHash string, - tokenSaltHex string, - hmacSecret []byte, + identity *auth.Identity, ) (sessiontypes.MultiSession, error) { if session == nil { return nil, fmt.Errorf("session must not be nil") } - // Both fields must be either both present or both absent. Any other - // combination indicates corrupted or incomplete metadata and must be - // rejected to fail closed: - // - hash present, salt absent: HMAC comparison will always fail, - // producing a silently broken (always-rejecting) decorator. - // - hash absent, salt present: session would be treated as anonymous, - // silently downgrading a bound session and bypassing token validation. - if tokenHash != "" && tokenSaltHex == "" { - return nil, fmt.Errorf("RestoreHijackPrevention: stored token hash is present but salt is missing " + - "(incomplete session metadata)") - } - if tokenHash == "" && tokenSaltHex != "" { - return nil, fmt.Errorf("RestoreHijackPrevention: stored token salt is present but hash is missing " + - "(incomplete session metadata)") - } - - allowAnonymous := tokenHash == "" + allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - var tokenSalt []byte - if tokenSaltHex != "" { - var decErr error - tokenSalt, decErr = hex.DecodeString(tokenSaltHex) - if decErr != nil { - return nil, fmt.Errorf("failed to decode stored token salt: %w", decErr) + // Determine the binding value to write. For bound sessions, extract the + // (iss, sub) binding before touching session metadata — if extraction fails, + // we must not leave a partial or sentinel value in the session. + bid, err := func() (string, error) { + if allowAnonymous { + return binding.UnauthenticatedSentinel, nil } + b, err := extractBindingID(identity) + if err != nil { + return "", fmt.Errorf("BindSession: cannot derive identity binding: %w", err) + } + return b, nil + }() + if err != nil { + return nil, err } - // Make defensive copies to prevent external mutation after construction. - var hmacSecretCopy, tokenSaltCopy []byte - if len(hmacSecret) > 0 { - hmacSecretCopy = append([]byte(nil), hmacSecret...) - } - if len(tokenSalt) > 0 { - tokenSaltCopy = append([]byte(nil), tokenSalt...) - } + // Write the resolved binding. This is the only metadata key written here; + // backend IDs and per-backend session keys are written by makeBaseSession. + session.SetMetadata(sessiontypes.MetadataKeyIdentityBinding, bid) - return &hijackPreventionDecorator{ + return &sessionBindingDecorator{ MultiSession: session, + boundIdentity: bid, allowAnonymous: allowAnonymous, - hmacSecret: hmacSecretCopy, - boundTokenHash: tokenHash, - tokenSalt: tokenSaltCopy, }, nil } -// PreventSessionHijacking wraps a session with hijack prevention security measures. -// It computes token binding hashes, stores them in session metadata, and returns -// a decorated session that validates caller identity on every operation. -// -// Whether the session is anonymous is derived from the identity: nil identity or -// empty token means anonymous, a non-empty token means bound/authenticated. +// RestoreSessionBinding recreates the session-binding decorator from a +// persisted binding string read out of session metadata. Use this when +// reconstructing a MultiSession after a pod restart or cross-pod failover. // -// For authenticated sessions (identity.Token != ""): -// - Generates a unique random salt -// - Computes HMAC-SHA256 hash of the bearer token -// - Stores hash and salt in session metadata -// - Returns decorator that validates every request against the creator's token +// This function is the symmetric counterpart of BindSession for the restore +// path. It is invoked by the session factory after RestoreSession deserializes +// the binding from session metadata. Unlike BindSession, it does NOT write +// metadata — the factory has already restored the metadata layer separately. // -// For anonymous sessions (identity == nil or identity.Token == ""): -// - Stores an empty string sentinel for the token hash metadata key -// - Omits the salt metadata key entirely (no salt is generated for anonymous sessions) -// - Returns decorator that allows nil callers and rejects token presentation +// storedBinding must be either: +// - binding.UnauthenticatedSentinel (the session was anonymous), or +// - a valid bound binding (binding.Parse returns ok). // -// Security: -// - Makes defensive copies of secret and salt to prevent external mutation -// - Uses constant-time comparison to prevent timing attacks -// - Prevents session upgrade attacks (anonymous → authenticated) -// - Raw tokens are never stored, only HMAC-SHA256 hashes -// -// Returns an error if: -// - session is nil -// - salt generation fails -func PreventSessionHijacking( +// Anything else (empty string, malformed value) is rejected as corrupted +// metadata. +func RestoreSessionBinding( session sessiontypes.MultiSession, - hmacSecret []byte, - identity *auth.Identity, + storedBinding string, ) (sessiontypes.MultiSession, error) { if session == nil { return nil, fmt.Errorf("session must not be nil") } - allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - - // Note: Pass-through methods (ID, Type, CreatedAt, etc.) are validated by the - // type system when the decorator is used. We don't validate them here to keep - // the constructor simple and allow minimal mocks for testing. - - var boundTokenHash string - var tokenSalt []byte - var err error - // Compute token binding for authenticated sessions - if !allowAnonymous && identity != nil && identity.Token != "" { - // Generate unique salt for this session - tokenSalt, err = generateSalt() - if err != nil { - return nil, fmt.Errorf("failed to generate token salt: %w", err) - } - // Compute HMAC-SHA256 hash with server secret and per-session salt - boundTokenHash = hashToken(identity.Token, hmacSecret, tokenSalt) - } - - // Store hash and salt in session metadata for persistence, auditing, - // and backward compatibility - session.SetMetadata(metadataKeyTokenHash, boundTokenHash) - if len(tokenSalt) > 0 { - session.SetMetadata(metadataKeyTokenSalt, hex.EncodeToString(tokenSalt)) + if binding.IsUnauthenticated(storedBinding) { + return &sessionBindingDecorator{ + MultiSession: session, + boundIdentity: storedBinding, + allowAnonymous: true, + }, nil } - // Make defensive copies of slices to prevent external mutation - var hmacSecretCopy, tokenSaltCopy []byte - if len(hmacSecret) > 0 { - hmacSecretCopy = append([]byte(nil), hmacSecret...) - } - if len(tokenSalt) > 0 { - tokenSaltCopy = append([]byte(nil), tokenSalt...) + // Validate the stored binding is parseable. We do not use iss/sub here — + // the factory calls binding.Parse separately when it needs them to + // reconstruct identity. This call is purely a validation gate. + if _, _, ok := binding.Parse(storedBinding); !ok { + return nil, fmt.Errorf("RestoreSessionBinding: stored binding is neither the unauthenticated sentinel " + + "nor a valid bound binding (corrupted metadata)") } - // Wrap with hijackPreventionDecorator for runtime validation. - // The decorator embeds the MultiSession interface, so all methods are automatically - // delegated except for the three we override (CallTool, ReadResource, GetPrompt). - return &hijackPreventionDecorator{ + return &sessionBindingDecorator{ MultiSession: session, - allowAnonymous: allowAnonymous, - hmacSecret: hmacSecretCopy, - boundTokenHash: boundTokenHash, - tokenSalt: tokenSaltCopy, + boundIdentity: storedBinding, + allowAnonymous: false, }, nil } diff --git a/pkg/vmcp/session/mocks/mock_factory.go b/pkg/vmcp/session/mocks/mock_factory.go index e548519ea0..3d850a1eb4 100644 --- a/pkg/vmcp/session/mocks/mock_factory.go +++ b/pkg/vmcp/session/mocks/mock_factory.go @@ -44,18 +44,18 @@ func (m *MockMultiSessionFactory) EXPECT() *MockMultiSessionFactoryMockRecorder } // MakeSessionWithID mocks base method. -func (m *MockMultiSessionFactory) MakeSessionWithID(ctx context.Context, id string, identity *auth.Identity, allowAnonymous bool, backends []*vmcp.Backend) (session.MultiSession, error) { +func (m *MockMultiSessionFactory) MakeSessionWithID(ctx context.Context, id string, identity *auth.Identity, backends []*vmcp.Backend) (session.MultiSession, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeSessionWithID", ctx, id, identity, allowAnonymous, backends) + ret := m.ctrl.Call(m, "MakeSessionWithID", ctx, id, identity, backends) ret0, _ := ret[0].(session.MultiSession) ret1, _ := ret[1].(error) return ret0, ret1 } // MakeSessionWithID indicates an expected call of MakeSessionWithID. -func (mr *MockMultiSessionFactoryMockRecorder) MakeSessionWithID(ctx, id, identity, allowAnonymous, backends any) *gomock.Call { +func (mr *MockMultiSessionFactoryMockRecorder) MakeSessionWithID(ctx, id, identity, backends any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeSessionWithID", reflect.TypeOf((*MockMultiSessionFactory)(nil).MakeSessionWithID), ctx, id, identity, allowAnonymous, backends) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeSessionWithID", reflect.TypeOf((*MockMultiSessionFactory)(nil).MakeSessionWithID), ctx, id, identity, backends) } // RestoreSession mocks base method. diff --git a/pkg/vmcp/session/token_binding_test.go b/pkg/vmcp/session/token_binding_test.go deleted file mode 100644 index db8fdf412e..0000000000 --- a/pkg/vmcp/session/token_binding_test.go +++ /dev/null @@ -1,343 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package session - -import ( - "context" - "errors" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/vmcp" - internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" - sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" -) - -// --------------------------------------------------------------------------- -// makeSession stores token hash in metadata -// --------------------------------------------------------------------------- - -// nilBackendConnector is a connector that returns (nil, nil, nil), causing the -// backend to be skipped during init. This lets us exercise session-metadata -// logic without real backend connections. -func nilBackendConnector() backendConnector { - return func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { - return nil, nil, nil - } -} - -func TestMakeSession_StoresTokenHash(t *testing.T) { - t.Parallel() - - t.Run("authenticated session stores HMAC-SHA256 hash", func(t *testing.T) { - t.Parallel() - - const rawToken = "test-bearer-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - // Verify token hash is stored - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set") - assert.NotEmpty(t, storedHash, "Token hash must be non-empty for authenticated session") - assert.Len(t, storedHash, 64, "HMAC-SHA256 hex-encoded hash should be 64 characters") - // Raw token must never appear in metadata. - assert.NotEqual(t, rawToken, storedHash) - - // Verify salt is stored for authenticated sessions - storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") - assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") - }) - - t.Run("anonymous session stores empty sentinel", func(t *testing.T) { - t.Parallel() - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set even for anonymous sessions") - assert.Empty(t, storedHash, "anonymous session must store empty sentinel") - - // Salt must not be present for anonymous sessions - storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "anonymous session must not store a salt") - }) - - t.Run("identity with empty token stores empty sentinel", func(t *testing.T) { - t.Parallel() - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, true, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - storedHash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.Empty(t, storedHash, "empty-token identity must store empty sentinel") - - // Salt must not be present for empty-token (anonymous) sessions - storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "empty-token identity must not store a salt") - }) - - t.Run("MakeSessionWithID also stores token hash", func(t *testing.T) { - t.Parallel() - - const rawToken = "id-specific-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "bob"}, Token: rawToken} - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), "explicit-session-id", identity, false, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - // Verify token hash - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set") - assert.NotEmpty(t, storedHash, "Token hash must be non-empty") - assert.Len(t, storedHash, 64, "HMAC-SHA256 hex-encoded hash should be 64 characters") - - // Verify salt is stored for authenticated sessions - storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") - assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") - }) -} - -// --------------------------------------------------------------------------- -// MakeSessionWithID validation -// --------------------------------------------------------------------------- - -// TestMakeSessionWithID_ValidationOfAllowAnonymous tests that MakeSessionWithID -// validates consistency between identity and allowAnonymous parameters. -func TestMakeSessionWithID_ValidationOfAllowAnonymous(t *testing.T) { - t.Parallel() - - factory := NewSessionFactory(nil) - - t.Run("rejects anonymous session with bearer token", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "bearer-token"} - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - true, // allowAnonymous=true but identity has token - nil, // no backends needed for validation test - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create anonymous session") - assert.Contains(t, err.Error(), "with bearer token") - }) - - t.Run("rejects bound session without bearer token (nil identity)", func(t *testing.T) { - t.Parallel() - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - nil, // no identity - false, // allowAnonymous=false but no identity - nil, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create bound session") - assert.Contains(t, err.Error(), "without bearer token") - }) - - t.Run("rejects bound session without bearer token (empty token)", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} // empty token - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - false, // allowAnonymous=false but token is empty - nil, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create bound session") - assert.Contains(t, err.Error(), "without bearer token") - }) - - t.Run("allows anonymous session with nil identity", func(t *testing.T) { - t.Parallel() - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - nil, // no identity - true, // allowAnonymous=true - consistent - nil, - ) - require.NoError(t, err) - }) - - t.Run("allows anonymous session with empty token", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - true, // allowAnonymous=true and token is empty - consistent - nil, - ) - require.NoError(t, err) - }) -} - -// --------------------------------------------------------------------------- -// WithHMACSecret defensive copy -// --------------------------------------------------------------------------- - -// TestWithHMACSecret_DefensiveCopy verifies that WithHMACSecret makes a defensive -// copy of the secret to prevent external modification after assignment. -func TestWithHMACSecret_DefensiveCopy(t *testing.T) { - t.Parallel() - - // Create a mutable secret - secretSlice := []byte("original-secret-value") - - // Create factory with the secret - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretSlice)) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - - // Create first session before modification - sess1, err := factory.MakeSessionWithID(context.Background(), "session-1", identity, false, nil) - require.NoError(t, err) - - // Verify first session was created successfully - hash1 := sess1.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, hash1, "first session should have token hash") - - // Maliciously modify the secret slice after passing it to WithHMACSecret - for i := range secretSlice { - secretSlice[i] = 0xFF - } - - // Create second session after modification - should still work correctly - // because WithHMACSecret made a defensive copy - sess2, err := factory.MakeSessionWithID(context.Background(), "session-2", identity, false, nil) - require.NoError(t, err) - - // Verify second session was created successfully - hash2 := sess2.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, hash2, "second session should have token hash") - - // Both sessions should still be able to validate the original token - // (proving the factory used the original secret, not the modified one). - // We verify this by calling a session method that requires authentication. - ctx := context.Background() - - // First session should accept the original token and fail with ErrToolNotFound, - // not an auth error (which would indicate the secret was corrupted) - _, err = sess1.CallTool(ctx, identity, "nonexistent-tool", nil, nil) - assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") - assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), - "should not be an auth error (would indicate corrupted secret)") - - // Second session should also accept the original token and fail with ErrToolNotFound - _, err = sess2.CallTool(ctx, identity, "nonexistent-tool", nil, nil) - assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") - assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), - "should not be an auth error (would indicate corrupted secret)") -} - -// --------------------------------------------------------------------------- -// RestoreSession fail-closed behaviour for absent token-hash key -// --------------------------------------------------------------------------- - -// TestRestoreSession_AbsentTokenHashKey verifies that RestoreSession fails closed -// when the stored metadata is missing MetadataKeyTokenHash entirely. -// -// Background: storedMetadata[key] returns "" for both an absent key and a -// legitimately anonymous session (which stores "" as a sentinel). The factory -// uses the two-value map lookup form to distinguish between the two cases and -// rejects absent keys rather than silently downgrading to anonymous. -func TestRestoreSession_AbsentTokenHashKey(t *testing.T) { - t.Parallel() - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - - t.Run("absent token-hash key is rejected (fail closed)", func(t *testing.T) { - t.Parallel() - - // Metadata that deliberately omits MetadataKeyTokenHash (simulates - // corrupted or truncated session metadata). MetadataKeyBackendIDs is - // present (empty = zero backends) so the earlier backend-IDs guard - // passes and we reach the token-hash guard. - storedMetadata := map[string]string{ - MetadataKeyIdentitySubject: "alice", - MetadataKeyBackendIDs: "", // present, empty = zero backends - // MetadataKeyTokenHash intentionally absent - } - - _, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "token hash metadata key absent") - }) - - t.Run("empty token-hash key (anonymous sentinel) is accepted", func(t *testing.T) { - t.Parallel() - - // Metadata with MetadataKeyTokenHash present but empty — this is what - // PreventSessionHijacking writes for anonymous sessions. - storedMetadata := map[string]string{ - MetadataKeyBackendIDs: "", // present, empty = zero backends - sessiontypes.MetadataKeyTokenHash: "", // present, empty = anonymous - } - - sess, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) - require.NoError(t, err) - require.NotNil(t, sess) - }) -} - -// TestWithHMACSecret_RejectsEmptySecret verifies that WithHMACSecret rejects -// nil or empty secrets to prevent silent security downgrades. -func TestWithHMACSecret_RejectsEmptySecret(t *testing.T) { - t.Parallel() - - t.Run("nil secret is rejected", func(t *testing.T) { - t.Parallel() - - // Create factory with nil secret (should fall back to default) - factory := NewSessionFactory(nil, WithHMACSecret(nil)) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - sess, err := factory.MakeSessionWithID(context.Background(), "test-session", identity, false, nil) - require.NoError(t, err) - - // Should still create a valid session with default secret - hash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.NotEmpty(t, hash, "should use default secret, not nil") - }) - - t.Run("empty secret is rejected", func(t *testing.T) { - t.Parallel() - - // Create factory with empty secret (should fall back to default) - factory := NewSessionFactory(nil, WithHMACSecret([]byte{})) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - sess, err := factory.MakeSessionWithID(context.Background(), "test-session", identity, false, nil) - require.NoError(t, err) - - // Should still create a valid session with default secret - hash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.NotEmpty(t, hash, "should use default secret, not empty slice") - }) -} diff --git a/pkg/vmcp/session/types/session.go b/pkg/vmcp/session/types/session.go index ae8dc7c867..bafb74eadc 100644 --- a/pkg/vmcp/session/types/session.go +++ b/pkg/vmcp/session/types/session.go @@ -169,6 +169,16 @@ const ( // // Fail-closed: a Claims["iss"] or Claims["sub"] that is present but not a // string (a misbehaving validator) is treated as bound, with a WARN logged. +// +// Contract note: this function answers "should this session be CREATED as +// anonymous?" using Token presence as a fast path. It does NOT guarantee that +// a non-anonymous identity will produce a successful binding: an identity with +// Token != "" but missing iss/sub claims passes this check as bound, then +// fails in BindSession with an extraction error. In practice this is +// pathological — all shipping middlewares (JWT validator, LocalUserMiddleware, +// AnonymousMiddleware) populate Claims. Callers who need "will this identity +// actually produce a binding?" must use extractBindingID directly, which is +// deliberately kept internal to the security decorator. func ShouldAllowAnonymous(identity *auth.Identity) bool { if identity == nil { return true From 3fdfd9bb8a77ea000e88889ef504a6ebc8c7309d Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 18 May 2026 11:40:31 +0100 Subject: [PATCH 4/4] Drop HMAC plumbing from serve.go; document Redis trust boundary createSessionFactory no longer takes an HMAC secret or a Kubernetes detection flag. VMCP_SESSION_HMAC_SECRET stays readable from the environment for one deploy cycle (DEBUG-logged and ignored) so the operator-side env-var injection can be removed in a follow-up PR without forcing a coordinated cut-over. Add a startup WARN when incoming auth is configured as "anonymous": AnonymousMiddleware populates the same (iss, sub) for every request, so all callers collide on one identity binding and per-identity hijack prevention degrades to dev-only behaviour. Surface that to operators rather than letting them assume the binding still scopes per user. Document the trust boundary in docs/arch/13-vmcp-scalability.md: the new scheme stores plaintext (iss, sub) at rest in Redis/Valkey, which trades the HMAC's at-rest opacity for refresh correctness. Operators must layer Redis ACLs or NetworkPolicies if a Redis dump revealing identity is unacceptable. Add a TODO on extractBindingID for the forward-looking RFC 7662 case (probe IdP introspection responses for iss + sub) so it surfaces if that becomes a top-level incoming-auth type. The five HMAC-specific serve_test.go cases collapse to one table-driven test of the new signature; the removed cases targeted HMAC validation that no longer exists. Co-Authored-By: Claude Opus 4.7 --- docs/arch/13-vmcp-scalability.md | 24 +++++++ pkg/vmcp/cli/serve.go | 68 ++++++------------- pkg/vmcp/cli/serve_test.go | 59 ++++++---------- .../session/internal/security/security.go | 4 ++ 4 files changed, 68 insertions(+), 87 deletions(-) diff --git a/docs/arch/13-vmcp-scalability.md b/docs/arch/13-vmcp-scalability.md index 3dfb33ecf5..a1a8bf64fa 100644 --- a/docs/arch/13-vmcp-scalability.md +++ b/docs/arch/13-vmcp-scalability.md @@ -84,6 +84,30 @@ line 177). This means: | Inactivity beyond TTL | Redis TTL expiry (automatic, no application-side action needed) | | Pod-local cache eviction (LRU) | `onEvict` callback closes backend connections only; the Redis metadata key is **not** deleted and expires via TTL | +### Identity-binding storage and Redis access control + +Each vMCP session carries an identity binding stored in session metadata under the +key `vmcp.identity.binding`. The canonical format is defined in +`pkg/vmcp/session/binding/binding.go`: a NUL-separated `iss + "\x00" + sub` for +authenticated sessions, and the literal string `"unauthenticated"` for sessions +without an auth identity. + +The binding is stored as **plaintext** in the session store (Redis/Valkey). It is +not a credential — it identifies but does not authenticate a principal — but it is +personally-identifying information (a combination of issuer URL and user subject). + +Operators are responsible for access-controlling the Redis/Valkey instance +equivalently to any other identity store. Concretely: enable Redis ACLs (Redis 6+) +or `requirepass`, restrict network reach with a Kubernetes `NetworkPolicy`, and +avoid sharing the cluster with untrusted workloads. + +The session store prior to issue #5306 held an HMAC of the bearer token rather than +the raw `(iss, sub)` pair. That scheme reduced the value of a Redis dump at the cost +of breaking on every legitimate OAuth token refresh. The current scheme accepts +plaintext PII at rest as the price of correctness; operators who require additional +protection against a Redis compromise must layer Redis-side access controls as +described above. + ## File descriptor limits Each open backend connection consumes one file descriptor on the vMCP pod. A diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index a962f52f27..600fe74375 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -252,6 +252,15 @@ func Serve(ctx context.Context, cfg ServeConfig) error { slog.Info(fmt.Sprintf("Setting up incoming authentication (type: %s)", vmcpCfg.IncomingAuth.Type)) + if vmcpCfg.IncomingAuth.Type == config.IncomingAuthTypeAnonymous { + slog.Warn( + "vMCP is configured with anonymous incoming auth; all anonymous sessions share a single sentinel binding, "+ + "so possession of a session ID is sufficient to act as that session from any source. "+ + "Anonymous mode is intended for development only.", + "incoming_auth_type", config.IncomingAuthTypeAnonymous, + ) + } + // Configure health monitoring if enabled. var healthMonitorConfig *health.MonitorConfig if vmcpCfg.Operational != nil && @@ -336,15 +345,11 @@ func Serve(ctx context.Context, cfg ServeConfig) error { } envReader := &env.OSReader{} - sessionFactory, err := createSessionFactory( - envReader.Getenv("VMCP_SESSION_HMAC_SECRET"), - runtime.IsKubernetesRuntimeWithEnv(envReader), - outgoingRegistry, - agg, - ) - if err != nil { - return err + if hmacSecret := envReader.Getenv("VMCP_SESSION_HMAC_SECRET"); hmacSecret != "" { + slog.Debug("VMCP_SESSION_HMAC_SECRET is set but no longer used after #5306; ignoring", + "env_var", "VMCP_SESSION_HMAC_SECRET") } + sessionFactory := createSessionFactory(outgoingRegistry, agg) // When the optimizer is enabled, its meta-tools must pass through the authz // response filter so they appear in tools/list. @@ -645,51 +650,16 @@ func runDiscovery( return backends, backendClient, outgoingRegistry, nil } -// createSessionFactory creates a MultiSessionFactory with HMAC-SHA256 token binding. -// The HMAC secret and Kubernetes detection are passed in as parameters (typically sourced -// from the VMCP_SESSION_HMAC_SECRET environment variable and runtime environment detection -// by the caller). -// -// Behavior: -// - If hmacSecret is non-empty: validates length and creates factory with the secret. -// - If running in Kubernetes without secret: returns error (production safety requirement). -// - Otherwise: logs warning and creates factory with default insecure secret. +// createSessionFactory creates a MultiSessionFactory backed by the provided outgoing +// auth registry and optional aggregator. When agg is non-nil, sessions gain access +// to aggregated backend metadata; pass nil for single-backend deployments. func createSessionFactory( - hmacSecret string, - isKubernetes bool, outgoingRegistry vmcpauth.OutgoingAuthRegistry, agg aggregator.Aggregator, -) (vmcpsession.MultiSessionFactory, error) { - const minRecommendedSecretLen = 32 - - opts := []vmcpsession.MultiSessionFactoryOption{} +) vmcpsession.MultiSessionFactory { + var opts []vmcpsession.MultiSessionFactoryOption if agg != nil { opts = append(opts, vmcpsession.WithAggregator(agg)) } - - if hmacSecret != "" { - if secretLen := len(hmacSecret); secretLen < minRecommendedSecretLen { - // G706: Safe - only logging integer length, not the secret itself. - slog.Warn( //nolint:gosec - "HMAC secret is shorter than recommended length - consider using a longer secret", - "actual_length", secretLen, - "recommended_length", minRecommendedSecretLen, - ) - } - slog.Info("using provided HMAC secret for session token binding") - opts = append(opts, vmcpsession.WithHMACSecret([]byte(hmacSecret))) - return vmcpsession.NewSessionFactory(outgoingRegistry, opts...), nil - } - - // No secret provided — fail fast in Kubernetes (production environment). - if isKubernetes { - return nil, fmt.Errorf( - "an HMAC secret is required when running in Kubernetes (set VMCP_SESSION_HMAC_SECRET). " + - "Generate a secure secret with: openssl rand -base64 32", - ) - } - - // Development mode: use default insecure secret with warning. - slog.Warn("no HMAC secret provided - using default insecure secret (NOT recommended for production)") - return vmcpsession.NewSessionFactory(outgoingRegistry, opts...), nil + return vmcpsession.NewSessionFactory(outgoingRegistry, opts...) } diff --git a/pkg/vmcp/cli/serve_test.go b/pkg/vmcp/cli/serve_test.go index 667b285779..c486e22a5c 100644 --- a/pkg/vmcp/cli/serve_test.go +++ b/pkg/vmcp/cli/serve_test.go @@ -17,6 +17,7 @@ import ( authserverconfig "github.com/stacklok/toolhive/pkg/authserver" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" aggregatormocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" clientmocks "github.com/stacklok/toolhive/pkg/vmcp/client/mocks" "github.com/stacklok/toolhive/pkg/vmcp/config" @@ -165,45 +166,27 @@ func newSessionFactoryMocks(t *testing.T) (*clientmocks.MockOutgoingAuthRegistry return clientmocks.NewMockOutgoingAuthRegistry(ctrl), aggregatormocks.NewMockAggregator(ctrl) } -func TestCreateSessionFactory_WithHMACSecret(t *testing.T) { +func TestCreateSessionFactory(t *testing.T) { t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("a-sufficiently-long-hmac-secret-value-32b", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_HMACSecretExactly32Bytes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("12345678901234567890123456789012", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_ShortHMACSecret(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("short", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_NoSecretNonKubernetes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_NoSecretKubernetes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("", true, registry, agg) - require.Error(t, err) - require.ErrorContains(t, err, "an HMAC secret is required when running in Kubernetes") - require.Nil(t, factory) + tests := []struct { + name string + useAgg bool + }{ + {name: "with aggregator", useAgg: true}, + {name: "without aggregator", useAgg: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + registry, agg := newSessionFactoryMocks(t) + var aggArg aggregator.Aggregator + if tc.useAgg { + aggArg = agg + } + factory := createSessionFactory(registry, aggArg) + require.NotNil(t, factory) + }) + } } // TestRunDiscovery_KubernetesGroupNotFound exercises the Kubernetes-specific branch diff --git a/pkg/vmcp/session/internal/security/security.go b/pkg/vmcp/session/internal/security/security.go index 881e172a0d..ff17881488 100644 --- a/pkg/vmcp/session/internal/security/security.go +++ b/pkg/vmcp/session/internal/security/security.go @@ -69,6 +69,10 @@ type sessionBindingDecorator struct { // // Callers MUST treat a non-nil error as "no identifying claims available" // and fail closed (do not silently fall through to anonymous). +// +// TODO(#5306-followup): if/when RFC 7662 introspection becomes a top-level +// incoming-auth type, add a startup probe that verifies the IdP emits iss +// and sub in introspection responses (extractBindingID requires both). func extractBindingID(identity *auth.Identity) (string, error) { if identity == nil { return "", fmt.Errorf("auth identity is nil")