diff --git a/docs/arch/13-vmcp-scalability.md b/docs/arch/13-vmcp-scalability.md index 3dfb33ecf5..a1a8bf64fa 100644 --- a/docs/arch/13-vmcp-scalability.md +++ b/docs/arch/13-vmcp-scalability.md @@ -84,6 +84,30 @@ line 177). This means: | Inactivity beyond TTL | Redis TTL expiry (automatic, no application-side action needed) | | Pod-local cache eviction (LRU) | `onEvict` callback closes backend connections only; the Redis metadata key is **not** deleted and expires via TTL | +### Identity-binding storage and Redis access control + +Each vMCP session carries an identity binding stored in session metadata under the +key `vmcp.identity.binding`. The canonical format is defined in +`pkg/vmcp/session/binding/binding.go`: a NUL-separated `iss + "\x00" + sub` for +authenticated sessions, and the literal string `"unauthenticated"` for sessions +without an auth identity. + +The binding is stored as **plaintext** in the session store (Redis/Valkey). It is +not a credential — it identifies but does not authenticate a principal — but it is +personally-identifying information (a combination of issuer URL and user subject). + +Operators are responsible for access-controlling the Redis/Valkey instance +equivalently to any other identity store. Concretely: enable Redis ACLs (Redis 6+) +or `requirepass`, restrict network reach with a Kubernetes `NetworkPolicy`, and +avoid sharing the cluster with untrusted workloads. + +The session store prior to issue #5306 held an HMAC of the bearer token rather than +the raw `(iss, sub)` pair. That scheme reduced the value of a Redis dump at the cost +of breaking on every legitimate OAuth token refresh. The current scheme accepts +plaintext PII at rest as the price of correctness; operators who require additional +protection against a Redis compromise must layer Redis-side access controls as +described above. + ## File descriptor limits Each open backend connection consumes one file descriptor on the vMCP pod. A diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index a962f52f27..600fe74375 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -252,6 +252,15 @@ func Serve(ctx context.Context, cfg ServeConfig) error { slog.Info(fmt.Sprintf("Setting up incoming authentication (type: %s)", vmcpCfg.IncomingAuth.Type)) + if vmcpCfg.IncomingAuth.Type == config.IncomingAuthTypeAnonymous { + slog.Warn( + "vMCP is configured with anonymous incoming auth; all anonymous sessions share a single sentinel binding, "+ + "so possession of a session ID is sufficient to act as that session from any source. "+ + "Anonymous mode is intended for development only.", + "incoming_auth_type", config.IncomingAuthTypeAnonymous, + ) + } + // Configure health monitoring if enabled. var healthMonitorConfig *health.MonitorConfig if vmcpCfg.Operational != nil && @@ -336,15 +345,11 @@ func Serve(ctx context.Context, cfg ServeConfig) error { } envReader := &env.OSReader{} - sessionFactory, err := createSessionFactory( - envReader.Getenv("VMCP_SESSION_HMAC_SECRET"), - runtime.IsKubernetesRuntimeWithEnv(envReader), - outgoingRegistry, - agg, - ) - if err != nil { - return err + if hmacSecret := envReader.Getenv("VMCP_SESSION_HMAC_SECRET"); hmacSecret != "" { + slog.Debug("VMCP_SESSION_HMAC_SECRET is set but no longer used after #5306; ignoring", + "env_var", "VMCP_SESSION_HMAC_SECRET") } + sessionFactory := createSessionFactory(outgoingRegistry, agg) // When the optimizer is enabled, its meta-tools must pass through the authz // response filter so they appear in tools/list. @@ -645,51 +650,16 @@ func runDiscovery( return backends, backendClient, outgoingRegistry, nil } -// createSessionFactory creates a MultiSessionFactory with HMAC-SHA256 token binding. -// The HMAC secret and Kubernetes detection are passed in as parameters (typically sourced -// from the VMCP_SESSION_HMAC_SECRET environment variable and runtime environment detection -// by the caller). -// -// Behavior: -// - If hmacSecret is non-empty: validates length and creates factory with the secret. -// - If running in Kubernetes without secret: returns error (production safety requirement). -// - Otherwise: logs warning and creates factory with default insecure secret. +// createSessionFactory creates a MultiSessionFactory backed by the provided outgoing +// auth registry and optional aggregator. When agg is non-nil, sessions gain access +// to aggregated backend metadata; pass nil for single-backend deployments. func createSessionFactory( - hmacSecret string, - isKubernetes bool, outgoingRegistry vmcpauth.OutgoingAuthRegistry, agg aggregator.Aggregator, -) (vmcpsession.MultiSessionFactory, error) { - const minRecommendedSecretLen = 32 - - opts := []vmcpsession.MultiSessionFactoryOption{} +) vmcpsession.MultiSessionFactory { + var opts []vmcpsession.MultiSessionFactoryOption if agg != nil { opts = append(opts, vmcpsession.WithAggregator(agg)) } - - if hmacSecret != "" { - if secretLen := len(hmacSecret); secretLen < minRecommendedSecretLen { - // G706: Safe - only logging integer length, not the secret itself. - slog.Warn( //nolint:gosec - "HMAC secret is shorter than recommended length - consider using a longer secret", - "actual_length", secretLen, - "recommended_length", minRecommendedSecretLen, - ) - } - slog.Info("using provided HMAC secret for session token binding") - opts = append(opts, vmcpsession.WithHMACSecret([]byte(hmacSecret))) - return vmcpsession.NewSessionFactory(outgoingRegistry, opts...), nil - } - - // No secret provided — fail fast in Kubernetes (production environment). - if isKubernetes { - return nil, fmt.Errorf( - "an HMAC secret is required when running in Kubernetes (set VMCP_SESSION_HMAC_SECRET). " + - "Generate a secure secret with: openssl rand -base64 32", - ) - } - - // Development mode: use default insecure secret with warning. - slog.Warn("no HMAC secret provided - using default insecure secret (NOT recommended for production)") - return vmcpsession.NewSessionFactory(outgoingRegistry, opts...), nil + return vmcpsession.NewSessionFactory(outgoingRegistry, opts...) } diff --git a/pkg/vmcp/cli/serve_test.go b/pkg/vmcp/cli/serve_test.go index 667b285779..c486e22a5c 100644 --- a/pkg/vmcp/cli/serve_test.go +++ b/pkg/vmcp/cli/serve_test.go @@ -17,6 +17,7 @@ import ( authserverconfig "github.com/stacklok/toolhive/pkg/authserver" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" aggregatormocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" clientmocks "github.com/stacklok/toolhive/pkg/vmcp/client/mocks" "github.com/stacklok/toolhive/pkg/vmcp/config" @@ -165,45 +166,27 @@ func newSessionFactoryMocks(t *testing.T) (*clientmocks.MockOutgoingAuthRegistry return clientmocks.NewMockOutgoingAuthRegistry(ctrl), aggregatormocks.NewMockAggregator(ctrl) } -func TestCreateSessionFactory_WithHMACSecret(t *testing.T) { +func TestCreateSessionFactory(t *testing.T) { t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("a-sufficiently-long-hmac-secret-value-32b", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_HMACSecretExactly32Bytes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("12345678901234567890123456789012", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_ShortHMACSecret(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("short", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_NoSecretNonKubernetes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("", false, registry, agg) - require.NoError(t, err) - require.NotNil(t, factory) -} - -func TestCreateSessionFactory_NoSecretKubernetes(t *testing.T) { - t.Parallel() - registry, agg := newSessionFactoryMocks(t) - factory, err := createSessionFactory("", true, registry, agg) - require.Error(t, err) - require.ErrorContains(t, err, "an HMAC secret is required when running in Kubernetes") - require.Nil(t, factory) + tests := []struct { + name string + useAgg bool + }{ + {name: "with aggregator", useAgg: true}, + {name: "without aggregator", useAgg: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + registry, agg := newSessionFactoryMocks(t) + var aggArg aggregator.Aggregator + if tc.useAgg { + aggArg = agg + } + factory := createSessionFactory(registry, aggArg) + require.NotNil(t, factory) + }) + } } // TestRunDiscovery_KubernetesGroupNotFound exercises the Kubernetes-specific branch diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index b3d2256d71..d6c4325d45 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -506,8 +506,8 @@ func TestIntegration_AuditLogging(t *testing.T) { // table needed for tool calls and resource reads to be audit-logged correctly. auditSessionFactory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) auditSessionFactory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() diff --git a/pkg/vmcp/server/session_management_integration_test.go b/pkg/vmcp/server/session_management_integration_test.go index 4a8c930d81..3b1eb6eba4 100644 --- a/pkg/vmcp/server/session_management_integration_test.go +++ b/pkg/vmcp/server/session_management_integration_test.go @@ -48,8 +48,8 @@ func newNoopMockFactory(t *testing.T) *sessionfactorymocks.MockMultiSessionFacto t.Helper() ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() @@ -90,13 +90,12 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* t.Helper() state := &mockFactoryState{} factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, identity *auth.Identity, allowAnonymous bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { state.makeWithIDCalled.Store(true) - tokenHash := "" - if identity != nil && identity.Token != "" && !allowAnonymous { - tokenHash = "fake-hash-for-testing" - } + // All sessions carry MetadataKeyIdentityBinding so Terminate takes the + // Phase 2 (storage.Delete) path. The sentinel value is sufficient for + // tests that don't validate the binding content. mock := sessionmocks.NewMockMultiSession(ctrl) mock.EXPECT().ID().Return(id).AnyTimes() mock.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes() @@ -105,7 +104,7 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* mock.EXPECT().GetData().Return(nil).AnyTimes() mock.EXPECT().SetData(gomock.Any()).AnyTimes() mock.EXPECT().GetMetadata().Return(map[string]string{ - vmcpsession.MetadataKeyTokenHash: tokenHash, + vmcpsession.MetadataKeyIdentityBinding: "unauthenticated", }).AnyTimes() mock.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes() toolsCopy := make([]vmcp.Tool, len(tools)) diff --git a/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go b/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go index 3b6395ff72..b131fd10a2 100644 --- a/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go +++ b/pkg/vmcp/server/sessionmanager/horizontal_scaling_integration_test.go @@ -28,9 +28,6 @@ import ( sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -// hmacSecret is a fixed 32-byte secret used across all integration tests. -var hmacSecret = []byte("test-hmac-secret-32bytes-exactly") - // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- @@ -63,9 +60,8 @@ func newSharedRedisStorage(t *testing.T, mr *miniredis.Miniredis) transportsessi } // newTestManagerWithSharedStorage creates a Manager backed by the given -// DataStorage, a real session factory with the package-level hmacSecret, and -// an ImmutableRegistry containing backends. Cleanup is registered via -// t.Cleanup. +// DataStorage, a real session factory, and an ImmutableRegistry containing +// backends. Cleanup is registered via t.Cleanup. func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.DataStorage, backends []*vmcp.Backend) *Manager { t.Helper() backendList := make([]vmcp.Backend, len(backends)) @@ -75,7 +71,6 @@ func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.Data registry := vmcp.NewImmutableRegistry(backendList) factory := vmcpsession.NewSessionFactory( newUnauthenticatedAuthRegistry(t), - vmcpsession.WithHMACSecret(hmacSecret), ) sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) require.NoError(t, err) @@ -215,13 +210,26 @@ func TestHorizontalScaling_CrossPodHijackPrevention(t *testing.T) { storage := newSharedRedisStorage(t, mr) backend := startMCPBackend(t, "backend-alpha", "echo") + // Both alice and eve need Claims with iss+sub so the identity-binding + // decorator can extract their (iss, sub) pairs (Token is not used for binding + // in the #5306 model; Claims are the canonical source). identity := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, - Token: "alice-bearer-token", + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, } wrongCaller := &auth.Identity{ - PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, - Token: "eve-bearer-token", + PrincipalInfo: auth.PrincipalInfo{ + Subject: "eve", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "eve", + }, + }, } // Pod A: create session bound to alice. diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index 0b01e9b945..1164a8dbc2 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -304,18 +304,11 @@ func (sm *Manager) CreateSession( // Resolve the caller identity (may be nil for anonymous access). identity, _ := auth.IdentityFromContext(ctx) - // Note: Token hash and salt are computed and stored by the session factory - // (MakeSessionWithID below). Token binding enforcement happens at the session - // level via validateCaller(), which uses HMAC-SHA256 with a per-session salt. - // List all available backends from the registry. backends := sm.listAllBackends(ctx) // Build the fully-formed MultiSession using the SDK-assigned session ID. - // Sessions created with an identity are bound to that identity (allowAnonymous=false). - // Sessions created without an identity allow anonymous access (allowAnonymous=true). - allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, allowAnonymous, backends) + sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, backends) if err != nil { sm.cleanupFailedPlaceholder(sessionID, placeholder) return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err) @@ -482,7 +475,7 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, fmt.Errorf("Manager.Terminate: failed to load session %q: %w", sessionID, loadErr) } - if _, isFullSession := metadata[sessiontypes.MetadataKeyTokenHash]; isFullSession { + if _, isFullSession := metadata[sessiontypes.MetadataKeyIdentityBinding]; isFullSession { // Phase 2 (full MultiSession): delete from storage. The cache entry will be // evicted lazily on the next Get when checkSession finds the session gone. if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { @@ -701,16 +694,17 @@ func (sm *Manager) loadSession(sessionID string) (vmcpsession.MultiSession, erro } // Don't restore placeholder sessions (Phase 2 never ran). - // PreventSessionHijacking always writes MetadataKeyTokenHash during Phase 2 - // (empty sentinel for anonymous, non-empty hash for authenticated). Its - // absence means Generate() stored this record but CreateSession() never - // completed — treat it as "not found" rather than "corrupted". + // BindSession always writes MetadataKeyIdentityBinding during Phase 2 + // (the unauthenticated sentinel for anonymous sessions, a bound (iss, sub) + // binding for authenticated ones). Its absence means Generate() stored + // this record but CreateSession() never completed — treat it as "not + // found" rather than "corrupted". // // Note: this is intentionally different from RestoreSession's fail-closed // check (absent key → error). Here we know a placeholder's empty metadata // is valid storage state produced by Generate(), so we return the // SDK-standard ErrSessionNotFound instead of an error. - if _, hashPresent := metadata[sessiontypes.MetadataKeyTokenHash]; !hashPresent { + if _, bindingPresent := metadata[sessiontypes.MetadataKeyIdentityBinding]; !bindingPresent { return nil, transportsession.ErrSessionNotFound } diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index bffce554f9..8a8fb70cd8 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -65,7 +65,7 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, sess vmcpsession.Mult t.Helper() factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil).AnyTimes() return factory } @@ -75,7 +75,7 @@ func newMockFactoryWithError(t *testing.T, ctrl *gomock.Controller, err error) * t.Helper() factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, err).AnyTimes() return factory } @@ -300,8 +300,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) var createdSess *sessionmocks.MockMultiSession factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = newMockSession(t, ctrl, id, tools) return createdSess, nil }).AnyTimes() @@ -367,8 +367,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factoryCalled := false factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { factoryCalled = true sess := newMockSession(t, ctrl, id, tools) return sess, nil @@ -402,8 +402,8 @@ func TestSessionManager_CreateSession(t *testing.T) { factoryCalled := false factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { factoryCalled = true sess := newMockSession(t, ctrl, id, tools) return sess, nil @@ -440,8 +440,8 @@ func TestSessionManager_CreateSession(t *testing.T) { var createdSess *sessionmocks.MockMultiSession factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { // Sleep to simulate slow backend initialization, creating a window // where the client can terminate the session after the first check passes. time.Sleep(50 * time.Millisecond) @@ -605,18 +605,18 @@ func TestSessionManager_Terminate(t *testing.T) { tools := []vmcp.Tool{{Name: "t1", Description: "tool 1"}} ctrl := gomock.NewController(t) - // tokenHashMeta is carried by the session so CreateSession writes it to + // bindingMeta is carried by the session so CreateSession writes it to // storage and Terminate takes the Phase 2 (storage.Delete) path. - tokenHashMeta := map[string]string{sessiontypes.MetadataKeyTokenHash: ""} + bindingMeta := map[string]string{sessiontypes.MetadataKeyIdentityBinding: "unauthenticated"} var createdSess *sessionmocks.MockMultiSession factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = sessionmocks.NewMockMultiSession(ctrl) createdSess.EXPECT().ID().Return(id).AnyTimes() - createdSess.EXPECT().GetMetadata().Return(tokenHashMeta).AnyTimes() + createdSess.EXPECT().GetMetadata().Return(bindingMeta).AnyTimes() createdSess.EXPECT().Tools().Return(tools).AnyTimes() createdSess.EXPECT().Type().Return(transportsession.SessionType("")).AnyTimes() createdSess.EXPECT().CreatedAt().Return(time.Time{}).AnyTimes() @@ -666,8 +666,8 @@ func TestSessionManager_Terminate(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) // Close is called by onEvict when Terminate removes the cache entry. sess.EXPECT().Close().Return(nil).AnyTimes() @@ -683,10 +683,10 @@ func TestSessionManager_Terminate(t *testing.T) { _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // Seed MetadataKeyIdentityBinding into storage so Terminate recognises this // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. _, err = storage.Update(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -847,8 +847,8 @@ func TestSessionManager_GetMultiSession(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) return sess, nil }).Times(1) @@ -873,7 +873,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Cross-pod restore path: session is in storage but not in the in-memory // cache (simulates pod restart or eviction). loadSession is called on Get. - t.Run("restore path: placeholder in storage (absent token hash) is treated as not found", func(t *testing.T) { + t.Run("restore path: placeholder in storage (absent identity binding) is treated as not found", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -885,24 +885,24 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sessionID := "restore-placeholder-session" // Write placeholder metadata directly to storage, bypassing the cache. - // Generate() stores an empty map with no token hash. + // Generate() stores an empty map with no identity binding. _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{}) require.NoError(t, err) - // loadSession detects absent MetadataKeyTokenHash → ErrSessionNotFound. + // loadSession detects absent MetadataKeyIdentityBinding → ErrSessionNotFound. multiSess, ok := sm.GetMultiSession(sessionID) assert.False(t, ok, "placeholder should not be restorable") assert.Nil(t, multiSess) }) - t.Run("restore path: fully-initialized zero-backend session (has token hash) is restored", func(t *testing.T) { + t.Run("restore path: fully-initialized zero-backend session (has identity binding) is restored", func(t *testing.T) { t.Parallel() tools := []vmcp.Tool{{Name: "zero-backend-tool", Description: "tool with no backends"}} ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) // MakeSessionWithID is only for Phase 2; unused in the restore path. - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-zero-backend-session" @@ -914,12 +914,13 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) - // Metadata matching what populateBackendMetadata now writes for a - // Phase-2-complete session with zero backends: MetadataKeyBackendIDs - // is always written (empty string for zero backends). + // Metadata matching what BindSession and populateBackendMetadata write + // for a Phase-2-complete anonymous session with zero backends: + // MetadataKeyIdentityBinding holds the unauthenticated sentinel; + // MetadataKeyBackendIDs is always written (empty string for zero backends). initializedMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", // anonymous sentinel — present but empty - vmcpsession.MetadataKeyBackendIDs: "", // always written; empty = zero backends + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", // anonymous sentinel + vmcpsession.MetadataKeyBackendIDs: "", // always written; empty = zero backends } _, err := sm.storage.Create(context.Background(), sessionID, initializedMeta) require.NoError(t, err) @@ -941,7 +942,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // RestoreSession. This test documents that backward-compat behaviour. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-legacy-session" @@ -953,9 +954,10 @@ func TestSessionManager_GetMultiSession(t *testing.T) { sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) - // Legacy metadata: token hash present but MetadataKeyBackendIDs absent. + // Metadata with identity binding but MetadataKeyBackendIDs absent + // (sessions written before populateBackendMetadata always wrote the key). legacyMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", // Phase 2 completion marker + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", // Phase 2 completion marker // MetadataKeyBackendIDs intentionally absent (legacy record) } _, err := sm.storage.Create(context.Background(), sessionID, legacyMeta) @@ -976,14 +978,14 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // that stale per-backend session IDs do not persist indefinitely in storage. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-metadata-persist-session" // The restored session returns fresh per-backend session metadata. freshMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", vmcpsession.MetadataKeyBackendIDs: "backend-a", vmcpsession.MetadataKeyBackendSessionPrefix + "backend-a": "fresh-session-id", } @@ -999,7 +1001,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Seed storage with stale per-backend session ID. staleMeta := map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", vmcpsession.MetadataKeyBackendIDs: "backend-a", vmcpsession.MetadataKeyBackendSessionPrefix + "backend-a": "stale-session-id", } @@ -1027,14 +1029,14 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // restored session. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-concurrent-delete-session" restored := sessionmocks.NewMockMultiSession(ctrl) restored.EXPECT().ID().Return(sessionID).AnyTimes() restored.EXPECT().GetMetadata().Return(map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }).AnyTimes() factory.EXPECT(). @@ -1053,7 +1055,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // Seed the inner storage with a valid session record. _, err = innerStorage.Create(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -1073,7 +1075,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { // metadata drift on the next liveness check and evict if necessary. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Times(0) sessionID := "restore-update-error-session" @@ -1090,7 +1092,7 @@ func TestSessionManager_GetMultiSession(t *testing.T) { t.Cleanup(func() { _ = cleanup(context.Background()) }) _, err = innerStorage.Create(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -1142,8 +1144,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1202,8 +1204,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1255,8 +1257,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { } factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(callToolResult, nil).Times(1) @@ -1296,8 +1298,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, errors.New("backend exploded")).Times(1) @@ -1328,8 +1330,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -1364,8 +1366,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { var capturedMeta map[string]any factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, _ *auth.Identity, _ string, _ map[string]any, meta map[string]any) (*vmcp.ToolCallResult, error) { @@ -1430,8 +1432,8 @@ func TestSessionManager_GetAdaptedTools(t *testing.T) { authErr := tc.authError factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, tools) sess.EXPECT().CallTool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, authErr).Times(1) @@ -1515,8 +1517,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) // Override default Resources() AnyTimes with a specific return sess.EXPECT().Resources().Return(resources).AnyTimes() @@ -1567,8 +1569,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///data.txt"). @@ -1615,8 +1617,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///broken.txt"). @@ -1662,8 +1664,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///binary.bin"). @@ -1725,8 +1727,8 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) sess.EXPECT().Resources().Return(resources).AnyTimes() sess.EXPECT().ReadResource(gomock.Any(), gomock.Any(), "file:///protected.txt"). @@ -1805,8 +1807,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { // Create mock directly (without newMockSession) so there is no // pre-existing Prompts().Return(nil).AnyTimes() that would win // the FIFO expectation race over our specific prompts list. @@ -1866,8 +1868,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -1908,8 +1910,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -1959,8 +1961,8 @@ func TestSessionManager_GetAdaptedPrompts(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() sess.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() @@ -2013,8 +2015,8 @@ func TestSessionManager_DecorateSession(t *testing.T) { ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { return newMockSession(t, ctrl, id, tools), nil }).Times(1) @@ -2076,17 +2078,17 @@ func TestSessionManager_DecorateSession(t *testing.T) { // decorator fn, so the re-check that follows fn() sees it is gone. ctrl := gomock.NewController(t) factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - // The mock session carries MetadataKeyTokenHash so that: + // The mock session carries MetadataKeyIdentityBinding so that: // 1. CreateSession stores it in storage (via sess.GetMetadata()), keeping // cache and storage in sync for checkSession's maps.Equal comparison. // 2. Terminate sees the key and takes the Phase 2 path (storage.Delete). - tokenHashMeta := map[string]string{sessiontypes.MetadataKeyTokenHash: ""} + bindingMeta := map[string]string{sessiontypes.MetadataKeyIdentityBinding: "unauthenticated"} factory.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := sessionmocks.NewMockMultiSession(ctrl) sess.EXPECT().ID().Return(id).AnyTimes() - sess.EXPECT().GetMetadata().Return(tokenHashMeta).AnyTimes() + sess.EXPECT().GetMetadata().Return(bindingMeta).AnyTimes() sess.EXPECT().Close().Return(nil).AnyTimes() // Other methods called by the session manager infrastructure. sess.EXPECT().Type().Return(transportsession.SessionType("")).AnyTimes() @@ -2136,7 +2138,7 @@ func TestSessionManager_CheckSession(t *testing.T) { t.Helper() ctrl := gomock.NewController(t) f := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) - f.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + f.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes().Return(nil, nil) f.EXPECT().RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes().Return(nil, nil) @@ -2450,10 +2452,10 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // Seed MetadataKeyIdentityBinding into storage so Terminate recognises this // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. _, err = storage.Update(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", }) require.NoError(t, err) @@ -2577,6 +2579,179 @@ func TestNotifyBackendExpired(t *testing.T) { }) } +// --------------------------------------------------------------------------- +// Tests: Phase 2 marker migration (#5306) +// --------------------------------------------------------------------------- + +// TestLoadSession_Phase2Marker_UsesIdentityBindingKey documents that the +// Phase-2 detection key for the restore path is MetadataKeyIdentityBinding, +// not the legacy MetadataKeyTokenHash. Sessions stored with only the legacy +// key are treated as not found (ErrSessionNotFound) and the client must +// re-initialize. Sessions stored with MetadataKeyIdentityBinding are restored +// normally. +func TestLoadSession_Phase2Marker_UsesIdentityBindingKey(t *testing.T) { + t.Parallel() + + t.Run("legacy session (only MetadataKeyTokenHash) returns not found", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + // RestoreSession must NOT be called for legacy sessions on the restore path. + factory.EXPECT().RestoreSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + sessionID := "legacy-only-token-hash-session" + // Seed storage with only the legacy key — no MetadataKeyIdentityBinding. + _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + }) + require.NoError(t, err) + + // loadSession must treat absent MetadataKeyIdentityBinding as a legacy session + // and return (nil, false) — not attempting RestoreSession. + multiSess, ok := sm.GetMultiSession(sessionID) + assert.False(t, ok, "legacy session with only MetadataKeyTokenHash must not be restored") + assert.Nil(t, multiSess) + }) + + t.Run("session with MetadataKeyIdentityBinding is restored normally", func(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{{Name: "restored-tool", Description: "a restored tool"}} + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT().MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + + sessionID := "identity-binding-session" + restored := newMockSession(t, ctrl, sessionID, tools) + + factory.EXPECT(). + RestoreSession(gomock.Any(), sessionID, gomock.Any(), gomock.Any()). + Return(restored, nil).Times(1) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + _, err := sm.storage.Create(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", + vmcpsession.MetadataKeyBackendIDs: "", + }) + require.NoError(t, err) + + multiSess, ok := sm.GetMultiSession(sessionID) + require.True(t, ok, "session with MetadataKeyIdentityBinding must be restorable") + require.NotNil(t, multiSess) + assert.Equal(t, sessionID, multiSess.ID()) + }) +} + +// TestTerminate_Phase2DetectionUsesIdentityBindingKey verifies that Terminate +// uses MetadataKeyIdentityBinding (not the legacy MetadataKeyTokenHash) to +// distinguish Phase 2 sessions (full MultiSession → Delete) from Phase 1 +// placeholders (→ mark terminated). +func TestTerminate_Phase2DetectionUsesIdentityBindingKey(t *testing.T) { + t.Parallel() + + t.Run("session with MetadataKeyIdentityBinding takes delete path", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "s", nil) + sess.EXPECT().Close().Return(nil).AnyTimes() + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // Write MetadataKeyIdentityBinding into storage to simulate a Phase 2 session. + _, err := storage.Update(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyIdentityBinding: "unauthenticated", + }) + require.NoError(t, err) + + // Terminate must take the Phase 2 path: storage.Delete (not marked terminated). + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must be deleted from storage, not just marked terminated. + _, loadErr := storage.Load(context.Background(), sessionID) + assert.ErrorIs(t, loadErr, transportsession.ErrSessionNotFound, + "Phase 2 Terminate must delete the session from storage") + }) + + t.Run("placeholder without MetadataKeyIdentityBinding takes mark-terminated path", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "", nil) + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // No MetadataKeyIdentityBinding in storage — this is a Phase 1 placeholder. + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must remain in storage but marked as terminated (not deleted). + metadata, loadErr := storage.Load(context.Background(), sessionID) + require.NoError(t, loadErr, "placeholder must remain in storage (TTL will clean it)") + assert.Equal(t, MetadataValTrue, metadata[MetadataKeyTerminated], + "Phase 1 Terminate must mark the session terminated, not delete it") + }) +} + +// TestTerminate_LegacyFormatSession_TakesPlaceholderPath is a B5 documentation +// test. It verifies that a legacy session stored in Redis with only the old +// MetadataKeyTokenHash key (no MetadataKeyIdentityBinding) causes Terminate to +// take the placeholder (mark-terminated) path rather than deleting the session. +// +// This is intentional: without the identity binding key, the Manager cannot +// tell whether the record is a real Phase-2 session or a corrupted/partial +// record. Treating it as a placeholder (soft termination) is safe — the TTL +// will eventually clean it up. The comment at session_manager.go line 485 +// references this test. +func TestTerminate_LegacyFormatSession_TakesPlaceholderPath(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sess := newMockSession(t, ctrl, "", nil) + factory := newMockFactory(t, ctrl, sess) + registry := newFakeRegistry() + sm, storage := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + + // Seed storage with only the legacy key — simulates a pre-#5306 session in Redis. + // MetadataKeyIdentityBinding is absent. + _, err := storage.Update(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + }) + require.NoError(t, err) + + // Terminate must take the placeholder path: mark as terminated, NOT delete. + isNotAllowed, err := sm.Terminate(sessionID) + require.NoError(t, err) + assert.False(t, isNotAllowed) + + // Session must still exist in storage — marked terminated, not deleted. + // (Storage cleanup happens via TTL or the next GET → checkSession → eviction.) + metadata, loadErr := storage.Load(context.Background(), sessionID) + require.NoError(t, loadErr, + "legacy session must remain in storage after Terminate (not deleted)") + assert.Equal(t, MetadataValTrue, metadata[MetadataKeyTerminated], + "legacy session Terminate must set MetadataKeyTerminated rather than deleting") +} + // --------------------------------------------------------------------------- // Helper // --------------------------------------------------------------------------- diff --git a/pkg/vmcp/server/telemetry_integration_test.go b/pkg/vmcp/server/telemetry_integration_test.go index 850a350457..70c7a76f8f 100644 --- a/pkg/vmcp/server/telemetry_integration_test.go +++ b/pkg/vmcp/server/telemetry_integration_test.go @@ -111,7 +111,7 @@ func newBackendAwareTestFactory(tools []vmcp.Tool, rt *vmcp.RoutingTable) (*back } func (f *backendAwareTestFactory) MakeSessionWithID( - _ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend, + _ context.Context, id string, _ *auth.Identity, _ []*vmcp.Backend, ) (vmcpsession.MultiSession, error) { return &backendAwareTestSession{ Session: transportsession.NewStreamableSession(id), diff --git a/pkg/vmcp/server/testfactory_test.go b/pkg/vmcp/server/testfactory_test.go index aba748c299..e9b7e62145 100644 --- a/pkg/vmcp/server/testfactory_test.go +++ b/pkg/vmcp/server/testfactory_test.go @@ -27,7 +27,7 @@ type minimalTestFactory struct{} var _ vmcpsession.MultiSessionFactory = (*minimalTestFactory)(nil) func (*minimalTestFactory) MakeSessionWithID( - _ context.Context, _ string, _ *auth.Identity, _ bool, _ []*vmcp.Backend, + _ context.Context, _ string, _ *auth.Identity, _ []*vmcp.Backend, ) (vmcpsession.MultiSession, error) { return nil, fmt.Errorf("minimalTestFactory: MakeSessionWithID not implemented in test helper") } diff --git a/pkg/vmcp/session/binding/binding.go b/pkg/vmcp/session/binding/binding.go new file mode 100644 index 0000000000..8ba3346a8f --- /dev/null +++ b/pkg/vmcp/session/binding/binding.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package binding is the single owner of the identity-binding format used by +// vMCP session storage. An identity binding encodes the OIDC principal that +// created a session as a single opaque string suitable for storage in a +// key/value store such as Redis or Valkey. +// +// # Format +// +// A bound identity is encoded as iss + "\x00" + sub. NUL is rejected from +// either half by Format and Parse: OIDC Core does not formally forbid NUL in +// sub, but no real-world issuer emits one, and accepting it would let a +// corrupted or adversarial value re-split during Parse. +// +// Sessions created before any auth middleware ran use the literal +// "unauthenticated" sentinel. The sentinel must not contain '\x00' so it +// cannot collide with any bound form; any future format change must preserve +// that disjointness. +// +// # Trust boundary +// +// Bindings are stored plaintext at rest. They are PII but not credentials — +// they identify but do not authenticate. They carry no freshness signal (no +// exp, no nonce) and are NOT a substitute for token validation: callers must +// only compare a stored binding against a freshly-validated (iss, sub) pair +// from the current request's token. +package binding + +import ( + "errors" + "strings" +) + +// UnauthenticatedSentinel is the binding value stored for sessions that were +// created without an authenticated identity (auth middleware not present or +// identity nil). +const UnauthenticatedSentinel = "unauthenticated" + +// ErrInvalidBinding is returned by Format when either input is empty or +// contains a NUL byte. +var ErrInvalidBinding = errors.New("invalid identity binding") + +// Format returns the canonical on-the-wire form of an identity binding: +// iss + "\x00" + sub. Returns ErrInvalidBinding when either input is empty +// or contains a NUL byte. +func Format(iss, sub string) (string, error) { + if iss == "" || sub == "" { + return "", ErrInvalidBinding + } + if strings.ContainsRune(iss, '\x00') || strings.ContainsRune(sub, '\x00') { + return "", ErrInvalidBinding + } + return iss + "\x00" + sub, nil +} + +// Parse splits an on-the-wire binding into its (iss, sub) components. +// Returns ok=true only when s contains exactly one NUL and both halves are +// non-empty. Returns ok=false for the unauthenticated sentinel, for malformed +// input, and for empty strings. Callers must check ok; the empty-string +// return values are not meaningful when ok=false. +func Parse(s string) (iss, sub string, ok bool) { + iss, sub, found := strings.Cut(s, "\x00") + if !found || iss == "" || sub == "" { + return "", "", false + } + // strings.Cut splits on the first NUL, so iss cannot contain one. Sub may + // still carry trailing NULs from a malformed input; reject those. + if strings.ContainsRune(sub, '\x00') { + return "", "", false + } + return iss, sub, true +} + +// IsUnauthenticated reports whether s is the literal unauthenticated sentinel. +func IsUnauthenticated(s string) bool { + return s == UnauthenticatedSentinel +} diff --git a/pkg/vmcp/session/binding/binding_test.go b/pkg/vmcp/session/binding/binding_test.go new file mode 100644 index 0000000000..67048dd1c7 --- /dev/null +++ b/pkg/vmcp/session/binding/binding_test.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package binding_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" +) + +func TestFormat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + iss string + sub string + want string + wantErr bool + }{ + { + name: "typical OIDC issuer and subject", + iss: "https://idp.example", + sub: "user-42", + want: "https://idp.example\x00user-42", + }, + { + name: "issuer with path", + iss: "https://accounts.google.com/o/oauth2", + sub: "109876543210987654321", + want: "https://accounts.google.com/o/oauth2\x00109876543210987654321", + }, + { + name: "minimal non-empty inputs", + iss: "i", + sub: "s", + want: "i\x00s", + }, + { + name: "email-like subject", + iss: "https://auth.example.com", + sub: "alice@example.com", + want: "https://auth.example.com\x00alice@example.com", + }, + {name: "empty iss", iss: "", sub: "sub", wantErr: true}, + {name: "empty sub", iss: "iss", sub: "", wantErr: true}, + {name: "NUL in iss", iss: "a\x00b", sub: "sub", wantErr: true}, + {name: "NUL in sub", iss: "iss", sub: "a\x00b", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := binding.Format(tt.iss, tt.sub) + if tt.wantErr { + assert.Empty(t, got) + require.ErrorIs(t, err, binding.ErrInvalidBinding) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + + gotIss, gotSub, ok := binding.Parse(got) + require.True(t, ok, "Parse must return ok=true for a value produced by Format") + assert.Equal(t, tt.iss, gotIss) + assert.Equal(t, tt.sub, gotSub) + }) + } +} + +func TestParse_RejectsInvalid(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + }{ + {name: "empty string", input: ""}, + {name: "no NUL separator", input: "nothing-here"}, + {name: "empty iss half", input: "\x00sub"}, + {name: "empty sub half", input: "iss\x00"}, + // Defense-in-depth: a value that splits into three parts on NUL must be + // rejected so a corrupt or adversarial store write cannot smuggle extra + // data past Parse. + {name: "trailing NUL with extra data", input: "iss\x00sub\x00trailer"}, + // Callers must call IsUnauthenticated before Parse on the restore path; + // the sentinel must not accidentally parse as a bound binding. + {name: "unauthenticated sentinel", input: binding.UnauthenticatedSentinel}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + iss, sub, ok := binding.Parse(tt.input) + assert.False(t, ok) + assert.Empty(t, iss) + assert.Empty(t, sub) + }) + } +} + +func TestIsUnauthenticated(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + {name: "sentinel", input: binding.UnauthenticatedSentinel, want: true}, + {name: "empty string", input: "", want: false}, + {name: "arbitrary string", input: "foo", want: false}, + {name: "bound binding", input: "iss\x00sub", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, binding.IsUnauthenticated(tt.input)) + }) + } +} diff --git a/pkg/vmcp/session/connector_integration_test.go b/pkg/vmcp/session/connector_integration_test.go index 3fc0e2d58f..c66b1bdeaf 100644 --- a/pkg/vmcp/session/connector_integration_test.go +++ b/pkg/vmcp/session/connector_integration_test.go @@ -5,7 +5,6 @@ package session import ( "context" - "encoding/hex" "net/http" "net/http/httptest" "sync" @@ -22,7 +21,7 @@ import ( vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/session/internal/security" + internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) @@ -115,7 +114,7 @@ func TestSessionFactory_Integration_CapabilityDiscovery(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -144,7 +143,7 @@ func TestSessionFactory_Integration_CallTool(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -167,7 +166,7 @@ func TestSessionFactory_Integration_ReadResource(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -190,7 +189,7 @@ func TestSessionFactory_Integration_GetPrompt(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -219,7 +218,7 @@ func TestSessionFactory_Integration_MultipleBackends(t *testing.T) { } factory := NewSessionFactory(newUnauthenticatedRegistry(t)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) @@ -229,21 +228,46 @@ func TestSessionFactory_Integration_MultipleBackends(t *testing.T) { } // --------------------------------------------------------------------------- -// Token-binding integration tests — HMAC rejection for ReadResource / GetPrompt +// Identity-binding integration tests — CallerRejection for ReadResource / GetPrompt // --------------------------------------------------------------------------- // TestTokenBinding_CallerRejection verifies that the hijack-prevention decorator // is applied to all three protected methods (CallTool, ReadResource, GetPrompt): -// each rejects a wrong token (ErrUnauthorizedCaller) and a nil caller -// (ErrNilCaller) before any backend routing occurs, so nilBackendConnector suffices. -func TestTokenBinding_CallerRejection(t *testing.T) { +// each rejects a wrong caller (ErrUnauthorizedCaller) and a nil caller +// (ErrNilCaller) before any backend routing occurs. +// +// The identity binding is derived from Claims["iss"] and Claims["sub"] per the +// new #5306 model (no HMAC secret required). +func TestIdentityBinding_CallerRejection(t *testing.T) { t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: "alice-token"} - wrongCaller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "bob"}, Token: "wrong-token"} + // Both alice and bob need valid Claims so the binding decorator can extract + // their (iss, sub) pairs. alice creates the session; bob is the wrong caller. + alice := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } + wrongCaller := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "bob", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "bob", + }, + }, + } - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) + // No backend connector needed: auth validation fires before any routing. + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), alice, nil) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) @@ -266,7 +290,7 @@ func TestTokenBinding_CallerRejection(t *testing.T) { } for _, fn := range callFns { - t.Run(fn.name+"/wrong token", func(t *testing.T) { + t.Run(fn.name+"/wrong caller", func(t *testing.T) { t.Parallel() assert.ErrorIs(t, fn.call(wrongCaller), sessiontypes.ErrUnauthorizedCaller) }) @@ -280,7 +304,7 @@ func TestTokenBinding_CallerRejection(t *testing.T) { // TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend verifies that a // bound session accepts ReadResource and GetPrompt calls from the correct caller // when a real backend is connected. -func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { +func TestIdentityBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { t.Parallel() baseURL := startInProcessMCPServer(t) @@ -291,15 +315,24 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { TransportType: "streamable-http", } - const rawToken = "alice-real-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} + // Identity with Claims so binding.Format(iss, sub) succeeds and the session + // is bound to the caller identity. + identity := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } - factory := NewSessionFactory(newUnauthenticatedRegistry(t), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, []*vmcp.Backend{backend}) + factory := NewSessionFactory(newUnauthenticatedRegistry(t)) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, sess.Close()) }) - t.Run("allows ReadResource with correct token", func(t *testing.T) { + t.Run("allows ReadResource with correct caller", func(t *testing.T) { t.Parallel() result, err := sess.ReadResource(context.Background(), identity, "test://data") require.NoError(t, err) @@ -308,7 +341,7 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { assert.Equal(t, "hello", result.Contents[0].Text) }) - t.Run("allows GetPrompt with correct token", func(t *testing.T) { + t.Run("allows GetPrompt with correct caller", func(t *testing.T) { t.Parallel() result, err := sess.GetPrompt(context.Background(), identity, "greet", nil) require.NoError(t, err) @@ -319,71 +352,43 @@ func TestTokenBinding_ReadResource_And_GetPrompt_WithRealBackend(t *testing.T) { }) } -// TestTokenBinding_DifferentSecretsProduceDifferentHashes verifies that two -// session factories configured with different HMAC secrets store different token -// hashes for the same raw bearer token. This is the key isolation property that -// prevents sessions from one secret epoch from being validated against another. -func TestTokenBinding_DifferentSecretsProduceDifferentHashes(t *testing.T) { - t.Parallel() - - const rawToken = "shared-token-same-for-both" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: rawToken} - - factoryA := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("secret-A-exactly-32-bytes-long!!"))) - factoryB := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("secret-B-exactly-32-bytes-long!!"))) - - sessA, err := factoryA.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessA.Close() }) - - sessB, err := factoryB.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessB.Close() }) - - hashA := sessA.GetMetadata()[MetadataKeyTokenHash] - hashB := sessB.GetMetadata()[MetadataKeyTokenHash] - - assert.NotEmpty(t, hashA) - assert.NotEmpty(t, hashB) - assert.NotEqual(t, hashA, hashB, - "different HMAC secrets must produce different token hashes for the same input token") -} - -// TestRestoreHijackPrevention_Integration_RoundTrip verifies the full +// TestTokenBinding_RestoreSession_RoundTrip verifies the full // store-then-restore flow across a real factory-created session: // -// 1. Create a session via the factory (writes tokenHash + tokenSalt to metadata). -// 2. Extract the persisted values. -// 3. Wrap a fresh base session with RestoreHijackPrevention using those values. -// 4. Confirm the restored decorator accepts the original token and rejects others. -func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { +// 1. Create a session via the factory (writes MetadataKeyIdentityBinding to metadata). +// 2. Extract the persisted binding. +// 3. Restore the session via RestoreSession using the persisted metadata. +// 4. Confirm the restored decorator accepts the original caller and rejects others. +func TestIdentityBinding_RestoreSession_RoundTrip(t *testing.T) { t.Parallel() - const rawToken = "integration-token" - hmacSecret := []byte("test-hmac-secret-exactly-32bytes") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} + identity := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "alice", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }, + }, + } - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(hmacSecret)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) + connector := func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } + factory := newSessionFactoryWithConnector(connector) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, nil) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) // Extract persisted values — these simulate what would be read back from Redis. meta := sess.GetMetadata() - persistedHash := meta[MetadataKeyTokenHash] - persistedSalt := meta[sessiontypes.MetadataKeyTokenSalt] - require.NotEmpty(t, persistedHash, "factory must write tokenHash to metadata") - require.NotEmpty(t, persistedSalt, "factory must write tokenSalt to metadata") - - // Simulate "Pod B": restore the decorator from persisted metadata. - // We use a nil-connector session as the inner session (no real backend needed - // to test auth path). - innerSess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = innerSess.Close() }) + storedBinding := meta[MetadataKeyIdentityBinding] + require.NotEmpty(t, storedBinding, "factory must write MetadataKeyIdentityBinding to metadata") - restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, hmacSecret) + // Simulate "Pod B": restore the session from persisted metadata. + restored, err := factory.RestoreSession(context.Background(), uuid.New().String(), meta, nil) require.NoError(t, err) + t.Cleanup(func() { _ = restored.Close() }) ctx := context.Background() @@ -393,8 +398,16 @@ func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { require.NotErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) require.NotErrorIs(t, err, sessiontypes.ErrNilCaller) - // A different caller is rejected at the auth layer — before any backend routing. - wrongCaller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, Token: "eve-token"} + // A different caller is rejected at the auth layer. + wrongCaller := &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: "eve", + Claims: map[string]any{ + "iss": "https://idp.example", + "sub": "eve", + }, + }, + } _, err = restored.CallTool(ctx, wrongCaller, "any-tool", nil, nil) require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) @@ -403,70 +416,6 @@ func TestRestoreHijackPrevention_Integration_RoundTrip(t *testing.T) { require.ErrorIs(t, err, sessiontypes.ErrNilCaller) } -// TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch verifies -// that a session restored on a replica with a different HMAC secret rejects -// the original caller's token, documenting the operational requirement that -// all replicas must share the same secret. -func TestRestoreHijackPrevention_Integration_CrossReplicaSecretMismatch(t *testing.T) { - t.Parallel() - - secretA := []byte("secret-A-exactly-32-bytes-long!!") - secretB := []byte("secret-B-exactly-32-bytes-long!!") - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: "alice-token"} - - // Pod A creates the session with secretA, persisting the hash. - factoryA := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretA)) - sessA, err := factoryA.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sessA.Close() }) - - persistedHash := sessA.GetMetadata()[MetadataKeyTokenHash] - persistedSalt := sessA.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - - // Pod B restores with secretB — the persisted hash was computed with secretA, - // so validation will produce a different HMAC and reject the caller. - factoryB := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretB)) - innerSess, err := factoryB.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = innerSess.Close() }) - - restored, err := security.RestoreHijackPrevention(innerSess, persistedHash, persistedSalt, secretB) - require.NoError(t, err) - - _, err = restored.CallTool(context.Background(), identity, "any-tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, - "cross-replica secret mismatch must reject the original caller") -} - -// TestTokenBinding_MetadataEncoding verifies that the token hash and salt stored -// in session metadata are valid hex strings of the expected lengths: -// - token hash: 64 hex chars (32-byte HMAC-SHA256) -// - token salt: 32 hex chars (16-byte random salt) -func TestTokenBinding_MetadataEncoding(t *testing.T) { - t.Parallel() - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token-123"} - - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret([]byte("test-hmac-secret-exactly-32bytes"))) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - t.Cleanup(func() { _ = sess.Close() }) - - tokenHash := sess.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, tokenHash) - assert.Len(t, tokenHash, 64, "HMAC-SHA256 hex-encoded hash must be 64 characters") - hashBytes, err := hex.DecodeString(tokenHash) - require.NoError(t, err, "token hash must be valid hex") - assert.Len(t, hashBytes, 32, "decoded token hash must be 32 bytes") - - tokenSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.NotEmpty(t, tokenSalt) - saltBytes, err := hex.DecodeString(tokenSalt) - require.NoError(t, err, "token salt must be valid hex") - assert.Len(t, saltBytes, 16, "decoded token salt must be 16 bytes") -} - // startInProcessMCPServerWithHeaderCapture starts an in-process MCP server and // returns the base URL along with a function that returns all Mcp-Session-Id // header values received by the server from clients. @@ -529,7 +478,7 @@ func TestSessionFactory_Integration_RestoreSession_SendsStoredSessionHintToBacke // Create the original session — the backend assigns a session ID over // streamable-HTTP and we store it in metadata. - orig, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + orig, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) t.Cleanup(func() { _ = orig.Close() }) diff --git a/pkg/vmcp/session/decorating_factory.go b/pkg/vmcp/session/decorating_factory.go index f536e9ef0f..2254634a23 100644 --- a/pkg/vmcp/session/decorating_factory.go +++ b/pkg/vmcp/session/decorating_factory.go @@ -51,10 +51,9 @@ func (f *decoratingMultiSessionFactory) MakeSessionWithID( ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) { - sess, err := f.base.MakeSessionWithID(ctx, id, identity, allowAnonymous, backends) + sess, err := f.base.MakeSessionWithID(ctx, id, identity, backends) if err != nil { return nil, err } diff --git a/pkg/vmcp/session/decorating_factory_test.go b/pkg/vmcp/session/decorating_factory_test.go index 361ad261e9..ae4776ed85 100644 --- a/pkg/vmcp/session/decorating_factory_test.go +++ b/pkg/vmcp/session/decorating_factory_test.go @@ -34,7 +34,7 @@ func TestNewDecoratingFactory_DecoratorsAppliedInOrder(t *testing.T) { sess := sessionmocks.NewMockMultiSession(ctrl) base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) var order []int @@ -48,7 +48,7 @@ func TestNewDecoratingFactory_DecoratorsAppliedInOrder(t *testing.T) { } factory := session.NewDecoratingFactory(base, dec1, dec2) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.NoError(t, err) assert.Equal(t, []int{1, 2}, order) } @@ -62,7 +62,7 @@ func TestNewDecoratingFactory_DecoratorError_ClosesSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("decorator boom") @@ -70,7 +70,7 @@ func TestNewDecoratingFactory_DecoratorError_ClosesSession(t *testing.T) { return nil, decErr }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.ErrorIs(t, err, decErr) } @@ -85,7 +85,7 @@ func TestNewDecoratingFactory_SecondDecoratorError_ClosesCurrentSession(t *testi base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("second decorator boom") @@ -93,7 +93,7 @@ func TestNewDecoratingFactory_SecondDecoratorError_ClosesCurrentSession(t *testi dec2 := func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return nil, decErr } factory := session.NewDecoratingFactory(base, dec1, dec2) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.ErrorIs(t, err, decErr) } @@ -106,14 +106,14 @@ func TestNewDecoratingFactory_NilReturnWithNoError_ClosesSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return nil, nil // buggy decorator }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "nil session") } @@ -127,7 +127,7 @@ func TestNewDecoratingFactory_CloseErrorOnDecoratorFailure_DoesNotSuppressOrigin base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) decErr := errors.New("decorator error") @@ -135,7 +135,7 @@ func TestNewDecoratingFactory_CloseErrorOnDecoratorFailure_DoesNotSuppressOrigin return nil, decErr }) - _, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + _, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) // The original decorator error, not the close error, is returned. require.ErrorIs(t, err, decErr) } @@ -149,14 +149,14 @@ func TestNewDecoratingFactory_HappyPath_ReturnsFinalSession(t *testing.T) { base := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) base.EXPECT(). - MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(sess, nil) factory := session.NewDecoratingFactory(base, func(_ context.Context, _ session.MultiSession) (session.MultiSession, error) { return finalSess, nil }, ) - got, err := factory.MakeSessionWithID(context.Background(), "id", nil, true, nil) + got, err := factory.MakeSessionWithID(context.Background(), "id", nil, nil) require.NoError(t, err) assert.Equal(t, finalSess, got) } diff --git a/pkg/vmcp/session/default_session_test.go b/pkg/vmcp/session/default_session_test.go index a0055f57e9..a2848748e8 100644 --- a/pkg/vmcp/session/default_session_test.go +++ b/pkg/vmcp/session/default_session_test.go @@ -549,7 +549,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) @@ -567,9 +567,9 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - s1, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + s1, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) - s2, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + s2, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) assert.NotEqual(t, s1.ID(), s2.ID()) @@ -582,7 +582,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(successConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, nil) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, nil) require.NoError(t, err) require.NotNil(t, sess) @@ -598,7 +598,7 @@ func TestNewSessionFactory_MakeSession(t *testing.T) { factory := newSessionFactoryWithConnector(successConnector) // Mix of valid and nil entries; nil must not cause a panic. backends := []*vmcp.Backend{nil, backend, nil} - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) require.NotNil(t, sess) @@ -626,7 +626,7 @@ func TestNewSessionFactory_PartialInitialisation(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err, "partial init must not return an error") require.NotNil(t, sess) @@ -685,7 +685,7 @@ func TestNewSessionFactory_ConnectorReturnsNilWithoutError(t *testing.T) { } factory := newSessionFactoryWithConnector(wrappedConnector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err) require.NotNil(t, sess) assert.Empty(t, sess.Tools()) @@ -712,7 +712,7 @@ func TestNewSessionFactory_ConnectorReturnsConnWithError(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "partial failure must not abort the session") require.NotNil(t, sess) assert.Empty(t, sess.Tools()) @@ -741,7 +741,7 @@ func TestNewSessionFactory_CapabilityNameConflictIsResolvedDeterministically(t * } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) require.NotNil(t, sess) defer func() { require.NoError(t, sess.Close()) }() @@ -771,7 +771,7 @@ func TestNewSessionFactory_AllBackendsFail(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "all-fail must still return a valid (empty) session") require.NotNil(t, sess) @@ -795,7 +795,7 @@ func TestNewSessionFactory_BackendInitTimeout(t *testing.T) { } factory := newSessionFactoryWithConnector(connector, WithBackendInitTimeout(50*time.Millisecond)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, []*vmcp.Backend{backend}) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, []*vmcp.Backend{backend}) require.NoError(t, err, "timeout is a partial failure, not a hard error") require.NotNil(t, sess) @@ -844,7 +844,7 @@ func TestNewSessionFactory_ParallelInit(t *testing.T) { } factory := newSessionFactoryWithConnector(connector, WithMaxBackendInitConcurrency(3)) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), nil, backends) require.NoError(t, err) // All backends must have been initialised. @@ -872,46 +872,54 @@ func TestNewSessionFactory_MakeSession_Metadata(t *testing.T) { } tests := []struct { - name string - connector backendConnector - identity *auth.Identity - backends []*vmcp.Backend - wantSubject string // non-empty → assert equal; empty → assert key absent - wantBackendIDs string // always asserted equal (key is always written, "" for zero backends) + name string + connector backendConnector + identity *auth.Identity + backends []*vmcp.Backend + wantIdentityBinding string // expected MetadataKeyIdentityBinding value + wantBackendIDs string // always asserted equal (key is always written, "" for zero backends) }{ { - name: "sets identity subject and backend IDs", - connector: successConnector, - identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user-123"}}, - backends: []*vmcp.Backend{backend1}, - wantSubject: "user-123", - wantBackendIDs: "b1", + // Identity with Subject but no Claims — no extractable (iss, sub) + // pair → binding is the unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity has no iss/sub claims", + connector: successConnector, + identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user-123"}}, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "omits subject when identity is nil", - connector: successConnector, - identity: nil, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "b1", + // Nil identity → unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity is nil", + connector: successConnector, + identity: nil, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "omits subject when subject is empty", - connector: successConnector, - identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: ""}}, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "b1", + // Empty Subject, no Claims → unauthenticated sentinel. + name: "writes unauthenticated sentinel when identity is empty", + connector: successConnector, + identity: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: ""}}, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1", }, { - name: "backend IDs are sorted", - connector: successConnector, - backends: []*vmcp.Backend{backend2, backend1}, // intentionally reversed - wantBackendIDs: "b1,b2", + name: "backend IDs are sorted", + connector: successConnector, + backends: []*vmcp.Backend{backend2, backend1}, // intentionally reversed + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "b1,b2", }, { - name: "writes empty backend IDs when no backends connect", - connector: failConnector, - backends: []*vmcp.Backend{backend1}, - wantBackendIDs: "", // key present, value empty — explicit zero-backend sentinel + name: "writes empty backend IDs when no backends connect", + connector: failConnector, + backends: []*vmcp.Backend{backend1}, + wantIdentityBinding: "unauthenticated", + wantBackendIDs: "", // key present, value empty — explicit zero-backend sentinel }, } @@ -920,19 +928,17 @@ func TestNewSessionFactory_MakeSession_Metadata(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(tt.connector) - sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), tt.identity, true, tt.backends) + sess, err := factory.MakeSessionWithID(context.Background(), uuid.New().String(), tt.identity, tt.backends) require.NoError(t, err) require.NotNil(t, sess) defer func() { require.NoError(t, sess.Close()) }() meta := sess.GetMetadata() - if tt.wantSubject != "" { - assert.Equal(t, tt.wantSubject, meta[MetadataKeyIdentitySubject]) - } else { - _, ok := meta[MetadataKeyIdentitySubject] - assert.False(t, ok, "identity subject key should be absent") - } + // MetadataKeyIdentityBinding is always written by BindSession. + bindingVal, bindingPresent := meta[MetadataKeyIdentityBinding] + assert.True(t, bindingPresent, "MetadataKeyIdentityBinding must always be written") + assert.Equal(t, tt.wantIdentityBinding, bindingVal) // MetadataKeyBackendIDs is always written (even "" for zero backends). backendIDsVal, backendIDsPresent := meta[MetadataKeyBackendIDs] @@ -1136,11 +1142,11 @@ func TestMakeSessionWithID_InvalidIDReturnsError(t *testing.T) { return nil, nil, nil }) - _, err := f.MakeSessionWithID(context.Background(), "", nil, true, nil) + _, err := f.MakeSessionWithID(context.Background(), "", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "must not be empty") - _, err = f.MakeSessionWithID(context.Background(), "bad id", nil, true, nil) + _, err = f.MakeSessionWithID(context.Background(), "bad id", nil, nil) require.Error(t, err) assert.Contains(t, err.Error(), "invalid character") } diff --git a/pkg/vmcp/session/factory.go b/pkg/vmcp/session/factory.go index a5e115eacd..a7ca9803f2 100644 --- a/pkg/vmcp/session/factory.go +++ b/pkg/vmcp/session/factory.go @@ -19,6 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" "github.com/stacklok/toolhive/pkg/vmcp/session/internal/security" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" @@ -28,11 +29,6 @@ const ( defaultMaxBackendInitConcurrency = 10 defaultBackendInitTimeout = 30 * time.Second - // MetadataKeyIdentitySubject is the transport-session metadata key that - // holds the subject claim of the authenticated caller (identity.Subject). - // Set at session creation; empty for anonymous callers. - MetadataKeyIdentitySubject = "vmcp.identity.subject" - // MetadataKeyBackendIDs is the transport-session metadata key that holds // a comma-separated, sorted list of successfully-connected backend IDs. // The key is always written, even as an empty string for zero-backend @@ -46,18 +42,6 @@ const ( MetadataKeyBackendSessionPrefix = "vmcp.backend.session." ) -var ( - // defaultHMACSecret is the fallback HMAC secret used when WithHMACSecret is not provided. - // WARNING: This is INSECURE and should ONLY be used for testing/development. - // Production deployments MUST provide a secure secret via WithHMACSecret option. - // - // NOTE: In multi-replica deployments, all replicas must use the same HMAC secret, - // injected via the VMCP_SESSION_HMAC_SECRET environment variable. If replicas use - // different secrets, cross-pod token validation will silently reject legitimate - // callers. The default insecure secret must NOT be used in production. - defaultHMACSecret = []byte("insecure-default-for-testing-only-change-in-production") -) - // MultiSessionFactory creates new MultiSessions for connecting clients. type MultiSessionFactory interface { // MakeSessionWithID creates a new MultiSession with a specific session ID. @@ -67,9 +51,8 @@ type MultiSessionFactory interface { // The id parameter must be non-empty and should be a valid MCP session ID // (visible ASCII characters, 0x21 to 0x7E per the MCP specification). // - // The allowAnonymous parameter controls whether the session allows nil caller - // identity. If false, all session method calls must provide a valid caller - // that matches the session creator's identity. + // Whether the session allows anonymous (nil) caller identity is derived + // internally from identity via ShouldAllowAnonymous. // // All other behaviour (partial initialisation, bounded concurrency, etc.) // is identical to MakeSession. @@ -77,18 +60,19 @@ type MultiSessionFactory interface { ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) // RestoreSession reconstructs a live MultiSession from persisted metadata. // It reconnects to the backends whose IDs are listed in storedMetadata under // MetadataKeyBackendIDs, rebuilds the routing table, and reapplies the - // hijack-prevention decorator using the stored token hash and salt. + // session-binding decorator from the stored identity binding. // // Use this when the node-local session cache misses — for example after a // pod restart or when a request is routed to a different pod. It is more // expensive than a cache hit because it opens new backend connections. + // Because MCP clients cannot be serialised, sticky sessions (session affinity + // at the load balancer) minimise how often this path is taken. // // allBackends is the current backend list from the registry; RestoreSession // filters it to the subset originally included in this session. @@ -131,7 +115,6 @@ type defaultMultiSessionFactory struct { connector backendConnector maxConcurrency int backendInitTimeout time.Duration - hmacSecret []byte // Server-managed secret for HMAC-SHA256 token hashing aggregator aggregator.Aggregator // Optional: applies tool transforms (overrides, conflict resolution, filter) } @@ -158,27 +141,6 @@ func WithBackendInitTimeout(d time.Duration) MultiSessionFactoryOption { } } -// WithHMACSecret sets the server-managed secret used for HMAC-SHA256 token hashing. -// The secret should be 32+ bytes and loaded from secure configuration (e.g., environment -// variable, secret management system). -// -// The secret is defensively copied to prevent external modification after assignment. -// Empty or nil secrets are rejected (function is a no-op) to prevent accidental security downgrades. -// -// If not set, a default insecure secret is used (NOT RECOMMENDED for production). -func WithHMACSecret(secret []byte) MultiSessionFactoryOption { - return func(f *defaultMultiSessionFactory) { - // Reject empty/nil secrets to prevent silent security downgrade - if len(secret) == 0 { - slog.Warn("WithHMACSecret: empty or nil secret rejected, falling back to default insecure secret", - "recommendation", "provide a secure secret via VMCP_SESSION_HMAC_SECRET environment variable") - return - } - // Make a defensive copy to prevent external modification - f.hmacSecret = append([]byte(nil), secret...) - } -} - // WithAggregator configures the factory to apply per-backend tool overrides, // conflict resolution, and advertising filters when building sessions. // If not set, raw backend tool names are used unchanged. @@ -202,7 +164,6 @@ func newSessionFactoryWithConnector(connector backendConnector, opts ...MultiSes connector: connector, maxConcurrency: defaultMaxBackendInitConcurrency, backendInitTimeout: defaultBackendInitTimeout, - hmacSecret: defaultHMACSecret, // Initialize with default (insecure) secret } for _, opt := range opts { opt(f) @@ -349,29 +310,11 @@ func (f *defaultMultiSessionFactory) MakeSessionWithID( ctx context.Context, id string, identity *auth.Identity, - allowAnonymous bool, backends []*vmcp.Backend, ) (MultiSession, error) { if err := validateSessionID(id); err != nil { return nil, err } - - // Validate allowAnonymous is consistent with identity to prevent security footguns. - // If identity has a token, allowAnonymous must be false (caller wants a bound session). - // If identity is nil or has no token, allowAnonymous should be true (anonymous session). - if identity != nil && identity.Token != "" && allowAnonymous { - return nil, fmt.Errorf( - "invalid session configuration: cannot create anonymous session " + - "(allowAnonymous=true) with bearer token (identity.Token is non-empty)", - ) - } - if (identity == nil || identity.Token == "") && !allowAnonymous { - return nil, fmt.Errorf( - "invalid session configuration: cannot create bound session " + - "(allowAnonymous=false) without bearer token (identity is nil or has empty token)", - ) - } - return f.makeSession(ctx, id, identity, backends) } @@ -406,17 +349,15 @@ func populateBackendMetadata(transportSess transportsession.Session, results []i transportSess.SetMetadata(MetadataKeyBackendSessionPrefix+r.target.WorkloadID, sessID) } } - // Always write MetadataKeyBackendIDs, even for zero-backend sessions (""). - // This distinguishes an explicit zero-backend state from absent/corrupted metadata - // in RestoreSession, preventing filterBackendsByStoredIDs from silently - // falling back to all backends when the key is missing. + // Always write MetadataKeyBackendIDs — key presence distinguishes explicit + // zero-backend from absent/corrupted metadata (see const doc). transportSess.SetMetadata(MetadataKeyBackendIDs, strings.Join(ids, ",")) } // makeBaseSession initialises backends and assembles a defaultMultiSession -// WITHOUT applying the hijack-prevention security wrapper. +// WITHOUT applying the session-binding security wrapper. // Callers are responsible for wrapping the result with the appropriate decorator -// (PreventSessionHijacking for new sessions, RestoreHijackPrevention for restored ones). +// (BindSession for new sessions, RestoreSessionBinding for restored ones). func (f *defaultMultiSessionFactory) makeBaseSession( ctx context.Context, sessID string, @@ -488,9 +429,6 @@ func (f *defaultMultiSessionFactory) makeBaseSession( } transportSess := transportsession.NewStreamableSession(sessID) - if identity != nil && identity.Subject != "" { - transportSess.SetMetadata(MetadataKeyIdentitySubject, identity.Subject) - } populateBackendMetadata(transportSess, results) return &defaultMultiSession{ @@ -507,7 +445,7 @@ func (f *defaultMultiSessionFactory) makeBaseSession( } // makeSession is the shared implementation for MakeSession and MakeSessionWithID. -// It builds the base session via makeBaseSession, then applies the hijack-prevention +// It builds the base session via makeBaseSession, then applies the session-binding // security wrapper using the caller's identity. func (f *defaultMultiSessionFactory) makeSession( ctx context.Context, @@ -520,9 +458,10 @@ func (f *defaultMultiSessionFactory) makeSession( return nil, err } - // Apply hijack prevention: computes token binding, stores metadata, and wraps - // the session with validation logic. - decorated, err := security.PreventSessionHijacking(baseSession, f.hmacSecret, identity) + // Apply session binding: extracts the (iss, sub) identity tuple, stores it in + // session metadata under MetadataKeyIdentityBinding, and wraps the session with + // validation logic that checks every subsequent caller against that binding. + decorated, err := security.BindSession(baseSession, identity) if err != nil { _ = baseSession.Close() return nil, err @@ -532,8 +471,10 @@ func (f *defaultMultiSessionFactory) makeSession( // RestoreSession implements MultiSessionFactory. // It reconnects to the backends whose IDs are listed in storedMetadata, rebuilds -// the routing table, and reapplies the hijack-prevention decorator from the stored -// token hash and salt — without recomputing them from a (unavailable) token. +// the routing table, and reapplies the session-binding decorator from the stored +// identity binding. Because the original bearer token is not available at restore +// time, identity is reconstructed from the (iss, sub) tuple in +// MetadataKeyIdentityBinding. func (f *defaultMultiSessionFactory) RestoreSession( ctx context.Context, id string, @@ -557,13 +498,44 @@ func (f *defaultMultiSessionFactory) RestoreSession( // Filter allBackends to the subset originally connected in this session. filteredBackends := filterBackendsByStoredIDs(allBackends, storedBackendIDs) - // Reconstruct a minimal identity from stored metadata. The original bearer - // token is never persisted (only its HMAC-SHA256 hash is), so Token is empty. - // The security decorator is restored from the stored hash/salt below. + // Reconstruct identity from the stored identity binding so makeBaseSession + // can pass it to backend connectors (e.g. for outgoing auth injection). + // The original bearer token is not available at restore time — only the + // (iss, sub) tuple stored in MetadataKeyIdentityBinding is used. + storedBinding, hasBinding := storedMetadata[sessiontypes.MetadataKeyIdentityBinding] + if !hasBinding { + // Legacy token-hash key present confirms not corrupted — safe to invalidate. + if _, hasLegacy := storedMetadata[sessiontypes.MetadataKeyTokenHash]; hasLegacy { + slog.Warn("RestoreSession: legacy session missing identity binding; invalidating", + "reason", "legacy_session_missing_identity_binding", + ) + return nil, transportsession.ErrSessionNotFound + } + return nil, fmt.Errorf("RestoreSession: %q metadata key absent (corrupted session metadata)", + sessiontypes.MetadataKeyIdentityBinding) + } + + // Restore identity from the stored (iss, sub) binding. Token is intentionally + // empty — the original bearer token is not persisted. Outgoing-auth strategies + // must derive credentials from Claims or a token store keyed by tsid; a + // strategy that reads identity.Token will reject the request with an + // identity-has-no-token error. The anonymous sentinel (identity == nil) is + // handled by the IsUnauthenticated guard below. var identity *auth.Identity - if subject := storedMetadata[MetadataKeyIdentitySubject]; subject != "" { - identity = &auth.Identity{} - identity.Subject = subject + if !binding.IsUnauthenticated(storedBinding) { + iss, sub, ok := binding.Parse(storedBinding) + if !ok { + return nil, fmt.Errorf("RestoreSession: stored identity binding is malformed: %q", storedBinding) + } + identity = &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Subject: sub, + Claims: map[string]any{ + "iss": iss, + "sub": sub, + }, + }, + } } // Extract stored per-backend session IDs as hints so each backend can @@ -576,45 +548,23 @@ func (f *defaultMultiSessionFactory) RestoreSession( } // Build the base session (backend connections + routing table) without the - // security wrapper. The wrapper is applied separately using stored hash/salt. + // security wrapper. The wrapper is applied separately below. baseSession, err := f.makeBaseSession(ctx, id, identity, filteredBackends, sessionHints) if err != nil { return nil, fmt.Errorf("RestoreSession: failed to rebuild backend connections: %w", err) } - // Restore only the security keys (token hash and salt) from stored metadata. - // MetadataKeyIdentitySubject is already set by makeBaseSession via the - // reconstructed identity. MetadataKeyBackendIDs and the per-backend session - // keys (MetadataKeyBackendSessionPrefix.*) are freshly computed by - // makeBaseSession from the actual reconnected backends; overwriting them with - // stored values would make metadata inconsistent if any backend failed to - // reconnect during restore. - for _, key := range []string{ - sessiontypes.MetadataKeyTokenHash, - sessiontypes.MetadataKeyTokenSalt, - } { - if v, ok := storedMetadata[key]; ok { - baseSession.SetMetadata(key, v) - } - } + // Restore only the identity-binding key from stored metadata. The other + // keys (MetadataKeyBackendIDs, MetadataKeyBackendSessionPrefix.*) are + // freshly computed by makeBaseSession from the actual reconnected backends; + // overwriting them with stored values would make metadata inconsistent if + // any backend failed to reconnect during restore. + baseSession.SetMetadata(sessiontypes.MetadataKeyIdentityBinding, storedBinding) - // Recreate the hijack-prevention decorator using the stored hash and salt, - // not by recomputing from identity.Token (which is unavailable at restore time). - // - // Fail closed if the token-hash key is entirely absent from stored metadata: - // PreventSessionHijacking always writes the key (empty string for anonymous, - // non-empty for authenticated), so an absent key indicates corrupted or - // truncated metadata — not a legitimately anonymous session. - storedHash, hashKeyPresent := storedMetadata[sessiontypes.MetadataKeyTokenHash] - if !hashKeyPresent { - _ = baseSession.Close() - return nil, fmt.Errorf("RestoreSession: token hash metadata key absent (corrupted session metadata)") - } - storedSalt := storedMetadata[sessiontypes.MetadataKeyTokenSalt] - restored, err := security.RestoreHijackPrevention(baseSession, storedHash, storedSalt, f.hmacSecret) + restored, err := security.RestoreSessionBinding(baseSession, storedBinding) if err != nil { _ = baseSession.Close() - return nil, fmt.Errorf("RestoreSession: failed to restore hijack prevention: %w", err) + return nil, fmt.Errorf("RestoreSession: failed to restore session binding: %w", err) } return restored, nil } diff --git a/pkg/vmcp/session/factory_metadata_test.go b/pkg/vmcp/session/factory_metadata_test.go index 81f1ed50fc..4d26808cda 100644 --- a/pkg/vmcp/session/factory_metadata_test.go +++ b/pkg/vmcp/session/factory_metadata_test.go @@ -41,7 +41,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { {ID: "backend-a"}, {ID: "backend-b"}, } - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) meta := sess.GetMetadata() @@ -56,7 +56,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { t.Parallel() factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, nil) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, nil) require.NoError(t, err) meta := sess.GetMetadata() @@ -85,7 +85,7 @@ func TestMakeSession_PersistsBackendSessionIDs(t *testing.T) { {ID: "backend-ok"}, {ID: "backend-fail"}, } - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) meta := sess.GetMetadata() @@ -123,7 +123,7 @@ func TestRestoreSession_FreshlyPopulatesMetadataKeyBackendIDs(t *testing.T) { sessionID := "restore-test-session" // Create the initial session so we have a real token hash in metadata. - original, err := factory.MakeSessionWithID(t.Context(), sessionID, nil, true, backends) + original, err := factory.MakeSessionWithID(t.Context(), sessionID, nil, backends) require.NoError(t, err) t.Cleanup(func() { _ = original.Close() }) @@ -195,7 +195,7 @@ func TestRestoreSession_PassesStoredSessionHintToConnector(t *testing.T) { } // Create the original session — connector receives empty hints. - original, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, backends) + original, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, backends) require.NoError(t, err) t.Cleanup(func() { _ = original.Close() }) @@ -241,7 +241,7 @@ func TestMakeSession_PassesEmptySessionHintToConnector(t *testing.T) { } factory := newSessionFactoryWithConnector(connector) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, []*vmcp.Backend{{ID: "backend-a"}}) + sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, []*vmcp.Backend{{ID: "backend-a"}}) require.NoError(t, err) t.Cleanup(func() { _ = sess.Close() }) diff --git a/pkg/vmcp/session/identity_binding_test.go b/pkg/vmcp/session/identity_binding_test.go new file mode 100644 index 0000000000..0363e0bf59 --- /dev/null +++ b/pkg/vmcp/session/identity_binding_test.go @@ -0,0 +1,270 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "errors" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" + internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +// nilBackendConnector is a connector that returns (nil, nil, nil), causing the +// backend to be skipped during init. This lets us exercise session-metadata +// logic without real backend connections. +func nilBackendConnector() backendConnector { + return func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { + return nil, nil, nil + } +} + +// identityWithClaims builds an *auth.Identity whose Claims map is set verbatim +// from claims. Used in tests that need specific claim values without setting +// the Subject field (binding extraction reads only Claims["iss"] and Claims["sub"]). +func identityWithClaims(token string, claims map[string]any) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: claims}, + Token: token, + } +} + +func TestMakeSession_StoresIdentityBinding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identity *auth.Identity + wantBinding string + }{ + { + name: "authenticated_oidc", + identity: identityWithClaims("bearer-token", map[string]any{"iss": "https://idp.example", "sub": "user-42"}), + wantBinding: "https://idp.example\x00user-42", + }, + { + name: "nil_identity_anonymous", + identity: nil, + wantBinding: binding.UnauthenticatedSentinel, + }, + { + // LocalUserMiddleware sets Token="" and populates Claims with + // iss="toolhive-local" and sub=. + name: "local_user_shape", + identity: identityWithClaims("", map[string]any{"iss": "toolhive-local", "sub": "alice"}), + wantBinding: "toolhive-local\x00alice", + }, + { + // AnonymousMiddleware (dev-only) sets Token="" with iss="toolhive-local" + // and sub="anonymous". All such sessions share one binding — intentional. + name: "anonymous_middleware_shape", + identity: identityWithClaims("", map[string]any{"iss": "toolhive-local", "sub": "anonymous"}), + wantBinding: "toolhive-local\x00anonymous", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + sess, err := factory.MakeSessionWithID( + t.Context(), uuid.New().String(), tt.identity, nil, + ) + require.NoError(t, err) + require.NotNil(t, sess) + t.Cleanup(func() { _ = sess.Close() }) + + meta := sess.GetMetadata() + assert.Equal(t, tt.wantBinding, meta[MetadataKeyIdentityBinding], + "MetadataKeyIdentityBinding must equal expected binding") + }) + } +} + +// TestMakeSession_RejectsBoundSessionWithoutIdentifyingClaims verifies the +// ordering invariant from the factory through to BindSession: creating a session +// with an identity that carries no valid (iss, sub) pair returns an error from +// MakeSessionWithID. +func TestMakeSession_RejectsBoundSessionWithoutIdentifyingClaims(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + + // Token is present but Claims are empty, so BindSession's extractBindingID fails. + identity := identityWithClaims("x", map[string]any{}) + + _, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, nil) + require.Error(t, err, "session creation must fail when bound identity lacks identifying claims") +} + +func TestRestoreSession_ErrorCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storedMetadata map[string]string + wantNotFoundErr bool // true → must be ErrSessionNotFound; false → must NOT be + wantErrContains string + }{ + { + // A session carrying the legacy MetadataKeyTokenHash but no + // MetadataKeyIdentityBinding is invalidated with ErrSessionNotFound + // so the MCP client can re-initialize. + name: "legacy_token_hash_only", + storedMetadata: map[string]string{ + MetadataKeyBackendIDs: "", + sessiontypes.MetadataKeyTokenHash: "deadbeefdeadbeef", + }, + wantNotFoundErr: true, + }, + { + // Genuinely corrupted metadata (no binding, no legacy key) must + // NOT masquerade as a session-not-found; it is a distinct error. + name: "absent_identity_binding_key", + storedMetadata: map[string]string{ + MetadataKeyBackendIDs: "", + }, + wantNotFoundErr: false, + wantErrContains: "absent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + _, err := factory.RestoreSession(t.Context(), uuid.New().String(), tt.storedMetadata, nil) + require.Error(t, err) + + if tt.wantNotFoundErr { + require.True(t, errors.Is(err, transportsession.ErrSessionNotFound), + "legacy token-hash-only session must return ErrSessionNotFound") + } else { + assert.False(t, errors.Is(err, transportsession.ErrSessionNotFound), + "corrupted metadata must not return ErrSessionNotFound") + } + if tt.wantErrContains != "" { + assert.Contains(t, err.Error(), tt.wantErrContains) + } + }) + } +} + +// TestRestoreSession_PopulatesBothSubjectFieldAndClaims verifies that after +// RestoreSession the reconstructed identity's binding is stored and the +// decorator accepts a matching caller. +func TestRestoreSession_PopulatesBothSubjectFieldAndClaims(t *testing.T) { + t.Parallel() + + factory := newSessionFactoryWithConnector(nilBackendConnector()) + + const storedBinding = "https://idp.example\x00alice" + storedMetadata := map[string]string{ + MetadataKeyBackendIDs: "", + sessiontypes.MetadataKeyIdentityBinding: storedBinding, + } + + sess, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) + require.NoError(t, err) + require.NotNil(t, sess) + t.Cleanup(func() { _ = sess.Close() }) + + meta := sess.GetMetadata() + assert.Equal(t, storedBinding, meta[sessiontypes.MetadataKeyIdentityBinding]) + + // Call a tool with the expected identity to verify the decorator accepts it. + caller := identityWithClaims("any-token", map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }) + _, err = sess.CallTool(t.Context(), caller, "nonexistent", nil, nil) + // ErrToolNotFound is acceptable — it means auth passed. + if err != nil { + assert.ErrorIs(t, err, ErrToolNotFound, + "restored session must accept matching caller; any error must be ErrToolNotFound (not auth)") + } +} + +// TestRestoreSession_ReconstructsIdentityWithEmptyTokenButPopulatedClaims pins +// the contract that RestoreSession passes a reconstructed identity to the +// backendConnector whose Token field is intentionally empty (the bearer token is +// not persisted), but whose Claims["iss"] and Claims["sub"] are populated from +// the stored identity binding. +func TestRestoreSession_ReconstructsIdentityWithEmptyTokenButPopulatedClaims(t *testing.T) { + t.Parallel() + + const ( + origIss = "https://idp.example" + origSub = "carol" + ) + + var capturedIdentity *auth.Identity + capturingConnector := func( + _ context.Context, + _ *vmcp.BackendTarget, + id *auth.Identity, + _ string, + ) (internalbk.Session, *vmcp.CapabilityList, error) { + capturedIdentity = id + return nil, nil, nil + } + + // Step 1: create the original session with an authenticated identity. + originalIdentity := identityWithClaims("bearer-AT1", map[string]any{ + "iss": origIss, + "sub": origSub, + }) + + factory := newSessionFactoryWithConnector(capturingConnector) + multiSess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), originalIdentity, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = multiSess.Close() }) + + // Step 2: capture the persisted metadata (simulates what Redis would hold). + meta := multiSess.GetMetadata() + require.NotEmpty(t, meta[sessiontypes.MetadataKeyIdentityBinding], + "factory must write MetadataKeyIdentityBinding to metadata") + + capturedIdentity = nil + + // Step 3: restore the session on "Pod B" with a backend present so the + // connector is actually invoked. + backend := &vmcp.Backend{ + ID: "test-backend", + Name: "test-backend", + } + storedMeta := make(map[string]string, len(meta)+1) + for k, v := range meta { + storedMeta[k] = v + } + storedMeta[MetadataKeyBackendIDs] = backend.ID + + restored, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMeta, []*vmcp.Backend{backend}) + require.NoError(t, err) + t.Cleanup(func() { _ = restored.Close() }) + + // Step 4: the connector must receive an identity with an empty Token but + // populated Claims. Any outgoing-auth strategy that reads identity.Token + // will silently produce unauthenticated backend requests after a pod restart. + require.NotNil(t, capturedIdentity, "connector must be called with a non-nil identity for an authenticated session") + assert.Empty(t, capturedIdentity.Token, + "restored identity.Token must be empty — bearer token is not persisted across pod restarts") + assert.Equal(t, origIss, capturedIdentity.Claims["iss"], + "restored identity.Claims[iss] must match original issuer") + assert.Equal(t, origSub, capturedIdentity.Claims["sub"], + "restored identity.Claims[sub] must match original subject") + assert.Equal(t, origSub, capturedIdentity.Subject, + "restored identity.Subject must match original subject") +} diff --git a/pkg/vmcp/session/internal/security/hijack_prevention_test.go b/pkg/vmcp/session/internal/security/hijack_prevention_test.go index d0746ee030..4f5ad9b67c 100644 --- a/pkg/vmcp/session/internal/security/hijack_prevention_test.go +++ b/pkg/vmcp/session/internal/security/hijack_prevention_test.go @@ -5,22 +5,19 @@ package security import ( "context" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -var ( - // Test HMAC secret and salt for consistent test results - testSecret = []byte("test-secret") - testTokenSalt = []byte("test-salt-123456") // 16 bytes -) - // mockSession is a minimal implementation of MultiSession for testing. // It embeds the interface so only the methods exercised by tests need to be defined. type mockSession struct { @@ -56,204 +53,241 @@ func (*mockSession) GetPrompt(_ context.Context, _ *auth.Identity, _ string, _ m func (*mockSession) Close() error { return nil } -// TestValidateCaller_EdgeCases tests edge cases in caller validation logic. -func TestValidateCaller_EdgeCases(t *testing.T) { +// newDecoratedSession creates a mockSession wrapped with BindSession using the given identity. +func newDecoratedSession(t *testing.T, identity *auth.Identity) sessiontypes.MultiSession { + t.Helper() + base := newMockSession("test-session") + decorated, err := BindSession(base, identity) + require.NoError(t, err) + require.NotNil(t, decorated) + return decorated +} + +// authedIdentity builds an authenticated identity with the given issuer and subject. +// token is used as the raw bearer token (may be any non-empty string). +func authedIdentity(token, iss, sub string) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{ + "iss": iss, + "sub": sub, + }, + }, + Token: token, + } +} + +// identityWithClaims builds an *auth.Identity whose Claims map is set verbatim +// from claims. Used in tests that need malformed claim values (missing keys, +// non-string values, NUL bytes) that authedIdentity cannot express. +func identityWithClaims(token string, claims map[string]any) *auth.Identity { + return &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: claims}, + Token: token, + } +} + +// TestBindSession_AcceptsRefreshedTokenWithSameIdentity is the regression test +// for issue #5306. A caller presenting a new bearer token (refreshed access +// token) but the same (iss, sub) identity must be accepted because validation +// is now identity-based, not token-hash-based. +func TestBindSession_AcceptsRefreshedTokenWithSameIdentity(t *testing.T) { + t.Parallel() + + const iss = "https://idp.example" + const sub = "user-42" + + // Session created with AT1. + creator := authedIdentity("AT1", iss, sub) + decorated := newDecoratedSession(t, creator) + + // Subsequent call arrives with AT2 (refreshed token) but same (iss, sub). + refreshed := authedIdentity("AT2", iss, sub) + _, err := decorated.CallTool(context.Background(), refreshed, "tool", nil, nil) + require.NoError(t, err, "refreshed token with same identity must be accepted") +} + +func TestBindSession_RejectsInvalidIdentityAtCreation(t *testing.T) { t.Parallel() tests := []struct { - name string - allowAnonymous bool - boundTokenHash string - caller *auth.Identity - wantErr error + name string + identity *auth.Identity + }{ + {name: "missing_sub_claim", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example"})}, + {name: "missing_iss_claim", identity: identityWithClaims("tok", map[string]any{"sub": "alice"})}, + {name: "both_claims_empty", identity: identityWithClaims("tok", map[string]any{})}, + {name: "non_string_iss", identity: identityWithClaims("tok", map[string]any{"iss": 42, "sub": "alice"})}, + {name: "non_string_sub", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example", "sub": true})}, + {name: "nul_byte_in_iss", identity: identityWithClaims("tok", map[string]any{"iss": "bad\x00iss", "sub": "alice"})}, + {name: "nul_byte_in_sub", identity: identityWithClaims("tok", map[string]any{"iss": "https://idp.example", "sub": "bad\x00sub"})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + base := &metadataObservingSession{mockSession: newMockSession("test-session")} + decorated, err := BindSession(base, tt.identity) + + require.Error(t, err) + assert.Nil(t, decorated) + assert.False(t, base.setMetadataCalled, "SetMetadata must not be called if binding extraction fails") + }) + } +} + +func TestBindSession_RejectsMismatchedCaller(t *testing.T) { + t.Parallel() + + const boundIss = "https://idp.example" + const boundSub = "alice" + + tests := []struct { + name string + caller *auth.Identity + }{ + {name: "different_sub", caller: authedIdentity("tok2", boundIss, "bob")}, + {name: "different_iss", caller: authedIdentity("tok2", "https://idp-b.example", boundSub)}, + {name: "missing_iss_claim", caller: identityWithClaims("tok2", map[string]any{"sub": boundSub})}, + {name: "missing_sub_claim", caller: identityWithClaims("tok2", map[string]any{"iss": boundIss})}, + {name: "both_claims_empty", caller: identityWithClaims("tok2", map[string]any{})}, + {name: "non_string_iss_claim", caller: identityWithClaims("tok2", map[string]any{"iss": []string{boundIss}, "sub": boundSub})}, + {name: "non_string_sub_claim", caller: identityWithClaims("tok2", map[string]any{"iss": boundIss, "sub": 12345})}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + decorated := newDecoratedSession(t, authedIdentity("tok", boundIss, boundSub)) + _, err := decorated.CallTool(context.Background(), tt.caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }) + } +} + +func TestBindSession_AnonymousSession(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + assertFn func(t *testing.T, decorated sessiontypes.MultiSession) }{ { - name: "anonymous session with nil caller", - allowAnonymous: true, - boundTokenHash: "", - caller: nil, - wantErr: nil, // Should succeed - }, - { - name: "anonymous session rejects caller with token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, // Prevent session upgrade attack - }, - { - name: "bound session with nil caller", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: nil, - wantErr: sessiontypes.ErrNilCaller, - }, - { - name: "bound session with matching token", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "correct-token"}, - wantErr: nil, // Should succeed - }, - { - name: "bound session with wrong token", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "wrong-token"}, - wantErr: sessiontypes.ErrUnauthorizedCaller, - }, - { - name: "bound session with empty token in identity", - allowAnonymous: false, - boundTokenHash: hashToken("correct-token", testSecret, testTokenSalt), - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: sessiontypes.ErrUnauthorizedCaller, + name: "nil_identity_stores_sentinel", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + meta := decorated.GetMetadata() + assert.Equal(t, binding.UnauthenticatedSentinel, meta[sessiontypes.MetadataKeyIdentityBinding], + "anonymous session must store UnauthenticatedSentinel in metadata") + }, }, { - name: "anonymous session accepts caller with empty token", - allowAnonymous: true, - boundTokenHash: "", - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: nil, // Empty token is equivalent to no token + name: "rejects_caller_presenting_token", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + caller := &auth.Identity{Token: "some-token"} + _, err := decorated.CallTool(context.Background(), caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }, }, { - name: "misconfigured bound session with empty hash rejects empty token", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""}, - wantErr: sessiontypes.ErrSessionOwnerUnknown, // Fail closed + name: "accepts_nil_caller", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + _, err := decorated.CallTool(context.Background(), nil, "tool", nil, nil) + require.NoError(t, err) + }, }, { - name: "misconfigured bound session with empty hash rejects nil caller", - allowAnonymous: false, - boundTokenHash: "", // Misconfiguration: bound but no hash - caller: nil, - wantErr: sessiontypes.ErrNilCaller, // Nil check happens first + name: "accepts_caller_with_empty_token", + assertFn: func(t *testing.T, decorated sessiontypes.MultiSession) { + t.Helper() + caller := &auth.Identity{Token: ""} + _, err := decorated.CallTool(context.Background(), caller, "tool", nil, nil) + require.NoError(t, err) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - - // Create a base session - baseSession := newMockSession("test-session") - - // Wrap with decorator that has the test configuration - decorator := &hijackPreventionDecorator{ - MultiSession: baseSession, - allowAnonymous: tt.allowAnonymous, - boundTokenHash: tt.boundTokenHash, - tokenSalt: testTokenSalt, - hmacSecret: testSecret, - } - - ctx := context.Background() - - // Test all three decorated methods to verify validation is integrated correctly - toolResult, errCallTool := decorator.CallTool(ctx, tt.caller, "test-tool", nil, nil) - resourceResult, errReadResource := decorator.ReadResource(ctx, tt.caller, "test://uri") - promptResult, errGetPrompt := decorator.GetPrompt(ctx, tt.caller, "test-prompt", nil) - - if tt.wantErr != nil { - require.ErrorIs(t, errCallTool, tt.wantErr) - require.ErrorIs(t, errReadResource, tt.wantErr) - require.ErrorIs(t, errGetPrompt, tt.wantErr) - assert.Nil(t, toolResult) - assert.Nil(t, resourceResult) - assert.Nil(t, promptResult) - } else { - require.NoError(t, errCallTool) - require.NoError(t, errReadResource) - require.NoError(t, errGetPrompt) - assert.NotNil(t, toolResult) - assert.NotNil(t, resourceResult) - assert.NotNil(t, promptResult) - } + base := newMockSession("test-session") + decorated, err := BindSession(base, nil) + require.NoError(t, err) + require.NotNil(t, decorated) + tt.assertFn(t, decorated) }) } } -// TestPreventSessionHijacking_NilSession tests that a nil session is rejected before any method call. -func TestPreventSessionHijacking_NilSession(t *testing.T) { +func TestBindSession_NilSession(t *testing.T) { t.Parallel() - decorated, err := PreventSessionHijacking(nil, testSecret, &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"}) + identity := authedIdentity("tok", "https://idp.example", "alice") + decorated, err := BindSession(nil, identity) require.Error(t, err) assert.Nil(t, decorated) } -// TestPreventSessionHijacking_BasicFunctionality tests the main entry point. -func TestPreventSessionHijacking_BasicFunctionality(t *testing.T) { +func TestBindSession_BoundRejectsNilCaller(t *testing.T) { t.Parallel() - t.Run("authenticated session", func(t *testing.T) { - t.Parallel() - - baseSession := newMockSession("test-session") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - - decorated, err := PreventSessionHijacking(baseSession, testSecret, identity) - require.NoError(t, err) - require.NotNil(t, decorated) - - // Verify metadata was set (no cast needed - returns concrete type) - metadata := decorated.GetMetadata() - assert.NotEmpty(t, metadata[metadataKeyTokenHash]) - assert.NotEmpty(t, metadata[metadataKeyTokenSalt]) - }) + decorated := newDecoratedSession(t, authedIdentity("tok", "https://idp.example", "alice")) - t.Run("anonymous session", func(t *testing.T) { - t.Parallel() - - baseSession := newMockSession("test-session") - - decorated, err := PreventSessionHijacking(baseSession, testSecret, nil) - require.NoError(t, err) - require.NotNil(t, decorated) - - // Verify metadata was set (empty for anonymous, no cast needed) - metadata := decorated.GetMetadata() - assert.Empty(t, metadata[metadataKeyTokenHash]) - assert.Empty(t, metadata[metadataKeyTokenSalt]) - }) + _, err := decorated.CallTool(context.Background(), nil, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrNilCaller) } -// TestRestoreHijackPrevention tests restoration of the hijack-prevention decorator. -func TestRestoreHijackPrevention(t *testing.T) { +// TestBindSession_ConcurrentRefreshRace verifies that two goroutines calling +// CallTool concurrently with different bearer tokens but the same (iss, sub) +// both succeed. This tests the core fix for issue #5306. +func TestBindSession_ConcurrentRefreshRace(t *testing.T) { t.Parallel() - t.Run("anonymous session (empty hash and salt)", func(t *testing.T) { - t.Parallel() + const iss = "https://idp.example" + const sub = "alice" - base := newMockSession("s1") - restored, err := RestoreHijackPrevention(base, "", "", testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - }) + decorated := newDecoratedSession(t, authedIdentity("AT1", iss, sub)) - t.Run("hash present but salt absent is rejected", func(t *testing.T) { - t.Parallel() + const goroutines = 20 - base := newMockSession("s2") - _, err := RestoreHijackPrevention(base, "somehash", "", testSecret) - require.Error(t, err) - assert.Contains(t, err.Error(), "salt is missing") - }) + errs := make([]error, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + for i := range goroutines { + go func(i int) { + defer wg.Done() + // Each goroutine uses a distinct token string but the same identity. + caller := authedIdentity("refreshed-token-"+string(rune('A'+i%26)), iss, sub) + _, errs[i] = decorated.CallTool(context.Background(), caller, "tool", nil, nil) + }(i) + } - t.Run("salt present but hash absent is rejected", func(t *testing.T) { - t.Parallel() + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for concurrent CallTool goroutines") + } - base := newMockSession("s3") - _, err := RestoreHijackPrevention(base, "", "deadbeef", testSecret) - require.Error(t, err) - assert.Contains(t, err.Error(), "hash is missing") - }) + for i, err := range errs { + assert.NoError(t, err, "goroutine %d must succeed with same identity and different token", i) + } +} - t.Run("nil session is rejected", func(t *testing.T) { - t.Parallel() +// metadataObservingSession wraps mockSession and records whether SetMetadata +// was ever called. Used in tests that assert SetMetadata is NOT called before +// a binding is validated. +type metadataObservingSession struct { + *mockSession + setMetadataCalled bool +} - _, err := RestoreHijackPrevention(nil, "", "", testSecret) - require.Error(t, err) - }) +func (m *metadataObservingSession) SetMetadata(key, value string) { + m.setMetadataCalled = true + m.mockSession.SetMetadata(key, value) } diff --git a/pkg/vmcp/session/internal/security/restore_test.go b/pkg/vmcp/session/internal/security/restore_test.go index 03de0f5a62..c9f98854ea 100644 --- a/pkg/vmcp/session/internal/security/restore_test.go +++ b/pkg/vmcp/session/internal/security/restore_test.go @@ -5,126 +5,99 @@ package security import ( "context" - "encoding/hex" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -func TestRestoreHijackPrevention_NilSession(t *testing.T) { +func TestRestoreSessionBinding(t *testing.T) { t.Parallel() - restored, err := RestoreHijackPrevention(nil, "somehash", hex.EncodeToString(testTokenSalt), testSecret) - require.Error(t, err) - assert.Nil(t, restored) + tests := []struct { + name string + storedBinding string + wantErr bool + // checkFn is called with the restored session when wantErr is false. + // Leave nil to skip behavioral assertions. + checkFn func(t *testing.T, restored sessiontypes.MultiSession) + }{ + { + name: "bound_round_trip_accepts_matching_caller", + storedBinding: "https://idp.example\x00alice", + checkFn: func(t *testing.T, restored sessiontypes.MultiSession) { + t.Helper() + matchingCaller := identityWithClaims("some-token", map[string]any{ + "iss": "https://idp.example", + "sub": "alice", + }) + _, err := restored.CallTool(context.Background(), matchingCaller, "tool", nil, nil) + require.NoError(t, err, "matching (iss, sub) caller must be accepted after restore") + + intruder := identityWithClaims("other-token", map[string]any{ + "iss": "https://idp.example", + "sub": "bob", + }) + _, err = restored.CallTool(context.Background(), intruder, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) + }, + }, + { + name: "unauthenticated_sentinel_accepts_nil_rejects_token", + storedBinding: binding.UnauthenticatedSentinel, + checkFn: func(t *testing.T, restored sessiontypes.MultiSession) { + t.Helper() + _, err := restored.CallTool(context.Background(), nil, "tool", nil, nil) + require.NoError(t, err, "nil caller must be accepted for anonymous sessions") + + caller := &auth.Identity{Token: "some-token"} + _, err = restored.CallTool(context.Background(), caller, "tool", nil, nil) + require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, + "caller presenting a token must be rejected (session-upgrade attack prevention)") + }, + }, + { + name: "corrupted_binding_no_nul", + storedBinding: "no-nul-here", + wantErr: true, + }, + { + name: "empty_string_rejected", + storedBinding: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + base := newMockSession("sess") + restored, err := RestoreSessionBinding(base, tt.storedBinding) + + if tt.wantErr { + require.Error(t, err) + assert.Nil(t, restored) + return + } + + require.NoError(t, err) + require.NotNil(t, restored) + if tt.checkFn != nil { + tt.checkFn(t, restored) + } + }) + } } -func TestRestoreHijackPrevention_MissingSalt(t *testing.T) { +func TestRestoreSessionBinding_NilSession(t *testing.T) { t.Parallel() - // Non-empty tokenHash with empty tokenSaltHex is malformed state. - base := newMockSession("sess") - restored, err := RestoreHijackPrevention(base, "nonemptyhash", "", testSecret) + restored, err := RestoreSessionBinding(nil, binding.UnauthenticatedSentinel) require.Error(t, err) assert.Nil(t, restored) } - -func TestRestoreHijackPrevention_InvalidSaltHex(t *testing.T) { - t.Parallel() - - base := newMockSession("sess") - restored, err := RestoreHijackPrevention(base, "nonemptyhash", "gg", testSecret) - require.Error(t, err) - assert.Nil(t, restored) -} - -func TestRestoreHijackPrevention_AnonymousSession(t *testing.T) { - t.Parallel() - - base := newMockSession("sess") - // tokenHash="" and tokenSaltHex="" → anonymous. - restored, err := RestoreHijackPrevention(base, "", "", testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - - ctx := context.Background() - - // Nil caller is accepted. - _, err = restored.CallTool(ctx, nil, "tool", nil, nil) - require.NoError(t, err) - - // Caller presenting a token is rejected (session upgrade attack prevention). - caller := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "u"}, Token: "t"} - _, err = restored.CallTool(ctx, caller, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) -} - -func TestRestoreHijackPrevention_AuthenticatedRoundTrip(t *testing.T) { - t.Parallel() - - // --- "Pod A": create a session, persist hash+salt from metadata. --- - base := newMockSession("sess") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "bearer-token"} - - created, err := PreventSessionHijacking(base, testSecret, identity) - require.NoError(t, err) - - meta := created.GetMetadata() - persistedHash := meta[metadataKeyTokenHash] - persistedSalt := meta[metadataKeyTokenSalt] - require.NotEmpty(t, persistedHash, "tokenHash must be persisted") - require.NotEmpty(t, persistedSalt, "tokenSalt must be persisted") - - // --- "Pod B": restore decorator from persisted values. --- - base2 := newMockSession("sess") - restored, err := RestoreHijackPrevention(base2, persistedHash, persistedSalt, testSecret) - require.NoError(t, err) - require.NotNil(t, restored) - - ctx := context.Background() - - // Original token is accepted. - _, err = restored.CallTool(ctx, identity, "tool", nil, nil) - require.NoError(t, err) - - // A different token is rejected. - other := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "wrong-token"} - _, err = restored.CallTool(ctx, other, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller) - - // Nil caller is rejected for a bound session. - _, err = restored.CallTool(ctx, nil, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrNilCaller) -} - -func TestRestoreHijackPrevention_CrossReplicaSecretMismatch(t *testing.T) { - t.Parallel() - - // Pod A creates with secretA. - secretA := []byte("secret-A") - secretB := []byte("secret-B") - - base := newMockSession("sess") - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "token"} - - created, err := PreventSessionHijacking(base, secretA, identity) - require.NoError(t, err) - - meta := created.GetMetadata() - persistedHash := meta[metadataKeyTokenHash] - persistedSalt := meta[metadataKeyTokenSalt] - - // Pod B restores with a different secretB — token validation must fail. - base2 := newMockSession("sess") - restored, err := RestoreHijackPrevention(base2, persistedHash, persistedSalt, secretB) - require.NoError(t, err) // Construction succeeds; mismatch only shows at validation time. - - ctx := context.Background() - _, err = restored.CallTool(ctx, identity, "tool", nil, nil) - require.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, - "cross-replica secret mismatch must reject the original token") -} diff --git a/pkg/vmcp/session/internal/security/security.go b/pkg/vmcp/session/internal/security/security.go index 358c4a0bba..ff17881488 100644 --- a/pkg/vmcp/session/internal/security/security.go +++ b/pkg/vmcp/session/internal/security/security.go @@ -1,117 +1,123 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -// Package security provides cryptographic utilities for session token binding -// and hijacking prevention. It handles HMAC-SHA256 token hashing, salt generation, -// and constant-time comparison to prevent timing attacks. +// Package security provides the session-hijack-prevention decorator for +// vMCP sessions. It binds a session to a stable identity tuple (iss, sub) +// extracted from the OIDC identity that created the session, and validates +// that every subsequent request comes from the same identity. +// +// Session bindings are stored as plaintext at rest in the session metadata +// (see pkg/vmcp/session/binding for the format and trust-boundary statement). +// The binding is NOT a credential; it identifies but does not authenticate. +// Callers must validate the request's token independently before passing the +// resulting *auth.Identity to validateCaller. package security import ( "context" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/hex" + "crypto/subtle" "fmt" "log/slog" "github.com/stacklok/toolhive/pkg/auth" - pkgsecurity "github.com/stacklok/toolhive/pkg/security" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) -const ( - // SHA256HexLen is the length of a hex-encoded SHA256 hash (32 bytes = 64 hex characters) - SHA256HexLen = 64 - - // metadataKeyTokenHash is the session metadata key for the token hash. - // Imported from types package to ensure consistency across all packages. - metadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash - - // metadataKeyTokenSalt is the session metadata key for the token salt. - // Imported from types package to ensure consistency across all packages. - metadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt -) - -// generateSalt generates a cryptographically secure random salt for token hashing. -// Returns 16 bytes of random data from crypto/rand. +// sessionBindingDecorator wraps a session and adds identity-binding validation +// to prevent session hijacking attacks. It validates that all requests come from +// the same identity that created the session. // -// Each session should have a unique salt to provide additional entropy and prevent -// attacks that work across multiple sessions. -func generateSalt() ([]byte, error) { - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return nil, fmt.Errorf("failed to generate salt: %w", err) - } - return salt, nil +// The decorator is applied by BindSession to ALL sessions (both authenticated +// and anonymous). For authenticated sessions, it validates the caller's (iss, sub) +// identity binding matches the creator's binding. For anonymous sessions +// (allowAnonymous=true), it allows nil callers and prevents session upgrade +// attacks by rejecting any token presentation. +// +// The decorator embeds MultiSession and only overrides the methods that require +// validation (CallTool, ReadResource, GetPrompt). All other methods are +// automatically delegated to the embedded session. +type sessionBindingDecorator struct { + sessiontypes.MultiSession // embedded — automatic delegation for unwrapped methods + + // boundIdentity is the canonical identity binding written at session + // creation. Immutable after construction. + // + // For sessions allowed to be anonymous, boundIdentity is + // binding.UnauthenticatedSentinel. + // For sessions bound to an authenticated identity, boundIdentity is the + // output of binding.Format(iss, sub). + boundIdentity string + + // allowAnonymous tracks whether the session was created without a bound + // identity. Used to reject session-upgrade attacks (caller presents a + // token on an anonymous session). + allowAnonymous bool } -// hashToken returns the hex-encoded HMAC-SHA256 hash of a raw bearer token string. -// Uses HMAC with a server-managed secret and per-session salt to prevent offline -// attacks if session storage is compromised. +// extractBindingID derives the canonical identity-binding string from the +// given auth identity's OIDC claims. It reads "iss" and "sub" from +// identity.Claims (not identity.Subject) so that JWT-validation and +// introspection paths canonicalize against the same source. // -// For empty tokens (anonymous sessions) it returns the empty string, which is -// the sentinel value used to identify sessions created without credentials. -// The raw token is never stored — only the hash. +// Returns ("", error) when: +// - identity is nil. +// - identity.Claims is missing "iss" or "sub". +// - either claim is present but not a string. +// - binding.Format rejects the (iss, sub) pair (empty halves or stray NULs). // -// Parameters: -// - token: The bearer token to hash -// - secret: Server-managed HMAC secret (should be 32+ bytes) -// - salt: Per-session random salt (typically 16 bytes) +// Callers MUST treat a non-nil error as "no identifying claims available" +// and fail closed (do not silently fall through to anonymous). // -// Security: Uses HMAC-SHA256 instead of plain SHA256 to prevent rainbow table -// attacks and offline brute force if session state leaks from Redis/Valkey. -func hashToken(token string, secret, salt []byte) string { - if token == "" { - return "" +// TODO(#5306-followup): if/when RFC 7662 introspection becomes a top-level +// incoming-auth type, add a startup probe that verifies the IdP emits iss +// and sub in introspection responses (extractBindingID requires both). +func extractBindingID(identity *auth.Identity) (string, error) { + if identity == nil { + return "", fmt.Errorf("auth identity is nil") } - h := hmac.New(sha256.New, secret) - h.Write(salt) - h.Write([]byte(token)) - return hex.EncodeToString(h.Sum(nil)) -} -// hijackPreventionDecorator wraps a session and adds token binding validation -// to prevent session hijacking attacks. It validates that all requests come from -// the same identity that created the session. -// -// The decorator is applied by PreventSessionHijacking to ALL sessions (both authenticated -// and anonymous). For authenticated sessions, it validates the caller's token matches -// the creator's token. For anonymous sessions (allowAnonymous=true), it allows nil -// callers and prevents session upgrade attacks by rejecting any token presentation. -// -// The decorator embeds MultiSession and only overrides the methods that require -// validation (CallTool, ReadResource, GetPrompt). All other methods are automatically -// delegated to the embedded session. -type hijackPreventionDecorator struct { - sessiontypes.MultiSession // Embedded interface - provides automatic delegation for most methods + issRaw, issPresent := identity.Claims["iss"] + if !issPresent { + return "", fmt.Errorf("auth identity is missing iss claim") + } + iss, issIsString := issRaw.(string) + if !issIsString { + return "", fmt.Errorf("auth identity has non-string iss claim") + } + + subRaw, subPresent := identity.Claims["sub"] + if !subPresent { + return "", fmt.Errorf("auth identity is missing sub claim") + } + sub, subIsString := subRaw.(string) + if !subIsString { + return "", fmt.Errorf("auth identity has non-string sub claim") + } - // Token binding fields: enforce that subsequent requests come from the same - // identity that created the session. - // These fields are immutable after decorator creation (no mutex needed). - boundTokenHash string // HMAC-SHA256 hash of creator's token (empty for anonymous) - tokenSalt []byte // Random salt used for HMAC (empty for anonymous) - hmacSecret []byte // Server-managed secret for HMAC-SHA256 - allowAnonymous bool // Whether to allow nil caller + b, err := binding.Format(iss, sub) + if err != nil { + return "", fmt.Errorf("auth identity (iss, sub) pair is invalid: %w", err) + } + return b, nil } -// validateCaller checks if the provided caller identity matches the session owner. -// Returns nil if validation succeeds, or an error if: -// - The session requires a bound identity but caller is nil (ErrNilCaller) -// - The caller's token hash doesn't match the session owner (ErrUnauthorizedCaller) -// - An anonymous session receives a caller with a non-empty token (ErrUnauthorizedCaller) +// validateCaller checks the caller against the session's bound identity. // -// For anonymous sessions (allowAnonymous=true, boundTokenHash=""), validation succeeds -// only when the caller is nil or has an empty token (prevents session upgrade attacks). -func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { - // No lock needed - token binding fields are immutable after decorator creation - - // Anonymous sessions: reject callers that present tokens - if d.allowAnonymous && d.boundTokenHash == "" { - // Prevent session upgrade attack: anonymous sessions cannot accept tokens +// Returns: +// - ErrSessionOwnerUnknown when the decorator was constructed with neither +// an anonymous marker nor a real binding (programming error). +// - ErrNilCaller when a bound session receives nil. +// - ErrUnauthorizedCaller when: +// - an anonymous session receives a caller presenting a token (upgrade attack), +// - the caller's identity binding does not match the session's bound identity. +func (d sessionBindingDecorator) validateCaller(caller *auth.Identity) error { + // Anonymous path: sessions that were created without a bound identity. + if d.allowAnonymous && binding.IsUnauthenticated(d.boundIdentity) { + // Prevent session upgrade attack: anonymous sessions cannot accept tokens. if caller != nil && caller.Token != "" { - slog.Warn("token validation failed: session upgrade attack prevented", + slog.Warn("identity binding validation failed: session upgrade attack prevented", "reason", "token_presented_to_anonymous_session", ) return sessiontypes.ErrUnauthorizedCaller @@ -119,32 +125,43 @@ func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { return nil } - // Bound sessions require a caller + // Bound sessions require a non-nil caller. if caller == nil { - slog.Warn("token validation failed: nil caller for bound session", + slog.Warn("identity binding validation failed: nil caller for bound session", "reason", "nil_caller", ) return sessiontypes.ErrNilCaller } - // Defensive check: bound sessions must have a non-empty token hash. - // This prevents misconfigured sessions from accepting empty tokens. - // Scenario: if boundTokenHash="" and caller.Token="", both would hash to "", - // and ConstantTimeHashCompare would return true (both empty case). - if d.boundTokenHash == "" { - slog.Error("token validation failed: bound session has empty token hash", - "reason", "misconfigured_session", - ) - return sessiontypes.ErrSessionOwnerUnknown + // Defensive check: the stored binding must be parseable. An unparseable value + // means the session was misconfigured at construction time — fail closed + // rather than accepting or rejecting based on garbage state. + if !binding.IsUnauthenticated(d.boundIdentity) { + if _, _, ok := binding.Parse(d.boundIdentity); !ok { + slog.Error("identity binding validation failed: stored binding is not parseable", + "reason", "misconfigured_session", + ) + return sessiontypes.ErrSessionOwnerUnknown + } } - // Compute caller's token hash using the same HMAC secret and salt - callerHash := hashToken(caller.Token, d.hmacSecret, d.tokenSalt) + // Compute the caller's binding from their identity claims. + callerBinding, err := extractBindingID(caller) + if err != nil { + slog.Warn("identity binding validation failed: could not extract caller binding", + "reason", "caller_binding_extraction_failed", + "error", err, + ) + return sessiontypes.ErrUnauthorizedCaller + } - // Constant-time comparison to prevent timing attacks - if !pkgsecurity.ConstantTimeHashCompare(d.boundTokenHash, callerHash, SHA256HexLen) { - slog.Warn("token validation failed: token hash mismatch", - "reason", "token_hash_mismatch", + // ConstantTimeCompare is constant-time over content but short-circuits on + // length mismatch. Leaking binding length is acceptable: iss is the OIDC + // issuer (public, in the discovery document) and sub is an opaque + // identifier whose length is typically per-issuer. Neither is secret. + if subtle.ConstantTimeCompare([]byte(d.boundIdentity), []byte(callerBinding)) != 1 { + slog.Warn("identity binding validation failed: identity binding mismatch", + "reason", "identity_binding_mismatch", ) return sessiontypes.ErrUnauthorizedCaller } @@ -153,14 +170,13 @@ func (d hijackPreventionDecorator) validateCaller(caller *auth.Identity) error { } // CallTool validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) CallTool( +func (d sessionBindingDecorator) CallTool( ctx context.Context, caller *auth.Identity, toolName string, arguments map[string]any, meta map[string]any, ) (*vmcp.ToolCallResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -169,12 +185,11 @@ func (d hijackPreventionDecorator) CallTool( } // ReadResource validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) ReadResource( +func (d sessionBindingDecorator) ReadResource( ctx context.Context, caller *auth.Identity, uri string, ) (*vmcp.ResourceReadResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -183,13 +198,12 @@ func (d hijackPreventionDecorator) ReadResource( } // GetPrompt validates the caller identity before delegating to the embedded session. -func (d hijackPreventionDecorator) GetPrompt( +func (d sessionBindingDecorator) GetPrompt( ctx context.Context, caller *auth.Identity, name string, arguments map[string]any, ) (*vmcp.PromptGetResult, error) { - // Validate caller identity if err := d.validateCaller(caller); err != nil { return nil, err } @@ -197,148 +211,101 @@ func (d hijackPreventionDecorator) GetPrompt( return d.MultiSession.GetPrompt(ctx, caller, name, arguments) } -// RestoreHijackPrevention recreates the hijack-prevention decorator from persisted -// metadata, rather than recomputing token binding from an identity. Use this when -// reconstructing a MultiSession after a pod restart or cross-pod failover where the -// original bearer token is no longer available but the stored hash and salt are. +// BindSession wraps a session with identity-binding validation. It writes +// the canonical binding into session metadata under MetadataKeyIdentityBinding +// and returns a decorator that validates the caller on every operation. +// +// Whether the session is anonymous is derived from the identity via +// types.ShouldAllowAnonymous. +// +// For bound sessions, the (iss, sub) tuple is extracted from identity.Claims. +// If the tuple cannot be extracted (missing claims, non-string claims, or +// invalid format), BindSession returns an error BEFORE writing anything to +// the session metadata — preserving the invariant that a session always has +// a valid MetadataKeyIdentityBinding value once written. // -// If tokenHash is empty the session is treated as anonymous (allowAnonymous=true). -// The hmacSecret must be the same server-managed secret used at creation time. -func RestoreHijackPrevention( +// Returns an error if session is nil, or if a bound identity is required +// but no valid (iss, sub) binding can be produced. +func BindSession( session sessiontypes.MultiSession, - tokenHash string, - tokenSaltHex string, - hmacSecret []byte, + identity *auth.Identity, ) (sessiontypes.MultiSession, error) { if session == nil { return nil, fmt.Errorf("session must not be nil") } - // Both fields must be either both present or both absent. Any other - // combination indicates corrupted or incomplete metadata and must be - // rejected to fail closed: - // - hash present, salt absent: HMAC comparison will always fail, - // producing a silently broken (always-rejecting) decorator. - // - hash absent, salt present: session would be treated as anonymous, - // silently downgrading a bound session and bypassing token validation. - if tokenHash != "" && tokenSaltHex == "" { - return nil, fmt.Errorf("RestoreHijackPrevention: stored token hash is present but salt is missing " + - "(incomplete session metadata)") - } - if tokenHash == "" && tokenSaltHex != "" { - return nil, fmt.Errorf("RestoreHijackPrevention: stored token salt is present but hash is missing " + - "(incomplete session metadata)") - } - - allowAnonymous := tokenHash == "" + allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - var tokenSalt []byte - if tokenSaltHex != "" { - var decErr error - tokenSalt, decErr = hex.DecodeString(tokenSaltHex) - if decErr != nil { - return nil, fmt.Errorf("failed to decode stored token salt: %w", decErr) + // Determine the binding value to write. For bound sessions, extract the + // (iss, sub) binding before touching session metadata — if extraction fails, + // we must not leave a partial or sentinel value in the session. + bid, err := func() (string, error) { + if allowAnonymous { + return binding.UnauthenticatedSentinel, nil + } + b, err := extractBindingID(identity) + if err != nil { + return "", fmt.Errorf("BindSession: cannot derive identity binding: %w", err) } + return b, nil + }() + if err != nil { + return nil, err } - // Make defensive copies to prevent external mutation after construction. - var hmacSecretCopy, tokenSaltCopy []byte - if len(hmacSecret) > 0 { - hmacSecretCopy = append([]byte(nil), hmacSecret...) - } - if len(tokenSalt) > 0 { - tokenSaltCopy = append([]byte(nil), tokenSalt...) - } + // Write the resolved binding. This is the only metadata key written here; + // backend IDs and per-backend session keys are written by makeBaseSession. + session.SetMetadata(sessiontypes.MetadataKeyIdentityBinding, bid) - return &hijackPreventionDecorator{ + return &sessionBindingDecorator{ MultiSession: session, + boundIdentity: bid, allowAnonymous: allowAnonymous, - hmacSecret: hmacSecretCopy, - boundTokenHash: tokenHash, - tokenSalt: tokenSaltCopy, }, nil } -// PreventSessionHijacking wraps a session with hijack prevention security measures. -// It computes token binding hashes, stores them in session metadata, and returns -// a decorated session that validates caller identity on every operation. -// -// Whether the session is anonymous is derived from the identity: nil identity or -// empty token means anonymous, a non-empty token means bound/authenticated. +// RestoreSessionBinding recreates the session-binding decorator from a +// persisted binding string read out of session metadata. Use this when +// reconstructing a MultiSession after a pod restart or cross-pod failover. // -// For authenticated sessions (identity.Token != ""): -// - Generates a unique random salt -// - Computes HMAC-SHA256 hash of the bearer token -// - Stores hash and salt in session metadata -// - Returns decorator that validates every request against the creator's token +// This function is the symmetric counterpart of BindSession for the restore +// path. It is invoked by the session factory after RestoreSession deserializes +// the binding from session metadata. Unlike BindSession, it does NOT write +// metadata — the factory has already restored the metadata layer separately. // -// For anonymous sessions (identity == nil or identity.Token == ""): -// - Stores an empty string sentinel for the token hash metadata key -// - Omits the salt metadata key entirely (no salt is generated for anonymous sessions) -// - Returns decorator that allows nil callers and rejects token presentation +// storedBinding must be either: +// - binding.UnauthenticatedSentinel (the session was anonymous), or +// - a valid bound binding (binding.Parse returns ok). // -// Security: -// - Makes defensive copies of secret and salt to prevent external mutation -// - Uses constant-time comparison to prevent timing attacks -// - Prevents session upgrade attacks (anonymous → authenticated) -// - Raw tokens are never stored, only HMAC-SHA256 hashes -// -// Returns an error if: -// - session is nil -// - salt generation fails -func PreventSessionHijacking( +// Anything else (empty string, malformed value) is rejected as corrupted +// metadata. +func RestoreSessionBinding( session sessiontypes.MultiSession, - hmacSecret []byte, - identity *auth.Identity, + storedBinding string, ) (sessiontypes.MultiSession, error) { if session == nil { return nil, fmt.Errorf("session must not be nil") } - allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) - - // Note: Pass-through methods (ID, Type, CreatedAt, etc.) are validated by the - // type system when the decorator is used. We don't validate them here to keep - // the constructor simple and allow minimal mocks for testing. - - var boundTokenHash string - var tokenSalt []byte - var err error - // Compute token binding for authenticated sessions - if !allowAnonymous && identity != nil && identity.Token != "" { - // Generate unique salt for this session - tokenSalt, err = generateSalt() - if err != nil { - return nil, fmt.Errorf("failed to generate token salt: %w", err) - } - // Compute HMAC-SHA256 hash with server secret and per-session salt - boundTokenHash = hashToken(identity.Token, hmacSecret, tokenSalt) - } - - // Store hash and salt in session metadata for persistence, auditing, - // and backward compatibility - session.SetMetadata(metadataKeyTokenHash, boundTokenHash) - if len(tokenSalt) > 0 { - session.SetMetadata(metadataKeyTokenSalt, hex.EncodeToString(tokenSalt)) + if binding.IsUnauthenticated(storedBinding) { + return &sessionBindingDecorator{ + MultiSession: session, + boundIdentity: storedBinding, + allowAnonymous: true, + }, nil } - // Make defensive copies of slices to prevent external mutation - var hmacSecretCopy, tokenSaltCopy []byte - if len(hmacSecret) > 0 { - hmacSecretCopy = append([]byte(nil), hmacSecret...) - } - if len(tokenSalt) > 0 { - tokenSaltCopy = append([]byte(nil), tokenSalt...) + // Validate the stored binding is parseable. We do not use iss/sub here — + // the factory calls binding.Parse separately when it needs them to + // reconstruct identity. This call is purely a validation gate. + if _, _, ok := binding.Parse(storedBinding); !ok { + return nil, fmt.Errorf("RestoreSessionBinding: stored binding is neither the unauthenticated sentinel " + + "nor a valid bound binding (corrupted metadata)") } - // Wrap with hijackPreventionDecorator for runtime validation. - // The decorator embeds the MultiSession interface, so all methods are automatically - // delegated except for the three we override (CallTool, ReadResource, GetPrompt). - return &hijackPreventionDecorator{ + return &sessionBindingDecorator{ MultiSession: session, - allowAnonymous: allowAnonymous, - hmacSecret: hmacSecretCopy, - boundTokenHash: boundTokenHash, - tokenSalt: tokenSaltCopy, + boundIdentity: storedBinding, + allowAnonymous: false, }, nil } diff --git a/pkg/vmcp/session/mocks/mock_factory.go b/pkg/vmcp/session/mocks/mock_factory.go index e548519ea0..3d850a1eb4 100644 --- a/pkg/vmcp/session/mocks/mock_factory.go +++ b/pkg/vmcp/session/mocks/mock_factory.go @@ -44,18 +44,18 @@ func (m *MockMultiSessionFactory) EXPECT() *MockMultiSessionFactoryMockRecorder } // MakeSessionWithID mocks base method. -func (m *MockMultiSessionFactory) MakeSessionWithID(ctx context.Context, id string, identity *auth.Identity, allowAnonymous bool, backends []*vmcp.Backend) (session.MultiSession, error) { +func (m *MockMultiSessionFactory) MakeSessionWithID(ctx context.Context, id string, identity *auth.Identity, backends []*vmcp.Backend) (session.MultiSession, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeSessionWithID", ctx, id, identity, allowAnonymous, backends) + ret := m.ctrl.Call(m, "MakeSessionWithID", ctx, id, identity, backends) ret0, _ := ret[0].(session.MultiSession) ret1, _ := ret[1].(error) return ret0, ret1 } // MakeSessionWithID indicates an expected call of MakeSessionWithID. -func (mr *MockMultiSessionFactoryMockRecorder) MakeSessionWithID(ctx, id, identity, allowAnonymous, backends any) *gomock.Call { +func (mr *MockMultiSessionFactoryMockRecorder) MakeSessionWithID(ctx, id, identity, backends any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeSessionWithID", reflect.TypeOf((*MockMultiSessionFactory)(nil).MakeSessionWithID), ctx, id, identity, allowAnonymous, backends) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeSessionWithID", reflect.TypeOf((*MockMultiSessionFactory)(nil).MakeSessionWithID), ctx, id, identity, backends) } // RestoreSession mocks base method. diff --git a/pkg/vmcp/session/session.go b/pkg/vmcp/session/session.go index c8bc492cac..0b615603b2 100644 --- a/pkg/vmcp/session/session.go +++ b/pkg/vmcp/session/session.go @@ -11,18 +11,13 @@ import ( // backward compatibility and convenience. type MultiSession = sessiontypes.MultiSession +// Re-exports from the types package for convenience. See the types package for +// authoritative documentation. const ( - // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 - // hash of the bearer token used to create the session. For authenticated sessions - // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty - // string sentinel. The raw token is never stored — only the hash. - // - // Re-exported from types package for convenience. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); invalidated on read. MetadataKeyTokenHash = sessiontypes.MetadataKeyTokenHash - - // MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded - // random salt used for HMAC-SHA256 token hashing. Omitted for anonymous sessions. - // - // Re-exported from types package for convenience. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306). MetadataKeyTokenSalt = sessiontypes.MetadataKeyTokenSalt + + MetadataKeyIdentityBinding = sessiontypes.MetadataKeyIdentityBinding ) diff --git a/pkg/vmcp/session/token_binding_test.go b/pkg/vmcp/session/token_binding_test.go deleted file mode 100644 index db8fdf412e..0000000000 --- a/pkg/vmcp/session/token_binding_test.go +++ /dev/null @@ -1,343 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package session - -import ( - "context" - "errors" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/vmcp" - internalbk "github.com/stacklok/toolhive/pkg/vmcp/session/internal/backend" - sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" -) - -// --------------------------------------------------------------------------- -// makeSession stores token hash in metadata -// --------------------------------------------------------------------------- - -// nilBackendConnector is a connector that returns (nil, nil, nil), causing the -// backend to be skipped during init. This lets us exercise session-metadata -// logic without real backend connections. -func nilBackendConnector() backendConnector { - return func(_ context.Context, _ *vmcp.BackendTarget, _ *auth.Identity, _ string) (internalbk.Session, *vmcp.CapabilityList, error) { - return nil, nil, nil - } -} - -func TestMakeSession_StoresTokenHash(t *testing.T) { - t.Parallel() - - t.Run("authenticated session stores HMAC-SHA256 hash", func(t *testing.T) { - t.Parallel() - - const rawToken = "test-bearer-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, Token: rawToken} - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, false, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - // Verify token hash is stored - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set") - assert.NotEmpty(t, storedHash, "Token hash must be non-empty for authenticated session") - assert.Len(t, storedHash, 64, "HMAC-SHA256 hex-encoded hash should be 64 characters") - // Raw token must never appear in metadata. - assert.NotEqual(t, rawToken, storedHash) - - // Verify salt is stored for authenticated sessions - storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") - assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") - }) - - t.Run("anonymous session stores empty sentinel", func(t *testing.T) { - t.Parallel() - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), nil, true, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set even for anonymous sessions") - assert.Empty(t, storedHash, "anonymous session must store empty sentinel") - - // Salt must not be present for anonymous sessions - storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "anonymous session must not store a salt") - }) - - t.Run("identity with empty token stores empty sentinel", func(t *testing.T) { - t.Parallel() - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), uuid.New().String(), identity, true, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - storedHash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.Empty(t, storedHash, "empty-token identity must store empty sentinel") - - // Salt must not be present for empty-token (anonymous) sessions - storedSalt := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - assert.Empty(t, storedSalt, "empty-token identity must not store a salt") - }) - - t.Run("MakeSessionWithID also stores token hash", func(t *testing.T) { - t.Parallel() - - const rawToken = "id-specific-token" - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "bob"}, Token: rawToken} - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - sess, err := factory.MakeSessionWithID(t.Context(), "explicit-session-id", identity, false, nil) - require.NoError(t, err) - require.NotNil(t, sess) - - // Verify token hash - storedHash, present := sess.GetMetadata()[MetadataKeyTokenHash] - require.True(t, present, "MetadataKeyTokenHash must be set") - assert.NotEmpty(t, storedHash, "Token hash must be non-empty") - assert.Len(t, storedHash, 64, "HMAC-SHA256 hex-encoded hash should be 64 characters") - - // Verify salt is stored for authenticated sessions - storedSalt, saltPresent := sess.GetMetadata()[sessiontypes.MetadataKeyTokenSalt] - require.True(t, saltPresent, "MetadataKeyTokenSalt must be set for authenticated sessions") - assert.NotEmpty(t, storedSalt, "Salt must be non-empty for authenticated session") - }) -} - -// --------------------------------------------------------------------------- -// MakeSessionWithID validation -// --------------------------------------------------------------------------- - -// TestMakeSessionWithID_ValidationOfAllowAnonymous tests that MakeSessionWithID -// validates consistency between identity and allowAnonymous parameters. -func TestMakeSessionWithID_ValidationOfAllowAnonymous(t *testing.T) { - t.Parallel() - - factory := NewSessionFactory(nil) - - t.Run("rejects anonymous session with bearer token", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "bearer-token"} - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - true, // allowAnonymous=true but identity has token - nil, // no backends needed for validation test - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create anonymous session") - assert.Contains(t, err.Error(), "with bearer token") - }) - - t.Run("rejects bound session without bearer token (nil identity)", func(t *testing.T) { - t.Parallel() - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - nil, // no identity - false, // allowAnonymous=false but no identity - nil, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create bound session") - assert.Contains(t, err.Error(), "without bearer token") - }) - - t.Run("rejects bound session without bearer token (empty token)", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} // empty token - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - false, // allowAnonymous=false but token is empty - nil, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot create bound session") - assert.Contains(t, err.Error(), "without bearer token") - }) - - t.Run("allows anonymous session with nil identity", func(t *testing.T) { - t.Parallel() - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - nil, // no identity - true, // allowAnonymous=true - consistent - nil, - ) - require.NoError(t, err) - }) - - t.Run("allows anonymous session with empty token", func(t *testing.T) { - t.Parallel() - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: ""} - _, err := factory.MakeSessionWithID( - context.Background(), - "test-session", - identity, - true, // allowAnonymous=true and token is empty - consistent - nil, - ) - require.NoError(t, err) - }) -} - -// --------------------------------------------------------------------------- -// WithHMACSecret defensive copy -// --------------------------------------------------------------------------- - -// TestWithHMACSecret_DefensiveCopy verifies that WithHMACSecret makes a defensive -// copy of the secret to prevent external modification after assignment. -func TestWithHMACSecret_DefensiveCopy(t *testing.T) { - t.Parallel() - - // Create a mutable secret - secretSlice := []byte("original-secret-value") - - // Create factory with the secret - factory := newSessionFactoryWithConnector(nilBackendConnector(), WithHMACSecret(secretSlice)) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - - // Create first session before modification - sess1, err := factory.MakeSessionWithID(context.Background(), "session-1", identity, false, nil) - require.NoError(t, err) - - // Verify first session was created successfully - hash1 := sess1.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, hash1, "first session should have token hash") - - // Maliciously modify the secret slice after passing it to WithHMACSecret - for i := range secretSlice { - secretSlice[i] = 0xFF - } - - // Create second session after modification - should still work correctly - // because WithHMACSecret made a defensive copy - sess2, err := factory.MakeSessionWithID(context.Background(), "session-2", identity, false, nil) - require.NoError(t, err) - - // Verify second session was created successfully - hash2 := sess2.GetMetadata()[MetadataKeyTokenHash] - require.NotEmpty(t, hash2, "second session should have token hash") - - // Both sessions should still be able to validate the original token - // (proving the factory used the original secret, not the modified one). - // We verify this by calling a session method that requires authentication. - ctx := context.Background() - - // First session should accept the original token and fail with ErrToolNotFound, - // not an auth error (which would indicate the secret was corrupted) - _, err = sess1.CallTool(ctx, identity, "nonexistent-tool", nil, nil) - assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") - assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), - "should not be an auth error (would indicate corrupted secret)") - - // Second session should also accept the original token and fail with ErrToolNotFound - _, err = sess2.CallTool(ctx, identity, "nonexistent-tool", nil, nil) - assert.ErrorIs(t, err, ErrToolNotFound, "should fail with tool not found error") - assert.False(t, errors.Is(err, sessiontypes.ErrUnauthorizedCaller), - "should not be an auth error (would indicate corrupted secret)") -} - -// --------------------------------------------------------------------------- -// RestoreSession fail-closed behaviour for absent token-hash key -// --------------------------------------------------------------------------- - -// TestRestoreSession_AbsentTokenHashKey verifies that RestoreSession fails closed -// when the stored metadata is missing MetadataKeyTokenHash entirely. -// -// Background: storedMetadata[key] returns "" for both an absent key and a -// legitimately anonymous session (which stores "" as a sentinel). The factory -// uses the two-value map lookup form to distinguish between the two cases and -// rejects absent keys rather than silently downgrading to anonymous. -func TestRestoreSession_AbsentTokenHashKey(t *testing.T) { - t.Parallel() - - factory := newSessionFactoryWithConnector(nilBackendConnector()) - - t.Run("absent token-hash key is rejected (fail closed)", func(t *testing.T) { - t.Parallel() - - // Metadata that deliberately omits MetadataKeyTokenHash (simulates - // corrupted or truncated session metadata). MetadataKeyBackendIDs is - // present (empty = zero backends) so the earlier backend-IDs guard - // passes and we reach the token-hash guard. - storedMetadata := map[string]string{ - MetadataKeyIdentitySubject: "alice", - MetadataKeyBackendIDs: "", // present, empty = zero backends - // MetadataKeyTokenHash intentionally absent - } - - _, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "token hash metadata key absent") - }) - - t.Run("empty token-hash key (anonymous sentinel) is accepted", func(t *testing.T) { - t.Parallel() - - // Metadata with MetadataKeyTokenHash present but empty — this is what - // PreventSessionHijacking writes for anonymous sessions. - storedMetadata := map[string]string{ - MetadataKeyBackendIDs: "", // present, empty = zero backends - sessiontypes.MetadataKeyTokenHash: "", // present, empty = anonymous - } - - sess, err := factory.RestoreSession(t.Context(), uuid.New().String(), storedMetadata, nil) - require.NoError(t, err) - require.NotNil(t, sess) - }) -} - -// TestWithHMACSecret_RejectsEmptySecret verifies that WithHMACSecret rejects -// nil or empty secrets to prevent silent security downgrades. -func TestWithHMACSecret_RejectsEmptySecret(t *testing.T) { - t.Parallel() - - t.Run("nil secret is rejected", func(t *testing.T) { - t.Parallel() - - // Create factory with nil secret (should fall back to default) - factory := NewSessionFactory(nil, WithHMACSecret(nil)) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - sess, err := factory.MakeSessionWithID(context.Background(), "test-session", identity, false, nil) - require.NoError(t, err) - - // Should still create a valid session with default secret - hash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.NotEmpty(t, hash, "should use default secret, not nil") - }) - - t.Run("empty secret is rejected", func(t *testing.T) { - t.Parallel() - - // Create factory with empty secret (should fall back to default) - factory := NewSessionFactory(nil, WithHMACSecret([]byte{})) - - identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: "user"}, Token: "test-token"} - sess, err := factory.MakeSessionWithID(context.Background(), "test-session", identity, false, nil) - require.NoError(t, err) - - // Should still create a valid session with default secret - hash := sess.GetMetadata()[MetadataKeyTokenHash] - assert.NotEmpty(t, hash, "should use default secret, not empty slice") - }) -} diff --git a/pkg/vmcp/session/types/session.go b/pkg/vmcp/session/types/session.go index f5752b48f1..bafb74eadc 100644 --- a/pkg/vmcp/session/types/session.go +++ b/pkg/vmcp/session/types/session.go @@ -12,10 +12,12 @@ package types import ( "context" "errors" + "log/slog" "github.com/stacklok/toolhive/pkg/auth" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/binding" ) // Caller represents the ability to invoke MCP protocol operations against a @@ -139,37 +141,86 @@ type MultiSession interface { } const ( - // MetadataKeyTokenHash is the session metadata key that holds the HMAC-SHA256 - // hash of the bearer token used to create the session. For authenticated sessions - // this is hex(HMAC-SHA256(bearerToken)). For anonymous sessions this is the empty - // string sentinel. The raw token is never stored — only the hash. + // MetadataKeyTokenHash held hex(HMAC-SHA256(bearerToken)) for authenticated + // sessions and "" for anonymous sessions. // - // This constant is the single source of truth used by the session factory and - // security layer to store and validate token binding metadata. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); sessions written + // under this key are invalidated on read. Constant removed in the follow-up PR. MetadataKeyTokenHash = "vmcp.token.hash" //nolint:gosec // This is a metadata key name, not a credential. - // MetadataKeyTokenSalt is the session metadata key that holds the hex-encoded - // random salt used for HMAC-SHA256 token hashing. Each authenticated session has a - // unique salt to prevent attacks across multiple sessions. Anonymous sessions do not - // generate a salt and this key is omitted from their metadata. + // MetadataKeyTokenSalt held the hex-encoded per-session salt used for + // HMAC-SHA256 token hashing. // - // This constant is the single source of truth used by the session factory and - // security layer to store and validate token binding metadata. + // Legacy: superseded by MetadataKeyIdentityBinding (#5306); removed in the follow-up PR. MetadataKeyTokenSalt = "vmcp.token.salt" //nolint:gosec // This is a metadata key name, not a credential. + + // MetadataKeyIdentityBinding is the session metadata key that holds the + // identity-binding string in the format defined by + // [pkg/vmcp/session/binding]. Bound sessions store binding.Format(iss, sub); + // unauthenticated sessions store binding.UnauthenticatedSentinel. + // + // Storage is plaintext PII — the Redis/Valkey instance must be access-controlled. + MetadataKeyIdentityBinding = "vmcp.identity.binding" ) -// ShouldAllowAnonymous determines if a session should allow anonymous access -// based on the creator's identity. Sessions without an identity (nil) or with -// an empty token are treated as anonymous. +// ShouldAllowAnonymous reports whether a session has no per-user binding and +// should be treated as anonymous. The identity is bound when Token is set, or +// when Claims yields a (iss, sub) pair accepted by binding.Format. +// +// Fail-closed: a Claims["iss"] or Claims["sub"] that is present but not a +// string (a misbehaving validator) is treated as bound, with a WARN logged. +// +// Contract note: this function answers "should this session be CREATED as +// anonymous?" using Token presence as a fast path. It does NOT guarantee that +// a non-anonymous identity will produce a successful binding: an identity with +// Token != "" but missing iss/sub claims passes this check as bound, then +// fails in BindSession with an extraction error. In practice this is +// pathological — all shipping middlewares (JWT validator, LocalUserMiddleware, +// AnonymousMiddleware) populate Claims. Callers who need "will this identity +// actually produce a binding?" must use extractBindingID directly, which is +// deliberately kept internal to the security decorator. func ShouldAllowAnonymous(identity *auth.Identity) bool { - return identity == nil || identity.Token == "" + if identity == nil { + return true + } + if identity.Token != "" { + return false + } + iss, issOK := claimString(identity.Claims, "iss") + sub, subOK := claimString(identity.Claims, "sub") + if !issOK || !subOK { + slog.Warn("auth identity has present-but-non-string iss/sub claim; treating as bound") + return false + } + if _, err := binding.Format(iss, sub); err == nil { + return false + } + return true +} + +// claimString returns (value, ok) for a string claim. ok is true when the +// claim is missing entirely (absence is benign — the caller proceeds to +// binding.Format which rejects empty halves) or when the claim is present +// and a string (including the empty string). ok is false only when the +// claim is present but not a string — the caller must treat that as a +// misconfigured validator and fail closed. +func claimString(claims map[string]any, key string) (string, bool) { + v, present := claims[key] + if !present { + return "", true + } + s, isString := v.(string) + if !isString { + return "", false + } + return s, true } -// Token binding errors returned by Caller methods when caller identity +// Identity binding errors returned by Caller methods when caller identity // validation fails. var ( // ErrUnauthorizedCaller is returned when the caller identity does not - // match the session owner's identity (token hash mismatch). + // match the session owner's identity (identity binding mismatch). ErrUnauthorizedCaller = errors.New("caller identity does not match session owner") // ErrNilCaller is returned when a bound session receives a nil caller. diff --git a/pkg/vmcp/session/types/session_test.go b/pkg/vmcp/session/types/session_test.go new file mode 100644 index 0000000000..3e6b19ce12 --- /dev/null +++ b/pkg/vmcp/session/types/session_test.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stacklok/toolhive/pkg/auth" +) + +func TestShouldAllowAnonymous(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + identity *auth.Identity + want bool + }{ + { + name: "nil identity", + identity: nil, + want: true, + }, + { + name: "token present", + identity: &auth.Identity{ + Token: "x", + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{"iss": "https://idp.example", "sub": "alice"}, + }, + }, + want: false, + }, + { + name: "empty token with valid iss+sub claims", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{ + Claims: map[string]any{"iss": "toolhive-local", "sub": "alice"}, + }, + }, + want: false, + }, + { + name: "empty token and nil claims", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: nil}, + }, + want: true, + }, + { + name: "empty token and missing iss", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"sub": "x"}}, + }, + want: true, + }, + { + name: "empty token and missing sub", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "x"}}, + }, + want: true, + }, + { + name: "empty token and empty string iss and sub", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "", "sub": ""}}, + }, + want: true, + }, + { + // Fail-closed: non-string claim from a misbehaving validator → bound, not anonymous. + name: "non-string iss fails closed", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": 42, "sub": "x"}}, + }, + want: false, + }, + { + name: "non-string sub fails closed", + identity: &auth.Identity{ + PrincipalInfo: auth.PrincipalInfo{Claims: map[string]any{"iss": "x", "sub": []string{"foo"}}}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, ShouldAllowAnonymous(tt.identity)) + }) + } +}