Skip to content

Commit 47c3651

Browse files
yroblataskbot
andauthored
Migrate session manager to DataStorage and add RestoreSession (#4464)
* Migrate session manager to DataStorage and add RestoreSession Replace the *transportsession.Manager dependency in sessionmanager.Manager with the DataStorage interface (introduced in split PR 1), enabling pluggable session metadata storage (local or Redis) without coupling live session state to the serialization layer. Key changes: - Add Exists() to the Storage interface with implementations for LocalStorage and RedisStorage; expose TTL() and Exists() on transportsession.Manager - Add RestoreSession() to MultiSessionFactory (and decorating/mock impls); refactor makeSession → makeBaseSession so RestoreSession can reconstruct a live session from stored metadata without a bearer token - Add RestoreHijackPrevention() in pkg/vmcp/session/internal/security to recreate the hijack-prevention decorator from stored hash/salt - Rewrite sessionmanager.Manager to use DataStorage: node-local multiSessions sync.Map for hot-path lookups, singleflight-deduplicated RestoreSession on cache miss, and a background eviction loop that probes storage.Exists() to clean up expired MultiSession objects whose Redis TTL fired silently - Replace *transportsession.Manager in discovery.Middleware with a MultiSessionGetter interface backed by the session manager; remove the 401 for unknown sessions — the SDK now responds 404 via Validate() - Wire server.go to create LocalSessionDataStorage and pass it to sessionmanager.New; route discovery middleware through vmcpSessionMgr Closes: #4220 * fixes from review --------- Co-authored-by: taskbot <taskbot@users.noreply.github.com>
1 parent c3fadd1 commit 47c3651

19 files changed

Lines changed: 2052 additions & 363 deletions

pkg/vmcp/discovery/middleware.go

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ import (
2525
"fmt"
2626
"log/slog"
2727
"net/http"
28-
"strings"
2928
"time"
3029

31-
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
3230
"github.com/stacklok/toolhive/pkg/vmcp"
3331
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
3432
"github.com/stacklok/toolhive/pkg/vmcp/health"
@@ -40,6 +38,13 @@ const (
4038
discoveryTimeout = 15 * time.Second
4139
)
4240

41+
// MultiSessionGetter retrieves a fully-formed MultiSession by session ID.
42+
// Returns (nil, false) if the session does not exist or has not yet been initialized.
43+
// This interface decouples the discovery middleware from the concrete session manager.
44+
type MultiSessionGetter interface {
45+
GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool)
46+
}
47+
4348
// middlewareConfig holds optional configuration for Middleware.
4449
type middlewareConfig struct {
4550
sessionScopedRouting bool
@@ -87,7 +92,7 @@ func WithDiscoveryTimeout(timeout time.Duration) MiddlewareOption {
8792
func Middleware(
8893
manager Manager,
8994
registry vmcp.BackendRegistry,
90-
sessionManager *transportsession.Manager,
95+
multiSessionGetter MultiSessionGetter,
9196
healthStatusProvider health.StatusProvider,
9297
opts ...MiddlewareOption,
9398
) func(http.Handler) http.Handler {
@@ -102,7 +107,6 @@ func Middleware(
102107
ctx := r.Context()
103108
sessionID := r.Header.Get("Mcp-Session-Id")
104109

105-
var err error
106110
if sessionID == "" {
107111
if cfg.sessionScopedRouting {
108112
// Session-scoped routing registers capabilities via the OnRegisterSession
@@ -111,15 +115,15 @@ func Middleware(
111115
return
112116
}
113117
// Initialize request: discover and cache capabilities in session.
118+
var err error
114119
ctx, err = handleInitializeRequest(ctx, r, manager, registry, healthStatusProvider, cfg.timeout)
120+
if err != nil {
121+
handleDiscoveryError(w, r, err)
122+
return
123+
}
115124
} else {
116-
// Subsequent request: retrieve cached capabilities from session.
117-
ctx, err = handleSubsequentRequest(ctx, r, sessionID, sessionManager)
118-
}
119-
120-
if err != nil {
121-
handleDiscoveryError(w, r, err)
122-
return
125+
// Subsequent request: inject routing context if the session is ready.
126+
ctx = handleSubsequentRequest(ctx, r, sessionID, multiSessionGetter)
123127
}
124128

125129
next.ServeHTTP(w, r.WithContext(ctx))
@@ -257,48 +261,29 @@ func handleInitializeRequest(
257261
}
258262

259263
// handleSubsequentRequest retrieves cached capabilities from the session.
260-
// Returns updated context with capabilities or an error.
264+
// Returns the updated context; never returns an error.
261265
func handleSubsequentRequest(
262266
ctx context.Context,
263267
r *http.Request,
264268
sessionID string,
265-
sessionManager *transportsession.Manager,
266-
) (context.Context, error) {
269+
multiSessionGetter MultiSessionGetter,
270+
) context.Context {
267271
//nolint:gosec // G706: session ID and request fields are not injection vectors
268272
slog.Debug("retrieving capabilities from session for subsequent request",
269273
"session_id", sessionID,
270274
"method", r.Method,
271275
"path", r.URL.Path)
272276

273-
// First, validate the session exists at all.
274-
rawSess, exists := sessionManager.Get(sessionID)
275-
if !exists {
276-
//nolint:gosec // G706: session ID is not an injection vector
277-
slog.Error("session not found",
278-
"session_id", sessionID,
279-
"method", r.Method,
280-
"path", r.URL.Path)
281-
return ctx, fmt.Errorf("session not found: %s", sessionID)
282-
}
283-
284-
// Backend tool handlers (created by DefaultHandlerFactory) resolve their backend
285-
// target by calling router.RouteTool(ctx, name), which reads DiscoveredCapabilities
286-
// from the request context. Inject capabilities built from the session's routing
287-
// table so these handlers can route correctly on subsequent requests.
288-
// Note: composite tool workflow engines are created per-session and route via
289-
// SessionRouter directly, so they no longer depend on this context value.
290-
multiSess, isMulti := rawSess.(vmcpsession.MultiSession)
291-
if !isMulti {
292-
// The session is still a StreamableSession placeholder — Phase 2
293-
// (OnRegisterSession / CreateSession) has not yet replaced it with a
294-
// MultiSession. This can happen if the client sends a request in the
295-
// brief window between receiving the session ID and the hook completing.
296-
// Skip capability injection and let the SDK respond (tools list will be
297-
// temporarily empty, but no 500 is returned to the client).
277+
// Look up the fully-formed MultiSession. Returns (nil, false) if the session does
278+
// not exist yet or is still a placeholder (CreateSession not yet complete). In either
279+
// case, skip capability injection and let the SDK validate/reject the request — the
280+
// SDK's own SessionIdManager.Validate() returns 404 for unknown session IDs.
281+
multiSess, ok := multiSessionGetter.GetMultiSession(sessionID)
282+
if !ok {
298283
//nolint:gosec // G706: session ID is not an injection vector
299-
slog.Debug("session initialisation in progress, skipping capability injection",
284+
slog.Debug("session not found or still initialising, skipping capability injection",
300285
"session_id", sessionID)
301-
return ctx, nil
286+
return ctx
302287
}
303288

304289
routingTable := multiSess.GetRoutingTable()
@@ -309,7 +294,7 @@ func handleSubsequentRequest(
309294
//nolint:gosec // G706: session ID is not an injection vector
310295
slog.Debug("multi-session routing table not yet initialised; skipping capability injection",
311296
"session_id", sessionID)
312-
return ctx, nil
297+
return ctx
313298
}
314299
//nolint:gosec // G706: session ID is not an injection vector
315300
slog.Debug("injecting capabilities from multi-session routing table for composite tool routing",
@@ -319,7 +304,7 @@ func handleSubsequentRequest(
319304
RoutingTable: routingTable,
320305
Tools: multiSess.Tools(),
321306
}
322-
return WithDiscoveredCapabilities(ctx, capabilities), nil
307+
return WithDiscoveredCapabilities(ctx, capabilities)
323308
}
324309

325310
// handleDiscoveryError writes appropriate HTTP error responses based on the error type.
@@ -329,13 +314,6 @@ func handleDiscoveryError(w http.ResponseWriter, _ *http.Request, err error) {
329314
return
330315
}
331316

332-
// Check for session-related errors
333-
errMsg := err.Error()
334-
if strings.Contains(errMsg, "session not found") {
335-
http.Error(w, "Session not found", http.StatusUnauthorized)
336-
return
337-
}
338-
339317
// Default to service unavailable for other errors
340318
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
341319
}

pkg/vmcp/discovery/middleware_test.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,32 @@ import (
1616
"github.com/stretchr/testify/require"
1717
"go.uber.org/mock/gomock"
1818

19-
transportsession "github.com/stacklok/toolhive/pkg/transport/session"
2019
"github.com/stacklok/toolhive/pkg/vmcp"
2120
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
2221
"github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
22+
vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
2323
sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks"
2424
)
2525

26-
// createTestSessionManager creates a session manager with StreamableSession factory for testing.
27-
func createTestSessionManager(t *testing.T) *transportsession.Manager {
28-
t.Helper()
29-
sessionMgr := transportsession.NewManager(30*time.Minute, transportsession.NewStreamableSession)
30-
t.Cleanup(func() { _ = sessionMgr.Stop() })
31-
return sessionMgr
26+
// Ensure stubMultiSessionGetter implements MultiSessionGetter.
27+
var _ MultiSessionGetter = (*stubMultiSessionGetter)(nil)
28+
29+
// stubMultiSessionGetter is a simple in-memory MultiSessionGetter for tests.
30+
type stubMultiSessionGetter struct {
31+
sessions map[string]vmcpsession.MultiSession
32+
}
33+
34+
func newStubMultiSessionGetter() *stubMultiSessionGetter {
35+
return &stubMultiSessionGetter{sessions: make(map[string]vmcpsession.MultiSession)}
36+
}
37+
38+
func (s *stubMultiSessionGetter) GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) {
39+
sess, ok := s.sessions[sessionID]
40+
return sess, ok
41+
}
42+
43+
func (s *stubMultiSessionGetter) add(sessionID string, sess vmcpsession.MultiSession) {
44+
s.sessions[sessionID] = sess
3245
}
3346

3447
// unorderedBackendsMatcher is a gomock matcher that compares backend slices without caring about order.
@@ -134,7 +147,7 @@ func TestMiddleware_InitializeRequest(t *testing.T) {
134147

135148
// Wrap handler with middleware
136149
backendRegistry := vmcp.NewImmutableRegistry(backends)
137-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil)
150+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil)
138151
wrappedHandler := middleware(testHandler)
139152

140153
// Create initialize request (no session ID header)
@@ -186,9 +199,6 @@ func TestMiddleware_SubsequentRequest_SkipsDiscovery(t *testing.T) {
186199
_, _ = w.Write([]byte("success"))
187200
})
188201

189-
// Create session manager and store routing table in a MultiSession
190-
sessionMgr := createTestSessionManager(t)
191-
192202
// Create a routing table for this session
193203
routingTable := &vmcp.RoutingTable{
194204
Tools: map[string]*vmcp.BackendTarget{"tool1": {WorkloadID: "backend1"}},
@@ -198,18 +208,11 @@ func TestMiddleware_SubsequentRequest_SkipsDiscovery(t *testing.T) {
198208

199209
// Add a MockMultiSession with the routing table
200210
mockSess := sessionmocks.NewMockMultiSession(ctrl)
201-
mockSess.EXPECT().ID().Return("dddddddd-1001-1001-1001-000000000001").AnyTimes()
202211
mockSess.EXPECT().GetRoutingTable().Return(routingTable).AnyTimes()
203212
mockSess.EXPECT().Tools().Return(nil).AnyTimes()
204-
mockSess.EXPECT().UpdatedAt().Return(time.Time{}).AnyTimes()
205-
mockSess.EXPECT().CreatedAt().Return(time.Time{}).AnyTimes()
206-
mockSess.EXPECT().Type().Return(transportsession.SessionType("")).AnyTimes()
207-
mockSess.EXPECT().GetData().Return(nil).AnyTimes()
208-
mockSess.EXPECT().SetData(gomock.Any()).AnyTimes()
209-
mockSess.EXPECT().GetMetadata().Return(nil).AnyTimes()
210-
mockSess.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes()
211-
err := sessionMgr.AddSession(mockSess)
212-
require.NoError(t, err, "failed to add session")
213+
214+
sessionMgr := newStubMultiSessionGetter()
215+
sessionMgr.add("dddddddd-1001-1001-1001-000000000001", mockSess)
213216

214217
// Wrap handler with middleware
215218
backendRegistry := vmcp.NewImmutableRegistry(backends)
@@ -254,7 +257,7 @@ func TestMiddleware_DiscoveryTimeout(t *testing.T) {
254257
})
255258

256259
backendRegistry := vmcp.NewImmutableRegistry(backends)
257-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil)
260+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil)
258261
wrappedHandler := middleware(testHandler)
259262

260263
// Initialize request (no session ID) - discovery should happen
@@ -295,7 +298,7 @@ func TestMiddleware_DiscoveryFailure(t *testing.T) {
295298
})
296299

297300
backendRegistry := vmcp.NewImmutableRegistry(backends)
298-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil)
301+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil)
299302
wrappedHandler := middleware(testHandler)
300303

301304
// Initialize request (no session ID) - discovery should happen
@@ -396,7 +399,7 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) {
396399
})
397400

398401
backendRegistry := vmcp.NewImmutableRegistry(backends)
399-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil)
402+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil)
400403
wrappedHandler := middleware(testHandler)
401404

402405
// Initialize request (no session ID) - discovery should happen
@@ -461,7 +464,7 @@ func TestMiddleware_PreservesUserContext(t *testing.T) {
461464
})
462465

463466
backendRegistry := vmcp.NewImmutableRegistry(backends)
464-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil)
467+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil)
465468
wrappedHandler := middleware(testHandler)
466469

467470
// Create initialize request with user context (as auth middleware would)
@@ -513,7 +516,7 @@ func TestMiddleware_ContextTimeoutHandling(t *testing.T) {
513516
})
514517

515518
backendRegistry := vmcp.NewImmutableRegistry(backends)
516-
middleware := Middleware(mockMgr, backendRegistry, createTestSessionManager(t), nil, WithDiscoveryTimeout(testTimeout))
519+
middleware := Middleware(mockMgr, backendRegistry, newStubMultiSessionGetter(), nil, WithDiscoveryTimeout(testTimeout))
517520
wrappedHandler := middleware(testHandler)
518521

519522
// Initialize request (no session ID) - discovery should happen

0 commit comments

Comments
 (0)