Skip to content

Commit 4f73abd

Browse files
JAORMXclaude
andcommitted
Allow embedders to inject custom session DataStorage
The vMCP server currently only supports the "memory" and "redis" providers encoded in SessionStorageConfig. Embedders that run on Postgres, DynamoDB, Spanner, etc. would have to fork server.go or fall back to in-memory sessions (no persistence across pod restarts). Add an optional Config.DataStorage field of type transportsession.DataStorage. When non-nil, the server uses the caller-supplied store directly and the SessionStorage enum is ignored. Setting both is rejected at New() so misconfiguration surfaces loudly instead of silently favouring one. Caller owns the lifecycle: the server never calls Close() on a caller-supplied store, matching the existing convention for every other caller-supplied dependency on Config (TelemetryProvider, StatusReporter, Watcher). The server-built path is unchanged — buildSessionDataStorage now returns a closer so that lifecycle is tracked explicitly rather than via an ownership bool. Closes #4928 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6f63ac0 commit 4f73abd

2 files changed

Lines changed: 269 additions & 18 deletions

File tree

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package server_test
5+
6+
import (
7+
"sync/atomic"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"go.uber.org/mock/gomock"
14+
15+
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
16+
"github.com/stacklok/toolhive/pkg/vmcp"
17+
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
18+
discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
19+
"github.com/stacklok/toolhive/pkg/vmcp/mocks"
20+
routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks"
21+
"github.com/stacklok/toolhive/pkg/vmcp/server"
22+
)
23+
24+
// countingDataStorage wraps a real LocalSessionDataStorage and counts how
25+
// many times Close has been invoked. Used to assert that Server.Stop does
26+
// not close a caller-supplied DataStorage.
27+
type countingDataStorage struct {
28+
transportsession.DataStorage
29+
closeCalls atomic.Int32
30+
}
31+
32+
func (c *countingDataStorage) Close() error {
33+
c.closeCalls.Add(1)
34+
return c.DataStorage.Close()
35+
}
36+
37+
func newCountingDataStorage(t *testing.T) *countingDataStorage {
38+
t.Helper()
39+
inner, err := transportsession.NewLocalSessionDataStorage(5 * time.Minute)
40+
require.NoError(t, err)
41+
return &countingDataStorage{DataStorage: inner}
42+
}
43+
44+
func TestNew_CallerOwnedDataStorageNotClosedOnStop(t *testing.T) {
45+
t.Parallel()
46+
47+
spy := newCountingDataStorage(t)
48+
// The spy is caller-owned; close the inner LocalSessionDataStorage
49+
// directly at the end of the test so the counter is not ticked by
50+
// cleanup — the post-Stop assertion below must reflect only the server's
51+
// behaviour. Err ignored: closing an already-closed local store is a
52+
// no-op in this implementation.
53+
t.Cleanup(func() {
54+
_ = spy.DataStorage.Close()
55+
})
56+
57+
ctrl := gomock.NewController(t)
58+
t.Cleanup(ctrl.Finish)
59+
60+
mockRouter := routerMocks.NewMockRouter(ctrl)
61+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
62+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
63+
mockDiscoveryMgr.EXPECT().Stop().Times(1)
64+
65+
srv, err := server.New(
66+
t.Context(),
67+
&server.Config{
68+
Host: "127.0.0.1",
69+
Port: 0,
70+
SessionFactory: newNoopMockFactory(t),
71+
DataStorage: spy,
72+
},
73+
mockRouter,
74+
mockBackendClient,
75+
mockDiscoveryMgr,
76+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
77+
nil,
78+
)
79+
require.NoError(t, err)
80+
81+
err = srv.Stop(t.Context())
82+
require.NoError(t, err)
83+
84+
assert.Equal(t, int32(0), spy.closeCalls.Load(),
85+
"server must not close a caller-supplied DataStorage")
86+
}
87+
88+
func TestNew_BothSessionStorageAndDataStorageErrors(t *testing.T) {
89+
t.Parallel()
90+
91+
spy := newCountingDataStorage(t)
92+
// Err ignored: closing an already-closed local store is a no-op.
93+
t.Cleanup(func() {
94+
_ = spy.DataStorage.Close()
95+
})
96+
97+
ctrl := gomock.NewController(t)
98+
t.Cleanup(ctrl.Finish)
99+
100+
mockRouter := routerMocks.NewMockRouter(ctrl)
101+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
102+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
103+
104+
_, err := server.New(
105+
t.Context(),
106+
&server.Config{
107+
Host: "127.0.0.1",
108+
Port: 0,
109+
SessionFactory: newNoopMockFactory(t),
110+
SessionStorage: &vmcpconfig.SessionStorageConfig{
111+
Provider: "redis",
112+
Address: "127.0.0.1:6379",
113+
},
114+
DataStorage: spy,
115+
},
116+
mockRouter,
117+
mockBackendClient,
118+
mockDiscoveryMgr,
119+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
120+
nil,
121+
)
122+
require.Error(t, err)
123+
assert.Contains(t, err.Error(), "DataStorage")
124+
assert.Contains(t, err.Error(), "SessionStorage")
125+
assert.Equal(t, int32(0), spy.closeCalls.Load(),
126+
"server must not close a caller-supplied DataStorage on misconfiguration")
127+
}
128+
129+
func TestNew_ServerBuiltDataStorageStopSucceeds(t *testing.T) {
130+
// Guards against accidental regression of the server-owned close path
131+
// when Close moved from an inline Stop() block onto sessionDataStorageCloser.
132+
// Stop() must still complete without error when the server built the store.
133+
// This is a smoke test — it cannot observe Close on the internal
134+
// LocalSessionDataStorage because that type is constructed inside New().
135+
t.Parallel()
136+
137+
ctrl := gomock.NewController(t)
138+
t.Cleanup(ctrl.Finish)
139+
140+
mockRouter := routerMocks.NewMockRouter(ctrl)
141+
mockBackendClient := mocks.NewMockBackendClient(ctrl)
142+
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
143+
mockDiscoveryMgr.EXPECT().Stop().Times(1)
144+
145+
srv, err := server.New(
146+
t.Context(),
147+
&server.Config{
148+
Host: "127.0.0.1",
149+
Port: 0,
150+
SessionFactory: newNoopMockFactory(t),
151+
SessionStorage: &vmcpconfig.SessionStorageConfig{Provider: "memory"},
152+
},
153+
mockRouter,
154+
mockBackendClient,
155+
mockDiscoveryMgr,
156+
vmcp.NewImmutableRegistry([]vmcp.Backend{}),
157+
nil,
158+
)
159+
require.NoError(t, err)
160+
161+
require.NoError(t, srv.Stop(t.Context()))
162+
}

pkg/vmcp/server/server.go

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,38 @@ type Config struct {
181181
// When provider is "redis", a Redis-backed store is created for cross-pod
182182
// session persistence; the Redis password is read from the
183183
// THV_SESSION_REDIS_PASSWORD environment variable.
184+
//
185+
// Mutually exclusive with DataStorage: setting both is rejected at New().
184186
SessionStorage *vmcpconfig.SessionStorageConfig
187+
188+
// DataStorage optionally injects a caller-supplied session metadata store,
189+
// bypassing the built-in memory/redis providers. When non-nil, the server
190+
// uses this store as-is and SessionStorage is ignored in its entirety (no
191+
// field of SessionStorage is consulted). Setting both DataStorage and a
192+
// non-empty SessionStorage.Provider is rejected at New() as ambiguous
193+
// configuration.
194+
//
195+
// Lifecycle: the caller owns it. The server does NOT call Close() on a
196+
// caller-supplied DataStorage, even on error paths in New() or during
197+
// Stop(). The caller is responsible for invoking Close() exactly once
198+
// after Server.Stop() returns (not before — the session manager may issue
199+
// final Update calls during Stop). The caller is likewise responsible for
200+
// configuring the store's TTL; cfg.SessionTTL applies only to the
201+
// transport-level session manager, not to the caller-supplied DataStorage.
202+
//
203+
// Sensitive material: the store holds HMAC-hashed token material and
204+
// other session metadata. Embedders should treat the backing datastore as
205+
// sensitive (dedicated credentials, encryption at rest, restricted read
206+
// access). Implementations must not include credentials in Close() error
207+
// messages — those errors are surfaced through Server.Stop().
208+
//
209+
// This seam lets embedders satisfy transportsession.DataStorage against
210+
// datastores other than the built-in providers (e.g. Postgres, DynamoDB)
211+
// without forking the server. It enables cross-replica session metadata
212+
// sharing when backed by a shared store; it does NOT solve cross-replica
213+
// message delivery — callers still need session affinity at the load
214+
// balancer for streaming responses.
215+
DataStorage transportsession.DataStorage
185216
}
186217

187218
// Server is the Virtual MCP Server that aggregates multiple backends.
@@ -224,10 +255,16 @@ type Server struct {
224255
sessionManager *transportsession.Manager
225256

226257
// sessionDataStorage is the pluggable key-value backend for session metadata.
227-
// Currently always LocalSessionDataStorage (in-memory, single-process).
228-
// Redis-backed storage for multi-pod deployments is not yet wired.
258+
// It may be LocalSessionDataStorage (in-memory, single-process), a Redis-backed
259+
// store, or a caller-supplied implementation injected via Config.DataStorage.
229260
sessionDataStorage transportsession.DataStorage
230261

262+
// sessionDataStorageCloser closes sessionDataStorage on shutdown. It is
263+
// set only when the server built the store itself (memory or redis
264+
// providers). When Config.DataStorage was supplied by the caller, this is
265+
// nil and the caller is responsible for closing the store.
266+
sessionDataStorageCloser func(context.Context) error
267+
231268
// Capability adapter for converting aggregator types to SDK types
232269
capabilityAdapter *adapter.CapabilityAdapter
233270

@@ -257,21 +294,51 @@ type Server struct {
257294
}
258295

259296
// buildSessionDataStorage constructs the DataStorage backend from cfg.
260-
// When cfg.SessionStorage is nil or provider is "memory" (or empty), local in-process
261-
// storage is used. When provider is "redis", a Redis-backed store is created
262-
// using the address, DB, and key prefix from cfg.SessionStorage; the password
263-
// is read from the THV_SESSION_REDIS_PASSWORD environment variable.
264-
// Any other provider value is a misconfiguration and returns an error.
265-
func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession.DataStorage, error) {
297+
//
298+
// Resolution order:
299+
//
300+
// 1. cfg.DataStorage (caller-supplied) takes precedence. When non-nil, the
301+
// store is returned as-is with a nil closer — the caller owns the
302+
// lifecycle. Setting both cfg.DataStorage and a non-empty
303+
// cfg.SessionStorage.Provider is rejected as ambiguous.
304+
// 2. cfg.SessionStorage.Provider "memory" (or empty, or nil SessionStorage):
305+
// local in-process storage is created.
306+
// 3. cfg.SessionStorage.Provider "redis": a Redis-backed store is created
307+
// using the address, DB, and key prefix from cfg.SessionStorage. The
308+
// password is read from the THV_SESSION_REDIS_PASSWORD environment
309+
// variable.
310+
// 4. Any other provider value is a misconfiguration and returns an error.
311+
//
312+
// For cases 2 and 3 (server-built stores), the returned closer wraps the
313+
// store's Close method. For case 1 (caller-supplied), the closer is nil.
314+
// New() routes the returned closer through Server.sessionDataStorageCloser
315+
// so Close is invoked on shutdown (and on New() error after this point) —
316+
// but only for server-built stores.
317+
func buildSessionDataStorage(
318+
ctx context.Context,
319+
cfg *Config,
320+
) (transportsession.DataStorage, func(context.Context) error, error) {
321+
if cfg.DataStorage != nil {
322+
if cfg.SessionStorage != nil && cfg.SessionStorage.Provider != "" {
323+
return nil, nil, fmt.Errorf(
324+
"cannot set both Config.DataStorage and Config.SessionStorage.Provider (%q); pick one",
325+
cfg.SessionStorage.Provider)
326+
}
327+
return cfg.DataStorage, nil, nil
328+
}
266329
// Default to in-process storage when session storage is not configured,
267330
// or when the provider is explicitly "memory" or left empty.
268331
if cfg.SessionStorage == nil ||
269332
cfg.SessionStorage.Provider == "" ||
270333
strings.EqualFold(cfg.SessionStorage.Provider, "memory") {
271-
return transportsession.NewLocalSessionDataStorage(cfg.SessionTTL)
334+
store, err := transportsession.NewLocalSessionDataStorage(cfg.SessionTTL)
335+
if err != nil {
336+
return nil, nil, err
337+
}
338+
return store, closerFor(store), nil
272339
}
273340
if cfg.SessionStorage.Provider != "redis" {
274-
return nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")",
341+
return nil, nil, fmt.Errorf("unsupported session storage provider %q (supported: \"memory\", \"redis\")",
275342
cfg.SessionStorage.Provider)
276343
}
277344
keyPrefix := cfg.SessionStorage.KeyPrefix
@@ -288,7 +355,19 @@ func buildSessionDataStorage(ctx context.Context, cfg *Config) (transportsession
288355
"db", cfg.SessionStorage.DB,
289356
"key_prefix", keyPrefix,
290357
)
291-
return transportsession.NewRedisSessionDataStorage(ctx, redisCfg, keyPrefix, cfg.SessionTTL)
358+
store, err := transportsession.NewRedisSessionDataStorage(ctx, redisCfg, keyPrefix, cfg.SessionTTL)
359+
if err != nil {
360+
return nil, nil, err
361+
}
362+
return store, closerFor(store), nil
363+
}
364+
365+
// closerFor adapts DataStorage.Close (no context) to the
366+
// func(context.Context) error signature used by Server.sessionDataStorageCloser.
367+
func closerFor(store transportsession.DataStorage) func(context.Context) error {
368+
return func(context.Context) error {
369+
return store.Close()
370+
}
292371
}
293372

294373
// New creates a new Virtual MCP Server instance.
@@ -412,16 +491,18 @@ func New(
412491
// keyed by the same session ID.
413492
sessionManager := transportsession.NewManager(cfg.SessionTTL, transportsession.NewStreamableSession)
414493

415-
sessionDataStorage, err := buildSessionDataStorage(ctx, cfg)
494+
sessionDataStorage, sessionDataStorageCloser, err := buildSessionDataStorage(ctx, cfg)
416495
if err != nil {
417496
return nil, fmt.Errorf("failed to create session data storage: %w", err)
418497
}
419-
// Close sessionDataStorage if New() returns an error after this point so the
420-
// background cleanup goroutine does not leak.
421-
closeStorageOnErr := true
498+
// If we built the store ourselves, close it when New() returns an error
499+
// after this point so the background cleanup goroutine does not leak.
500+
// For a caller-supplied store (sessionDataStorageCloser == nil), the
501+
// caller owns the lifecycle and we leave it untouched on every path.
502+
closeStorageOnErr := sessionDataStorageCloser != nil
422503
defer func() {
423504
if closeStorageOnErr {
424-
_ = sessionDataStorage.Close()
505+
_ = sessionDataStorageCloser(ctx)
425506
}
426507
}()
427508

@@ -486,6 +567,12 @@ func New(
486567
srv.shutdownFuncs = append(srv.shutdownFuncs, optimizerCleanup)
487568
}
488569

570+
// Store the session data storage closer on the Server so Stop() can invoke
571+
// it last (after session manager and discovery manager have stopped). For
572+
// a caller-supplied store this is nil and Stop() leaves it alone — the
573+
// caller owns the lifecycle.
574+
srv.sessionDataStorageCloser = sessionDataStorageCloser
575+
489576
// Register OnRegisterSession hook to inject capabilities after SDK registers session.
490577
// See handleSessionRegistration for implementation details.
491578
hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) {
@@ -848,8 +935,10 @@ func (s *Server) Stop(ctx context.Context) error {
848935

849936
// Close session data storage last: HTTP server is down (no new in-flight requests),
850937
// all other components have stopped (no further restore or liveness checks).
851-
if s.sessionDataStorage != nil {
852-
if err := s.sessionDataStorage.Close(); err != nil {
938+
// Only invoked when the server built the store itself; caller-supplied stores
939+
// (Config.DataStorage) are left for the caller to close.
940+
if s.sessionDataStorageCloser != nil {
941+
if err := s.sessionDataStorageCloser(ctx); err != nil {
853942
errs = append(errs, fmt.Errorf("failed to close session data storage: %w", err))
854943
}
855944
}

0 commit comments

Comments
 (0)