|
| 1 | +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package sessionmanager |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + "net/http" |
| 9 | + "net/http/httptest" |
| 10 | + "testing" |
| 11 | + "time" |
| 12 | + |
| 13 | + "github.com/alicebob/miniredis/v2" |
| 14 | + mcpmcp "github.com/mark3labs/mcp-go/mcp" |
| 15 | + mcpserver "github.com/mark3labs/mcp-go/server" |
| 16 | + "github.com/stretchr/testify/assert" |
| 17 | + "github.com/stretchr/testify/require" |
| 18 | + |
| 19 | + "github.com/stacklok/toolhive/pkg/auth" |
| 20 | + transportsession "github.com/stacklok/toolhive/pkg/transport/session" |
| 21 | + "github.com/stacklok/toolhive/pkg/vmcp" |
| 22 | + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" |
| 23 | + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" |
| 24 | + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" |
| 25 | + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" |
| 26 | + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" |
| 27 | +) |
| 28 | + |
| 29 | +// hmacSecret is a fixed 32-byte secret used across all integration tests. |
| 30 | +var hmacSecret = []byte("test-hmac-secret-exactly-32bytes") |
| 31 | + |
| 32 | +// --------------------------------------------------------------------------- |
| 33 | +// Helpers |
| 34 | +// --------------------------------------------------------------------------- |
| 35 | + |
| 36 | +// newUnauthenticatedAuthRegistry builds an OutgoingAuthRegistry with only the |
| 37 | +// unauthenticated strategy registered — suitable for tests whose backend MCP |
| 38 | +// servers require no auth. |
| 39 | +func newUnauthenticatedAuthRegistry(t *testing.T) vmcpauth.OutgoingAuthRegistry { |
| 40 | + t.Helper() |
| 41 | + reg := vmcpauth.NewDefaultOutgoingAuthRegistry() |
| 42 | + require.NoError(t, reg.RegisterStrategy(authtypes.StrategyTypeUnauthenticated, strategies.NewUnauthenticatedStrategy())) |
| 43 | + return reg |
| 44 | +} |
| 45 | + |
| 46 | +// newSharedRedisStorage creates a RedisSessionDataStorage pointing at mr. |
| 47 | +// The storage is closed via t.Cleanup. |
| 48 | +func newSharedRedisStorage(t *testing.T, mr *miniredis.Miniredis) transportsession.DataStorage { |
| 49 | + t.Helper() |
| 50 | + storage, err := transportsession.NewRedisSessionDataStorage( |
| 51 | + context.Background(), |
| 52 | + transportsession.RedisConfig{ |
| 53 | + Addr: mr.Addr(), |
| 54 | + KeyPrefix: "test:vmcp:session:", |
| 55 | + }, |
| 56 | + time.Hour, |
| 57 | + ) |
| 58 | + require.NoError(t, err) |
| 59 | + t.Cleanup(func() { _ = storage.Close() }) |
| 60 | + return storage |
| 61 | +} |
| 62 | + |
| 63 | +// newTestManagerWithSharedStorage creates a Manager backed by the given |
| 64 | +// DataStorage, a real session factory with the package-level hmacSecret, and |
| 65 | +// an ImmutableRegistry containing backends. Cleanup is registered via |
| 66 | +// t.Cleanup. |
| 67 | +func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.DataStorage, backends []*vmcp.Backend) *Manager { |
| 68 | + t.Helper() |
| 69 | + backendList := make([]vmcp.Backend, len(backends)) |
| 70 | + for i, b := range backends { |
| 71 | + backendList[i] = *b |
| 72 | + } |
| 73 | + registry := vmcp.NewImmutableRegistry(backendList) |
| 74 | + factory := vmcpsession.NewSessionFactory( |
| 75 | + newUnauthenticatedAuthRegistry(t), |
| 76 | + vmcpsession.WithHMACSecret(hmacSecret), |
| 77 | + ) |
| 78 | + sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) |
| 79 | + require.NoError(t, err) |
| 80 | + t.Cleanup(func() { require.NoError(t, cleanup(context.Background())) }) |
| 81 | + return sm |
| 82 | +} |
| 83 | + |
| 84 | +// createSession runs the two-phase Generate + CreateSession flow. |
| 85 | +// identity may be nil for anonymous sessions. |
| 86 | +// Returns the assigned session ID. |
| 87 | +func createSession(t *testing.T, sm *Manager, identity *auth.Identity) string { |
| 88 | + t.Helper() |
| 89 | + sessionID := sm.Generate() |
| 90 | + require.NotEmpty(t, sessionID) |
| 91 | + ctx := context.Background() |
| 92 | + if identity != nil { |
| 93 | + ctx = auth.WithIdentity(ctx, identity) |
| 94 | + } |
| 95 | + _, err := sm.CreateSession(ctx, sessionID) |
| 96 | + require.NoError(t, err) |
| 97 | + return sessionID |
| 98 | +} |
| 99 | + |
| 100 | +// startMCPBackend starts an in-process streamable-HTTP MCP server that |
| 101 | +// exposes a single tool named toolName (which echoes its "input" argument). |
| 102 | +// The server is shut down when t completes. |
| 103 | +// Returns a *vmcp.Backend pointing at the server. |
| 104 | +func startMCPBackend(t *testing.T, backendID, toolName string) *vmcp.Backend { |
| 105 | + t.Helper() |
| 106 | + mcpSrv := mcpserver.NewMCPServer(backendID, "1.0.0") |
| 107 | + mcpSrv.AddTool( |
| 108 | + mcpmcp.NewTool(toolName, |
| 109 | + mcpmcp.WithDescription("Echoes the input argument"), |
| 110 | + mcpmcp.WithString("input", mcpmcp.Required()), |
| 111 | + ), |
| 112 | + func(_ context.Context, req mcpmcp.CallToolRequest) (*mcpmcp.CallToolResult, error) { |
| 113 | + args, _ := req.Params.Arguments.(map[string]any) |
| 114 | + input, _ := args["input"].(string) |
| 115 | + return &mcpmcp.CallToolResult{ |
| 116 | + Content: []mcpmcp.Content{mcpmcp.NewTextContent(input)}, |
| 117 | + }, nil |
| 118 | + }, |
| 119 | + ) |
| 120 | + streamableSrv := mcpserver.NewStreamableHTTPServer(mcpSrv) |
| 121 | + mux := http.NewServeMux() |
| 122 | + mux.Handle("/mcp", streamableSrv) |
| 123 | + ts := httptest.NewServer(mux) |
| 124 | + t.Cleanup(ts.Close) |
| 125 | + return &vmcp.Backend{ |
| 126 | + ID: backendID, |
| 127 | + Name: backendID, |
| 128 | + BaseURL: ts.URL + "/mcp", |
| 129 | + TransportType: "streamable-http", |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +// startStoppableMCPBackend is like startMCPBackend but returns a stop |
| 134 | +// function instead of registering t.Cleanup. The caller is responsible for |
| 135 | +// calling stop (e.g. to simulate a backend going away mid-test). |
| 136 | +func startStoppableMCPBackend(t *testing.T, backendID, toolName string) (*vmcp.Backend, func()) { |
| 137 | + t.Helper() |
| 138 | + mcpSrv := mcpserver.NewMCPServer(backendID, "1.0.0") |
| 139 | + mcpSrv.AddTool( |
| 140 | + mcpmcp.NewTool(toolName, |
| 141 | + mcpmcp.WithDescription("Echoes the input argument"), |
| 142 | + mcpmcp.WithString("input", mcpmcp.Required()), |
| 143 | + ), |
| 144 | + func(_ context.Context, req mcpmcp.CallToolRequest) (*mcpmcp.CallToolResult, error) { |
| 145 | + args, _ := req.Params.Arguments.(map[string]any) |
| 146 | + input, _ := args["input"].(string) |
| 147 | + return &mcpmcp.CallToolResult{ |
| 148 | + Content: []mcpmcp.Content{mcpmcp.NewTextContent(input)}, |
| 149 | + }, nil |
| 150 | + }, |
| 151 | + ) |
| 152 | + streamableSrv := mcpserver.NewStreamableHTTPServer(mcpSrv) |
| 153 | + mux := http.NewServeMux() |
| 154 | + mux.Handle("/mcp", streamableSrv) |
| 155 | + ts := httptest.NewServer(mux) |
| 156 | + return &vmcp.Backend{ |
| 157 | + ID: backendID, |
| 158 | + Name: backendID, |
| 159 | + BaseURL: ts.URL + "/mcp", |
| 160 | + TransportType: "streamable-http", |
| 161 | + }, ts.Close |
| 162 | +} |
| 163 | + |
| 164 | +// --------------------------------------------------------------------------- |
| 165 | +// AC1: Cross-pod session reconstruction |
| 166 | +// --------------------------------------------------------------------------- |
| 167 | + |
| 168 | +// TestHorizontalScaling_CrossPodReconstruction verifies that a session |
| 169 | +// created on "pod A" (Manager A) can be reconstructed on "pod B" (Manager B) |
| 170 | +// via GetMultiSession → cache miss → RestoreSession from Redis. |
| 171 | +func TestHorizontalScaling_CrossPodReconstruction(t *testing.T) { |
| 172 | + t.Parallel() |
| 173 | + |
| 174 | + mr := miniredis.RunT(t) |
| 175 | + storage := newSharedRedisStorage(t, mr) |
| 176 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 177 | + |
| 178 | + // Pod A: create a session; it is stored in Redis and cached locally in smA. |
| 179 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 180 | + sessionID := createSession(t, smA, nil) |
| 181 | + |
| 182 | + // Pod B: fresh Manager, same Redis storage — session is NOT in local cache. |
| 183 | + smB := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 184 | + |
| 185 | + // GetMultiSession triggers cache miss → loadSession → RestoreSession from Redis. |
| 186 | + sess, ok := smB.GetMultiSession(sessionID) |
| 187 | + require.True(t, ok, "pod B must reconstruct the session from Redis on cache miss") |
| 188 | + require.NotNil(t, sess) |
| 189 | + |
| 190 | + // The restored session must have reconnected to the backend and discovered tools. |
| 191 | + assert.NotEmpty(t, sess.Tools(), "restored session must have the backend's tools") |
| 192 | + assert.Equal(t, "echo", sess.Tools()[0].Name) |
| 193 | +} |
| 194 | + |
| 195 | +// --------------------------------------------------------------------------- |
| 196 | +// AC2: Cross-pod hijack prevention |
| 197 | +// --------------------------------------------------------------------------- |
| 198 | + |
| 199 | +// TestHorizontalScaling_CrossPodHijackPrevention verifies that: |
| 200 | +// - A session bound to alice on pod A can be reconstructed on pod B. |
| 201 | +// - After restoration, a wrong-token caller is rejected (ErrUnauthorizedCaller). |
| 202 | +// - A nil caller is rejected (ErrNilCaller). |
| 203 | +// - The original caller (correct token) is not rejected at the auth layer. |
| 204 | +func TestHorizontalScaling_CrossPodHijackPrevention(t *testing.T) { |
| 205 | + t.Parallel() |
| 206 | + |
| 207 | + mr := miniredis.RunT(t) |
| 208 | + storage := newSharedRedisStorage(t, mr) |
| 209 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 210 | + |
| 211 | + identity := &auth.Identity{ |
| 212 | + PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, |
| 213 | + Token: "alice-bearer-token", |
| 214 | + } |
| 215 | + wrongCaller := &auth.Identity{ |
| 216 | + PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, |
| 217 | + Token: "eve-bearer-token", |
| 218 | + } |
| 219 | + |
| 220 | + // Pod A: create session bound to alice. |
| 221 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 222 | + sessionID := createSession(t, smA, identity) |
| 223 | + |
| 224 | + // Pod B: restore from Redis. |
| 225 | + smB := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 226 | + sess, ok := smB.GetMultiSession(sessionID) |
| 227 | + require.True(t, ok, "session must be restorable on pod B") |
| 228 | + require.NotNil(t, sess) |
| 229 | + |
| 230 | + ctx := context.Background() |
| 231 | + |
| 232 | + // Wrong caller must be rejected before any backend routing. |
| 233 | + _, err := sess.CallTool(ctx, wrongCaller, "echo", map[string]any{"input": "hi"}, nil) |
| 234 | + assert.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, "wrong token must be rejected") |
| 235 | + |
| 236 | + // Nil caller must be rejected. |
| 237 | + _, err = sess.CallTool(ctx, nil, "echo", map[string]any{"input": "hi"}, nil) |
| 238 | + assert.ErrorIs(t, err, sessiontypes.ErrNilCaller, "nil caller must be rejected") |
| 239 | + |
| 240 | + // Original caller must pass auth and successfully route to the backend. |
| 241 | + // The backend is still running, so the call must complete without error. |
| 242 | + _, err = sess.CallTool(ctx, identity, "echo", map[string]any{"input": "hi"}, nil) |
| 243 | + require.NoError(t, err, "correct caller must be able to invoke the tool after restore") |
| 244 | + assert.NotErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, "correct token must pass auth") |
| 245 | + assert.NotErrorIs(t, err, sessiontypes.ErrNilCaller, "correct token must pass nil-caller check") |
| 246 | +} |
| 247 | + |
| 248 | +// --------------------------------------------------------------------------- |
| 249 | +// AC3 is intentionally omitted: LRU eviction (RC-10, issue #4221) was dropped |
| 250 | +// in favour of TTL-based Redis eviction. |
| 251 | +// --------------------------------------------------------------------------- |
| 252 | + |
| 253 | +// --------------------------------------------------------------------------- |
| 254 | +// AC4: All backends fail during RestoreSession → empty routing table |
| 255 | +// --------------------------------------------------------------------------- |
| 256 | + |
| 257 | +// TestHorizontalScaling_AllBackendsFailOnRestore verifies that when all |
| 258 | +// backends are unreachable at restore time, GetMultiSession still returns a |
| 259 | +// valid (non-nil) session with an empty routing table — consistent with the |
| 260 | +// makeSession partial-failure behaviour documented in the spec. |
| 261 | +func TestHorizontalScaling_AllBackendsFailOnRestore(t *testing.T) { |
| 262 | + t.Parallel() |
| 263 | + |
| 264 | + mr := miniredis.RunT(t) |
| 265 | + storage := newSharedRedisStorage(t, mr) |
| 266 | + |
| 267 | + // Use a stoppable backend so we can shut it down mid-test. |
| 268 | + backend, stopBackend := startStoppableMCPBackend(t, "backend-alpha", "echo") |
| 269 | + |
| 270 | + sm := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 271 | + sessionID := createSession(t, sm, nil) |
| 272 | + |
| 273 | + // Stop the backend — RestoreSession will be unable to reconnect. |
| 274 | + stopBackend() |
| 275 | + |
| 276 | + // Evict from the local cache so the next Get takes the restore path. |
| 277 | + sm.sessions.Delete(sessionID) |
| 278 | + |
| 279 | + // GetMultiSession must return a session (not false/nil) even though the |
| 280 | + // backend is unreachable — the session comes back with an empty tool list. |
| 281 | + sess, ok := sm.GetMultiSession(sessionID) |
| 282 | + require.True(t, ok, "GetMultiSession must return ok=true even when backends are unreachable") |
| 283 | + require.NotNil(t, sess) |
| 284 | + assert.Empty(t, sess.Tools(), "routing table must be empty when no backend reconnected") |
| 285 | +} |
| 286 | + |
| 287 | +// --------------------------------------------------------------------------- |
| 288 | +// AC5: RC-16 backend expiry — NotifyBackendExpired removes metadata; |
| 289 | +// subsequent RestoreSession skips the expired backend. |
| 290 | +// --------------------------------------------------------------------------- |
| 291 | + |
| 292 | +// TestHorizontalScaling_BackendExpiry_SkipsExpiredOnRestore verifies that |
| 293 | +// after NotifyBackendExpired removes a backend from Redis metadata, a |
| 294 | +// subsequent RestoreSession on a different pod only connects to the remaining |
| 295 | +// backend and does not include the expired backend's tools. |
| 296 | +func TestHorizontalScaling_BackendExpiry_SkipsExpiredOnRestore(t *testing.T) { |
| 297 | + t.Parallel() |
| 298 | + |
| 299 | + mr := miniredis.RunT(t) |
| 300 | + storage := newSharedRedisStorage(t, mr) |
| 301 | + |
| 302 | + // Two backends with distinct tool names so we can tell them apart. |
| 303 | + backendA := startMCPBackend(t, "backend-alpha", "tool-alpha") |
| 304 | + backendB := startMCPBackend(t, "backend-beta", "tool-beta") |
| 305 | + |
| 306 | + // Pod A: create session connected to both backends. |
| 307 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backendA, backendB}) |
| 308 | + sessionID := createSession(t, smA, nil) |
| 309 | + |
| 310 | + // Verify session A has tools from both backends before expiry. |
| 311 | + sessA, ok := smA.GetMultiSession(sessionID) |
| 312 | + require.True(t, ok) |
| 313 | + toolNames := make(map[string]bool) |
| 314 | + for _, tool := range sessA.Tools() { |
| 315 | + toolNames[tool.Name] = true |
| 316 | + } |
| 317 | + require.True(t, toolNames["tool-alpha"], "session A must have tool-alpha before expiry") |
| 318 | + require.True(t, toolNames["tool-beta"], "session A must have tool-beta before expiry") |
| 319 | + |
| 320 | + // NotifyBackendExpired updates Redis to remove backend-beta and evicts from pod A's cache. |
| 321 | + smA.NotifyBackendExpired(sessionID, backendB.ID) |
| 322 | + |
| 323 | + // Pod C: fresh Manager, same storage and both backends in registry. |
| 324 | + // (backendB is still running — we're testing that RestoreSession filters |
| 325 | + // it out based on the updated Redis metadata, not because it's unreachable.) |
| 326 | + smC := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backendA, backendB}) |
| 327 | + sessC, ok := smC.GetMultiSession(sessionID) |
| 328 | + require.True(t, ok, "session must be restorable after NotifyBackendExpired") |
| 329 | + require.NotNil(t, sessC) |
| 330 | + |
| 331 | + // Restored session must only have tool-alpha; tool-beta was filtered out. |
| 332 | + restoredTools := make(map[string]bool) |
| 333 | + for _, tool := range sessC.Tools() { |
| 334 | + restoredTools[tool.Name] = true |
| 335 | + } |
| 336 | + assert.True(t, restoredTools["tool-alpha"], "restored session must have tool-alpha") |
| 337 | + assert.False(t, restoredTools["tool-beta"], "restored session must NOT have tool-beta after expiry") |
| 338 | +} |
| 339 | + |
| 340 | +// --------------------------------------------------------------------------- |
| 341 | +// AC6: In-memory-only mode (no Redis) — no cross-pod sharing |
| 342 | +// --------------------------------------------------------------------------- |
| 343 | + |
| 344 | +// TestHorizontalScaling_InMemoryOnlyMode verifies that when Redis is not |
| 345 | +// configured (LocalSessionDataStorage), sessions are not visible to a second |
| 346 | +// Manager instance, and single-pod usage continues to work correctly. |
| 347 | +func TestHorizontalScaling_InMemoryOnlyMode(t *testing.T) { |
| 348 | + t.Parallel() |
| 349 | + |
| 350 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 351 | + |
| 352 | + newLocalStorage := func(t *testing.T) transportsession.DataStorage { |
| 353 | + t.Helper() |
| 354 | + s, err := transportsession.NewLocalSessionDataStorage(time.Hour) |
| 355 | + require.NoError(t, err) |
| 356 | + t.Cleanup(func() { _ = s.Close() }) |
| 357 | + return s |
| 358 | + } |
| 359 | + |
| 360 | + // Pod A and pod B each have their own local storage — no sharing. |
| 361 | + storageA := newLocalStorage(t) |
| 362 | + storageB := newLocalStorage(t) |
| 363 | + |
| 364 | + smA := newTestManagerWithSharedStorage(t, storageA, []*vmcp.Backend{backend}) |
| 365 | + smB := newTestManagerWithSharedStorage(t, storageB, []*vmcp.Backend{backend}) |
| 366 | + |
| 367 | + sessionID := createSession(t, smA, nil) |
| 368 | + |
| 369 | + // Pod B must not be able to see pod A's session. |
| 370 | + _, ok := smB.GetMultiSession(sessionID) |
| 371 | + assert.False(t, ok, "in-memory-only: pod B must not see pod A's session") |
| 372 | + |
| 373 | + // Single-pod usage on pod A must still work. |
| 374 | + sess, ok := smA.GetMultiSession(sessionID) |
| 375 | + require.True(t, ok, "pod A must still serve its own session") |
| 376 | + require.NotNil(t, sess) |
| 377 | + assert.NotEmpty(t, sess.Tools(), "session on pod A must have tools") |
| 378 | +} |
0 commit comments