Skip to content

Commit 804df32

Browse files
JAORMXclaude
andcommitted
Allow embedders to inject custom session DataStorage
The vMCP server currently only supports the "memory" and "redis" providers encoded in SessionStorageConfig. Embedders that run on Postgres, DynamoDB, Spanner, etc. would have to fork server.go or fall back to in-memory sessions (no persistence across pod restarts). Add an optional Config.DataStorage field of type transportsession.DataStorage. When non-nil, the server uses the caller-supplied store directly and the SessionStorage enum is ignored. Setting both is rejected at New() so misconfiguration surfaces loudly instead of silently favouring one. Caller owns the lifecycle: the server never calls Close() on a caller-supplied store, matching the existing convention for every other caller-supplied dependency on Config (TelemetryProvider, StatusReporter, Watcher). The server-built path is unchanged — buildSessionDataStorage now returns a closer so that lifecycle is tracked explicitly rather than via an ownership bool. Closes #4928 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 68f4c2f commit 804df32

2 files changed

Lines changed: 269 additions & 18 deletions

File tree

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package server_test
5+
6+
import (
7+
"sync/atomic"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/mock/gomock"
14+
15+
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
16+
"github.com/stacklok/toolhive/pkg/vmcp"
17+
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
18+
discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
19+
"github.com/stacklok/toolhive/pkg/vmcp/mocks"
20+
routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks"
21+
"github.com/stacklok/toolhive/pkg/vmcp/server"
22+
)
23+
24+
// countingDataStorage wraps a real LocalSessionDataStorage and counts how
25+
// many times Close has been invoked. Used to assert that Server.Stop does
26+
// not close a caller-supplied DataStorage.
27+
type countingDataStorage struct {
28+
transportsession.DataStorage
29+
closeCalls atomic.Int32
30+
}
31+
32+
func (c *countingDataStorage) Close() error {
33+
c.closeCalls.Add(1)
34+
return c.DataStorage.Close()
35+
}
36+
37+
func newCountingDataStorage(t *testing.T) *countingDataStorage {
38+
t.Helper()
39+
inner, err := transportsession.NewLocalSessionDataStorage(5 * time.Minute)
40+
require.NoError(t, err)
41+
return &countingDataStorage{DataStorage: inner}
42+
}
43+
44+
func TestNew_CallerOwnedDataStorageNotClosedOnStop(t *testing.T) {
45+
t.Parallel()
46+
47+
spy := newCountingDataStorage(t)
48+
// The spy is caller-owned; close the inner LocalSessionDataStorage
49+
// directly at the end of the test so the counter is not ticked by
50+
// cleanup — the post-Stop assertion below must reflect only the server's
51+
// behaviour. Err ignored: closing an already-closed local store is a
52+
// no-op in this implementation.
53+
t.Cleanup(func() {
54+
_ = spy.DataStorage.Close()
55+
})
56+
57+
ctrl := gomock.NewController(t)
58+
t.Cleanup(ctrl.Finish)
59+
60+
mockRouter := routerMocks.NewMockRouter(ctrl)
61+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
62+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
63+
mockDiscoveryMgr.EXPECT().Stop().Times(1)
64+
65+
srv, err := server.New(
66+
t.Context(),
67+
&server.Config{
68+
Host: "127.0.0.1",
69+
Port: 0,
70+
SessionFactory: newNoopMockFactory(t),
71+
DataStorage: spy,
72+
},
73+
mockRouter,
74+
mockBackendClient,
75+
mockDiscoveryMgr,
76+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
77+
nil,
78+
)
79+
require.NoError(t, err)
80+
81+
err = srv.Stop(t.Context())
82+
require.NoError(t, err)
83+
84+
assert.Equal(t, int32(0), spy.closeCalls.Load(),
85+
"server must not close a caller-supplied DataStorage")
86+
}
87+
88+
func TestNew_BothSessionStorageAndDataStorageErrors(t *testing.T) {
89+
t.Parallel()
90+
91+
spy := newCountingDataStorage(t)
92+
// Err ignored: closing an already-closed local store is a no-op.
93+
t.Cleanup(func() {
94+
_ = spy.DataStorage.Close()
95+
})
96+
97+
ctrl := gomock.NewController(t)
98+
t.Cleanup(ctrl.Finish)
99+
100+
mockRouter := routerMocks.NewMockRouter(ctrl)
101+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
102+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
103+
104+
_, err := server.New(
105+
t.Context(),
106+
&server.Config{
107+
Host: "127.0.0.1",
108+
Port: 0,
109+
SessionFactory: newNoopMockFactory(t),
110+
SessionStorage: &vmcpconfig.SessionStorageConfig{
111+
Provider: "redis",
112+
Address: "127.0.0.1:6379",
113+
},
114+
DataStorage: spy,
115+
},
116+
mockRouter,
117+
mockBackendClient,
118+
mockDiscoveryMgr,
119+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
120+
nil,
121+
)
122+
require.Error(t, err)
123+
assert.Contains(t, err.Error(), "DataStorage")
124+
assert.Contains(t, err.Error(), "SessionStorage")
125+
assert.Equal(t, int32(0), spy.closeCalls.Load(),
126+
"server must not close a caller-supplied DataStorage on misconfiguration")
127+
}
128+
129+
func TestNew_ServerBuiltDataStorageStopSucceeds(t *testing.T) {
130+
// Guards against accidental regression of the server-owned close path
131+
// when Close moved from an inline Stop() block onto sessionDataStorageCloser.
132+
// Stop() must still complete without error when the server built the store.
133+
// This is a smoke test — it cannot observe Close on the internal
134+
// LocalSessionDataStorage because that type is constructed inside New().
135+
t.Parallel()
136+
137+
ctrl := gomock.NewController(t)
138+
t.Cleanup(ctrl.Finish)
139+
140+
mockRouter := routerMocks.NewMockRouter(ctrl)
141+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
142+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
143+
mockDiscoveryMgr.EXPECT().Stop().Times(1)
144+
145+
srv, err := server.New(
146+
t.Context(),
147+
&server.Config{
148+
Host: "127.0.0.1",
149+
Port: 0,
150+
SessionFactory: newNoopMockFactory(t),
151+
SessionStorage: &vmcpconfig.SessionStorageConfig{Provider: "memory"},
152+
},
153+
mockRouter,
154+
mockBackendClient,
155+
mockDiscoveryMgr,
156+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
157+
nil,
158+
)
159+
require.NoError(t, err)
160+
161+
require.NoError(t, srv.Stop(t.Context()))
162+
}

pkg/vmcp/server/server.go

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,38 @@ type Config struct {
180180
// When provider is "redis", a Redis-backed store is created for cross-pod
181181
// session persistence; the Redis password is read from the
182182
// THV_SESSION_REDIS_PASSWORD environment variable.
183+
//
184+
// Mutually exclusive with DataStorage: setting both is rejected at New().
183185
SessionStorage *vmcpconfig.SessionStorageConfig
186+
187+
// DataStorage optionally injects a caller-supplied session metadata store,
188+
// bypassing the built-in memory/redis providers. When non-nil, the server
189+
// uses this store as-is and SessionStorage is ignored in its entirety (no
190+
// field of SessionStorage is consulted). Setting both DataStorage and a
191+
// non-empty SessionStorage.Provider is rejected at New() as ambiguous
192+
// configuration.
193+
//
194+
// Lifecycle: the caller owns it. The server does NOT call Close() on a
195+
// caller-supplied DataStorage, even on error paths in New() or during
196+
// Stop(). The caller is responsible for invoking Close() exactly once
197+
// after Server.Stop() returns (not before — the session manager may issue
198+
// final Update calls during Stop). The caller is likewise responsible for
199+
// configuring the store's TTL; cfg.SessionTTL applies only to the
200+
// transport-level session manager, not to the caller-supplied DataStorage.
201+
//
202+
// Sensitive material: the store holds HMAC-hashed token material and
203+
// other session metadata. Embedders should treat the backing datastore as
204+
// sensitive (dedicated credentials, encryption at rest, restricted read
205+
// access). Implementations must not include credentials in Close() error
206+
// messages — those errors are surfaced through Server.Stop().
207+
//
208+
// This seam lets embedders satisfy transportsession.DataStorage against
209+
// datastores other than the built-in providers (e.g. Postgres, DynamoDB)
210+
// without forking the server. It enables cross-replica session metadata
211+
// sharing when backed by a shared store; it does NOT solve cross-replica
212+
// message delivery — callers still need session affinity at the load
213+
// balancer for streaming responses.
214+
DataStorage transportsession.DataStorage
184215
}
185216

186217
// Server is the Virtual MCP Server that aggregates multiple backends.
@@ -223,10 +254,16 @@ type Server struct {
223254
sessionManager *transportsession.Manager
224255

225256
// sessionDataStorage is the pluggable key-value backend for session metadata.
226-
// Currently always LocalSessionDataStorage (in-memory, single-process).
227-
// Redis-backed storage for multi-pod deployments is not yet wired.
257+
// It may be LocalSessionDataStorage (in-memory, single-process), a Redis-backed
258+
// store, or a caller-supplied implementation injected via Config.DataStorage.
228259
sessionDataStorage transportsession.DataStorage
229260

261+
// sessionDataStorageCloser closes sessionDataStorage on shutdown. It is
262+
// set only when the server built the store itself (memory or redis
263+
// providers). When Config.DataStorage was supplied by the caller, this is
264+
// nil and the caller is responsible for closing the store.
265+
sessionDataStorageCloser func(context.Context) error
266+
230267
// Capability adapter for converting aggregator types to SDK types
231268
capabilityAdapter *adapter.CapabilityAdapter
232269

@@ -256,21 +293,51 @@ type Server struct {
256293
}
257294

258295
// buildSessionDataStorage constructs the DataStorage backend from cfg.
259-
// When cfg.SessionStorage is nil or provider is "memory" (or empty), local in-process
260-
// storage is used. When provider is "redis", a Redis-backed store is created
261-
// using the address, DB, and key prefix from cfg.SessionStorage; the password
262-
// is read from the THV_SESSION_REDIS_PASSWORD environment variable.
263-
// Any other provider value is a misconfiguration and returns an error.
264-
func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession.DataStorage, error) {
296+
//
297+
// Resolution order:
298+
//
299+
// 1. cfg.DataStorage (caller-supplied) takes precedence. When non-nil, the
300+
// store is returned as-is with a nil closer — the caller owns the
301+
// lifecycle. Setting both cfg.DataStorage and a non-empty
302+
// cfg.SessionStorage.Provider is rejected as ambiguous.
303+
// 2. cfg.SessionStorage.Provider "memory" (or empty, or nil SessionStorage):
304+
// local in-process storage is created.
305+
// 3. cfg.SessionStorage.Provider "redis": a Redis-backed store is created
306+
// using the address, DB, and key prefix from cfg.SessionStorage. The
307+
// password is read from the THV_SESSION_REDIS_PASSWORD environment
308+
// variable.
309+
// 4. Any other provider value is a misconfiguration and returns an error.
310+
//
311+
// For cases 2 and 3 (server-built stores), the returned closer wraps the
312+
// store's Close method. For case 1 (caller-supplied), the closer is nil.
313+
// New() routes the returned closer through Server.sessionDataStorageCloser
314+
// so Close is invoked on shutdown (and on New() error after this point) —
315+
// but only for server-built stores.
316+
func buildSessionDataStorage(
317+
ctx context.Context,
318+
cfg *Config,
319+
) (transportsession.DataStorage, func(context.Context) error, error) {
320+
if cfg.DataStorage != nil {
321+
if cfg.SessionStorage != nil && cfg.SessionStorage.Provider != "" {
322+
return nil, nil, fmt.Errorf(
323+
"cannot set both Config.DataStorage and Config.SessionStorage.Provider (%q); pick one",
324+
cfg.SessionStorage.Provider)
325+
}
326+
return cfg.DataStorage, nil, nil
327+
}
265328
// Default to in-process storage when session storage is not configured,
266329
// or when the provider is explicitly "memory" or left empty.
267330
if cfg.SessionStorage == nil ||
268331
cfg.SessionStorage.Provider == "" ||
269332
strings.EqualFold(cfg.SessionStorage.Provider, "memory") {
270-
return transportsession.NewLocalSessionDataStorage(cfg.SessionTTL)
333+
store, err := transportsession.NewLocalSessionDataStorage(cfg.SessionTTL)
334+
if err != nil {
335+
return nil, nil, err
336+
}
337+
return store, closerFor(store), nil
271338
}
272339
if cfg.SessionStorage.Provider != "redis" {
273-
return nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")",
340+
return nil, nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")",
274341
cfg.SessionStorage.Provider)
275342
}
276343
keyPrefix := cfg.SessionStorage.KeyPrefix
@@ -288,7 +355,19 @@ func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession
288355
"db", cfg.SessionStorage.DB,
289356
"key_prefix", keyPrefix,
290357
)
291-
return transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL)
358+
store, err := transportsession.NewRedisSessionDataStorage(ctx, redisCfg, cfg.SessionTTL)
359+
if err != nil {
360+
return nil, nil, err
361+
}
362+
return store, closerFor(store), nil
363+
}
364+
365+
// closerFor adapts DataStorage.Close (no context) to the
366+
// func(context.Context) error signature used by Server.sessionDataStorageCloser.
367+
func closerFor(store transportsession.DataStorage) func(context.Context) error {
368+
return func(context.Context) error {
369+
return store.Close()
370+
}
292371
}
293372

294373
// New creates a new Virtual MCP Server instance.
@@ -412,16 +491,18 @@ func New(
412491
// keyed by the same session ID.
413492
sessionManager := transportsession.NewManager(cfg.SessionTTL, transportsession.NewStreamableSession)
414493

415-
sessionDataStorage, err := buildSessionDataStorage(ctx, cfg)
494+
sessionDataStorage, sessionDataStorageCloser, err := buildSessionDataStorage(ctx, cfg)
416495
if err != nil {
417496
return nil, fmt.Errorf("failed to create session data storage: %w", err)
418497
}
419-
// Close sessionDataStorage if New() returns an error after this point so the
420-
// background cleanup goroutine does not leak.
421-
closeStorageOnErr := true
498+
// If we built the store ourselves, close it when New() returns an error
499+
// after this point so the background cleanup goroutine does not leak.
500+
// For a caller-supplied store (sessionDataStorageCloser == nil), the
501+
// caller owns the lifecycle and we leave it untouched on every path.
502+
closeStorageOnErr := sessionDataStorageCloser != nil
422503
defer func() {
423504
if closeStorageOnErr {
424-
_ = sessionDataStorage.Close()
505+
_ = sessionDataStorageCloser(ctx)
425506
}
426507
}()
427508

@@ -486,6 +567,12 @@ func New(
486567
srv.shutdownFuncs = append(srv.shutdownFuncs, optimizerCleanup)
487568
}
488569

570+
// Store the session data storage closer on the Server so Stop() can invoke
571+
// it last (after session manager and discovery manager have stopped). For
572+
// a caller-supplied store this is nil and Stop() leaves it alone — the
573+
// caller owns the lifecycle.
574+
srv.sessionDataStorageCloser = sessionDataStorageCloser
575+
489576
// Register OnRegisterSession hook to inject capabilities after SDK registers session.
490577
// See handleSessionRegistration for implementation details.
491578
hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) {
@@ -848,8 +935,10 @@ func (s *Server) Stop(ctx context.Context) error {
848935

849936
// Close session data storage last: HTTP server is down (no new in-flight requests),
850937
// all other components have stopped (no further restore or liveness checks).
851-
if s.sessionDataStorage != nil {
852-
if err := s.sessionDataStorage.Close(); err != nil {
938+
// Only invoked when the server built the store itself; caller-supplied stores
939+
// (Config.DataStorage) are left for the caller to close.
940+
if s.sessionDataStorageCloser != nil {
941+
if err := s.sessionDataStorageCloser(ctx); err != nil {
853942
errs = append(errs, fmt.Errorf("failed to close session data storage: %w", err))
854943
}
855944
}

0 commit comments

Comments
 (0)