|
| 1 | +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package sessionmanager |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + "net/http" |
| 9 | + "net/http/httptest" |
| 10 | + "sync" |
| 11 | + "testing" |
| 12 | + "time" |
| 13 | + |
| 14 | + "github.com/alicebob/miniredis/v2" |
| 15 | + mcpmcp "github.com/mark3labs/mcp-go/mcp" |
| 16 | + mcpserver "github.com/mark3labs/mcp-go/server" |
| 17 | + "github.com/stretchr/testify/assert" |
| 18 | + "github.com/stretchr/testify/require" |
| 19 | + |
| 20 | + "github.com/stacklok/toolhive/pkg/auth" |
| 21 | + transportsession "github.com/stacklok/toolhive/pkg/transport/session" |
| 22 | + "github.com/stacklok/toolhive/pkg/vmcp" |
| 23 | + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" |
| 24 | + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" |
| 25 | + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" |
| 26 | + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" |
| 27 | + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" |
| 28 | +) |
| 29 | + |
| 30 | +// hmacSecret is a fixed 32-byte secret used across all integration tests. |
| 31 | +var hmacSecret = []byte("test-hmac-secret-32bytes-exactly") |
| 32 | + |
| 33 | +// --------------------------------------------------------------------------- |
| 34 | +// Helpers |
| 35 | +// --------------------------------------------------------------------------- |
| 36 | + |
| 37 | +// newUnauthenticatedAuthRegistry builds an OutgoingAuthRegistry with only the |
| 38 | +// unauthenticated strategy registered — suitable for tests whose backend MCP |
| 39 | +// servers require no auth. |
| 40 | +func newUnauthenticatedAuthRegistry(t *testing.T) vmcpauth.OutgoingAuthRegistry { |
| 41 | + t.Helper() |
| 42 | + reg := vmcpauth.NewDefaultOutgoingAuthRegistry() |
| 43 | + require.NoError(t, reg.RegisterStrategy(authtypes.StrategyTypeUnauthenticated, strategies.NewUnauthenticatedStrategy())) |
| 44 | + return reg |
| 45 | +} |
| 46 | + |
| 47 | +// newSharedRedisStorage creates a RedisSessionDataStorage pointing at mr. |
| 48 | +// The storage is closed via t.Cleanup. |
| 49 | +func newSharedRedisStorage(t *testing.T, mr *miniredis.Miniredis) transportsession.DataStorage { |
| 50 | + t.Helper() |
| 51 | + storage, err := transportsession.NewRedisSessionDataStorage( |
| 52 | + context.Background(), |
| 53 | + transportsession.RedisConfig{ |
| 54 | + Addr: mr.Addr(), |
| 55 | + KeyPrefix: "test:vmcp:session:", |
| 56 | + }, |
| 57 | + time.Hour, |
| 58 | + ) |
| 59 | + require.NoError(t, err) |
| 60 | + t.Cleanup(func() { _ = storage.Close() }) |
| 61 | + return storage |
| 62 | +} |
| 63 | + |
| 64 | +// newTestManagerWithSharedStorage creates a Manager backed by the given |
| 65 | +// DataStorage, a real session factory with the package-level hmacSecret, and |
| 66 | +// an ImmutableRegistry containing backends. Cleanup is registered via |
| 67 | +// t.Cleanup. |
| 68 | +func newTestManagerWithSharedStorage(t *testing.T, storage transportsession.DataStorage, backends []*vmcp.Backend) *Manager { |
| 69 | + t.Helper() |
| 70 | + backendList := make([]vmcp.Backend, len(backends)) |
| 71 | + for i, b := range backends { |
| 72 | + backendList[i] = *b |
| 73 | + } |
| 74 | + registry := vmcp.NewImmutableRegistry(backendList) |
| 75 | + factory := vmcpsession.NewSessionFactory( |
| 76 | + newUnauthenticatedAuthRegistry(t), |
| 77 | + vmcpsession.WithHMACSecret(hmacSecret), |
| 78 | + ) |
| 79 | + sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) |
| 80 | + require.NoError(t, err) |
| 81 | + t.Cleanup(func() { require.NoError(t, cleanup(context.Background())) }) |
| 82 | + return sm |
| 83 | +} |
| 84 | + |
| 85 | +// createSession runs the two-phase Generate + CreateSession flow. |
| 86 | +// identity may be nil for anonymous sessions. |
| 87 | +// Returns the assigned session ID. |
| 88 | +func createSession(t *testing.T, sm *Manager, identity *auth.Identity) string { |
| 89 | + t.Helper() |
| 90 | + sessionID := sm.Generate() |
| 91 | + require.NotEmpty(t, sessionID) |
| 92 | + ctx := context.Background() |
| 93 | + if identity != nil { |
| 94 | + ctx = auth.WithIdentity(ctx, identity) |
| 95 | + } |
| 96 | + _, err := sm.CreateSession(ctx, sessionID) |
| 97 | + require.NoError(t, err) |
| 98 | + return sessionID |
| 99 | +} |
| 100 | + |
| 101 | +// startMCPBackend starts an in-process streamable-HTTP MCP server that |
| 102 | +// exposes a single tool named toolName (which echoes its "input" argument). |
| 103 | +// The server is shut down when t completes. |
| 104 | +// Returns a *vmcp.Backend pointing at the server. |
| 105 | +func startMCPBackend(t *testing.T, backendID, toolName string) *vmcp.Backend { |
| 106 | + t.Helper() |
| 107 | + mcpSrv := mcpserver.NewMCPServer(backendID, "1.0.0") |
| 108 | + mcpSrv.AddTool( |
| 109 | + mcpmcp.NewTool(toolName, |
| 110 | + mcpmcp.WithDescription("Echoes the input argument"), |
| 111 | + mcpmcp.WithString("input", mcpmcp.Required()), |
| 112 | + ), |
| 113 | + func(_ context.Context, req mcpmcp.CallToolRequest) (*mcpmcp.CallToolResult, error) { |
| 114 | + args, _ := req.Params.Arguments.(map[string]any) |
| 115 | + input, _ := args["input"].(string) |
| 116 | + return &mcpmcp.CallToolResult{ |
| 117 | + Content: []mcpmcp.Content{mcpmcp.NewTextContent(input)}, |
| 118 | + }, nil |
| 119 | + }, |
| 120 | + ) |
| 121 | + streamableSrv := mcpserver.NewStreamableHTTPServer(mcpSrv) |
| 122 | + mux := http.NewServeMux() |
| 123 | + mux.Handle("/mcp", streamableSrv) |
| 124 | + ts := httptest.NewServer(mux) |
| 125 | + t.Cleanup(ts.Close) |
| 126 | + return &vmcp.Backend{ |
| 127 | + ID: backendID, |
| 128 | + Name: backendID, |
| 129 | + BaseURL: ts.URL + "/mcp", |
| 130 | + TransportType: "streamable-http", |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +// startStoppableMCPBackend is like startMCPBackend but also returns a stop |
| 135 | +// function so the caller can shut the backend down mid-test (e.g. to simulate |
| 136 | +// a backend going away). The stop function is idempotent and is also |
| 137 | +// registered with t.Cleanup so the server is always closed even if the test |
| 138 | +// fails before the caller invokes stop. |
| 139 | +func startStoppableMCPBackend(t *testing.T, backendID, toolName string) (*vmcp.Backend, func()) { |
| 140 | + t.Helper() |
| 141 | + mcpSrv := mcpserver.NewMCPServer(backendID, "1.0.0") |
| 142 | + mcpSrv.AddTool( |
| 143 | + mcpmcp.NewTool(toolName, |
| 144 | + mcpmcp.WithDescription("Echoes the input argument"), |
| 145 | + mcpmcp.WithString("input", mcpmcp.Required()), |
| 146 | + ), |
| 147 | + func(_ context.Context, req mcpmcp.CallToolRequest) (*mcpmcp.CallToolResult, error) { |
| 148 | + args, _ := req.Params.Arguments.(map[string]any) |
| 149 | + input, _ := args["input"].(string) |
| 150 | + return &mcpmcp.CallToolResult{ |
| 151 | + Content: []mcpmcp.Content{mcpmcp.NewTextContent(input)}, |
| 152 | + }, nil |
| 153 | + }, |
| 154 | + ) |
| 155 | + streamableSrv := mcpserver.NewStreamableHTTPServer(mcpSrv) |
| 156 | + mux := http.NewServeMux() |
| 157 | + mux.Handle("/mcp", streamableSrv) |
| 158 | + ts := httptest.NewServer(mux) |
| 159 | + var once sync.Once |
| 160 | + stop := func() { once.Do(ts.Close) } |
| 161 | + t.Cleanup(stop) |
| 162 | + return &vmcp.Backend{ |
| 163 | + ID: backendID, |
| 164 | + Name: backendID, |
| 165 | + BaseURL: ts.URL + "/mcp", |
| 166 | + TransportType: "streamable-http", |
| 167 | + }, stop |
| 168 | +} |
| 169 | + |
| 170 | +// --------------------------------------------------------------------------- |
| 171 | +// AC1: Cross-pod session reconstruction |
| 172 | +// --------------------------------------------------------------------------- |
| 173 | + |
| 174 | +// TestHorizontalScaling_CrossPodReconstruction verifies that a session |
| 175 | +// created on "pod A" (Manager A) can be reconstructed on "pod B" (Manager B) |
| 176 | +// via GetMultiSession → cache miss → RestoreSession from Redis. |
| 177 | +func TestHorizontalScaling_CrossPodReconstruction(t *testing.T) { |
| 178 | + t.Parallel() |
| 179 | + |
| 180 | + mr := miniredis.RunT(t) |
| 181 | + storage := newSharedRedisStorage(t, mr) |
| 182 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 183 | + |
| 184 | + // Pod A: create a session; it is stored in Redis and cached locally in smA. |
| 185 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 186 | + sessionID := createSession(t, smA, nil) |
| 187 | + |
| 188 | + // Pod B: fresh Manager, same Redis storage — session is NOT in local cache. |
| 189 | + smB := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 190 | + |
| 191 | + // GetMultiSession triggers cache miss → loadSession → RestoreSession from Redis. |
| 192 | + sess, ok := smB.GetMultiSession(sessionID) |
| 193 | + require.True(t, ok, "pod B must reconstruct the session from Redis on cache miss") |
| 194 | + require.NotNil(t, sess) |
| 195 | + |
| 196 | + // The restored session must have reconnected to the backend and discovered tools. |
| 197 | + require.NotEmpty(t, sess.Tools(), "restored session must have the backend's tools") |
| 198 | + assert.Equal(t, "echo", sess.Tools()[0].Name) |
| 199 | +} |
| 200 | + |
| 201 | +// --------------------------------------------------------------------------- |
| 202 | +// AC2: Cross-pod hijack prevention |
| 203 | +// --------------------------------------------------------------------------- |
| 204 | + |
| 205 | +// TestHorizontalScaling_CrossPodHijackPrevention verifies that: |
| 206 | +// - A session bound to alice on pod A can be reconstructed on pod B. |
| 207 | +// - After restoration, a wrong-token caller is rejected (ErrUnauthorizedCaller). |
| 208 | +// - A nil caller is rejected (ErrNilCaller). |
| 209 | +// - The original caller (correct token) is not rejected at the auth layer. |
| 210 | +func TestHorizontalScaling_CrossPodHijackPrevention(t *testing.T) { |
| 211 | + t.Parallel() |
| 212 | + |
| 213 | + mr := miniredis.RunT(t) |
| 214 | + storage := newSharedRedisStorage(t, mr) |
| 215 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 216 | + |
| 217 | + identity := &auth.Identity{ |
| 218 | + PrincipalInfo: auth.PrincipalInfo{Subject: "alice"}, |
| 219 | + Token: "alice-bearer-token", |
| 220 | + } |
| 221 | + wrongCaller := &auth.Identity{ |
| 222 | + PrincipalInfo: auth.PrincipalInfo{Subject: "eve"}, |
| 223 | + Token: "eve-bearer-token", |
| 224 | + } |
| 225 | + |
| 226 | + // Pod A: create session bound to alice. |
| 227 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 228 | + sessionID := createSession(t, smA, identity) |
| 229 | + |
| 230 | + // Pod B: restore from Redis. |
| 231 | + smB := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 232 | + sess, ok := smB.GetMultiSession(sessionID) |
| 233 | + require.True(t, ok, "session must be restorable on pod B") |
| 234 | + require.NotNil(t, sess) |
| 235 | + |
| 236 | + ctx := context.Background() |
| 237 | + |
| 238 | + // Wrong caller must be rejected before any backend routing. |
| 239 | + _, err := sess.CallTool(ctx, wrongCaller, "echo", map[string]any{"input": "hi"}, nil) |
| 240 | + assert.ErrorIs(t, err, sessiontypes.ErrUnauthorizedCaller, "wrong token must be rejected") |
| 241 | + |
| 242 | + // Nil caller must be rejected. |
| 243 | + _, err = sess.CallTool(ctx, nil, "echo", map[string]any{"input": "hi"}, nil) |
| 244 | + assert.ErrorIs(t, err, sessiontypes.ErrNilCaller, "nil caller must be rejected") |
| 245 | + |
| 246 | + // Original caller must pass auth and successfully route to the backend. |
| 247 | + // The backend is still running, so the call must complete without error. |
| 248 | + _, err = sess.CallTool(ctx, identity, "echo", map[string]any{"input": "hi"}, nil) |
| 249 | + require.NoError(t, err, "correct caller must be able to invoke the tool after restore") |
| 250 | +} |
| 251 | + |
| 252 | +// --------------------------------------------------------------------------- |
| 253 | +// AC3 is intentionally omitted: LRU eviction (RC-10, issue #4221) was dropped |
| 254 | +// in favour of TTL-based Redis eviction. |
| 255 | +// --------------------------------------------------------------------------- |
| 256 | + |
| 257 | +// --------------------------------------------------------------------------- |
| 258 | +// AC4: All backends fail during RestoreSession → empty routing table |
| 259 | +// --------------------------------------------------------------------------- |
| 260 | + |
| 261 | +// TestHorizontalScaling_AllBackendsFailOnRestore verifies that when all |
| 262 | +// backends are unreachable at restore time, GetMultiSession still returns a |
| 263 | +// valid (non-nil) session with an empty routing table — consistent with the |
| 264 | +// makeSession partial-failure behaviour documented in the spec. |
| 265 | +func TestHorizontalScaling_AllBackendsFailOnRestore(t *testing.T) { |
| 266 | + t.Parallel() |
| 267 | + |
| 268 | + mr := miniredis.RunT(t) |
| 269 | + storage := newSharedRedisStorage(t, mr) |
| 270 | + |
| 271 | + // Use a stoppable backend so we can shut it down mid-test. |
| 272 | + backend, stopBackend := startStoppableMCPBackend(t, "backend-alpha", "echo") |
| 273 | + |
| 274 | + sm := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backend}) |
| 275 | + sessionID := createSession(t, sm, nil) |
| 276 | + |
| 277 | + // Stop the backend — RestoreSession will be unable to reconnect. |
| 278 | + stopBackend() |
| 279 | + |
| 280 | + // Evict from the local cache so the next Get takes the restore path. |
| 281 | + sm.sessions.Delete(sessionID) |
| 282 | + |
| 283 | + // GetMultiSession must return a session (not false/nil) even though the |
| 284 | + // backend is unreachable — the session comes back with an empty tool list. |
| 285 | + sess, ok := sm.GetMultiSession(sessionID) |
| 286 | + require.True(t, ok, "GetMultiSession must return ok=true even when backends are unreachable") |
| 287 | + require.NotNil(t, sess) |
| 288 | + assert.Empty(t, sess.Tools(), "routing table must be empty when no backend reconnected") |
| 289 | +} |
| 290 | + |
| 291 | +// --------------------------------------------------------------------------- |
| 292 | +// AC5: RC-16 backend expiry — NotifyBackendExpired removes metadata; |
| 293 | +// subsequent RestoreSession skips the expired backend. |
| 294 | +// --------------------------------------------------------------------------- |
| 295 | + |
| 296 | +// TestHorizontalScaling_BackendExpiry_SkipsExpiredOnRestore verifies that |
| 297 | +// after NotifyBackendExpired removes a backend from Redis metadata, a |
| 298 | +// subsequent RestoreSession on a different pod only connects to the remaining |
| 299 | +// backend and does not include the expired backend's tools. |
| 300 | +func TestHorizontalScaling_BackendExpiry_SkipsExpiredOnRestore(t *testing.T) { |
| 301 | + t.Parallel() |
| 302 | + |
| 303 | + mr := miniredis.RunT(t) |
| 304 | + storage := newSharedRedisStorage(t, mr) |
| 305 | + |
| 306 | + // Two backends with distinct tool names so we can tell them apart. |
| 307 | + backendA := startMCPBackend(t, "backend-alpha", "tool-alpha") |
| 308 | + backendB := startMCPBackend(t, "backend-beta", "tool-beta") |
| 309 | + |
| 310 | + // Pod A: create session connected to both backends. |
| 311 | + smA := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backendA, backendB}) |
| 312 | + sessionID := createSession(t, smA, nil) |
| 313 | + |
| 314 | + // Verify session A has tools from both backends before expiry. |
| 315 | + sessA, ok := smA.GetMultiSession(sessionID) |
| 316 | + require.True(t, ok) |
| 317 | + toolNames := make(map[string]bool) |
| 318 | + for _, tool := range sessA.Tools() { |
| 319 | + toolNames[tool.Name] = true |
| 320 | + } |
| 321 | + require.True(t, toolNames["tool-alpha"], "session A must have tool-alpha before expiry") |
| 322 | + require.True(t, toolNames["tool-beta"], "session A must have tool-beta before expiry") |
| 323 | + |
| 324 | + // NotifyBackendExpired updates Redis to remove backend-beta and evicts from pod A's cache. |
| 325 | + smA.NotifyBackendExpired(sessionID, backendB.ID) |
| 326 | + |
| 327 | + // Pod C: fresh Manager, same storage and both backends in registry. |
| 328 | + // (backendB is still running — we're testing that RestoreSession filters |
| 329 | + // it out based on the updated Redis metadata, not because it's unreachable.) |
| 330 | + smC := newTestManagerWithSharedStorage(t, storage, []*vmcp.Backend{backendA, backendB}) |
| 331 | + sessC, ok := smC.GetMultiSession(sessionID) |
| 332 | + require.True(t, ok, "session must be restorable after NotifyBackendExpired") |
| 333 | + require.NotNil(t, sessC) |
| 334 | + |
| 335 | + // Restored session must only have tool-alpha; tool-beta was filtered out. |
| 336 | + restoredTools := make(map[string]bool) |
| 337 | + for _, tool := range sessC.Tools() { |
| 338 | + restoredTools[tool.Name] = true |
| 339 | + } |
| 340 | + assert.True(t, restoredTools["tool-alpha"], "restored session must have tool-alpha") |
| 341 | + assert.False(t, restoredTools["tool-beta"], "restored session must NOT have tool-beta after expiry") |
| 342 | +} |
| 343 | + |
| 344 | +// --------------------------------------------------------------------------- |
| 345 | +// AC6: In-memory-only mode (no Redis) — no cross-pod sharing |
| 346 | +// --------------------------------------------------------------------------- |
| 347 | + |
| 348 | +// TestHorizontalScaling_InMemoryOnlyMode verifies that when Redis is not |
| 349 | +// configured (LocalSessionDataStorage), sessions are not visible to a second |
| 350 | +// Manager instance, and single-pod usage continues to work correctly. |
| 351 | +func TestHorizontalScaling_InMemoryOnlyMode(t *testing.T) { |
| 352 | + t.Parallel() |
| 353 | + |
| 354 | + backend := startMCPBackend(t, "backend-alpha", "echo") |
| 355 | + |
| 356 | + newLocalStorage := func(t *testing.T) transportsession.DataStorage { |
| 357 | + t.Helper() |
| 358 | + s, err := transportsession.NewLocalSessionDataStorage(time.Hour) |
| 359 | + require.NoError(t, err) |
| 360 | + t.Cleanup(func() { _ = s.Close() }) |
| 361 | + return s |
| 362 | + } |
| 363 | + |
| 364 | + // Pod A and pod B each have their own local storage — no sharing. |
| 365 | + storageA := newLocalStorage(t) |
| 366 | + storageB := newLocalStorage(t) |
| 367 | + |
| 368 | + smA := newTestManagerWithSharedStorage(t, storageA, []*vmcp.Backend{backend}) |
| 369 | + smB := newTestManagerWithSharedStorage(t, storageB, []*vmcp.Backend{backend}) |
| 370 | + |
| 371 | + sessionID := createSession(t, smA, nil) |
| 372 | + |
| 373 | + // Pod B must not be able to see pod A's session. |
| 374 | + _, ok := smB.GetMultiSession(sessionID) |
| 375 | + assert.False(t, ok, "in-memory-only: pod B must not see pod A's session") |
| 376 | + |
| 377 | + // Single-pod usage on pod A must still work. |
| 378 | + sess, ok := smA.GetMultiSession(sessionID) |
| 379 | + require.True(t, ok, "pod A must still serve its own session") |
| 380 | + require.NotNil(t, sess) |
| 381 | + assert.NotEmpty(t, sess.Tools(), "session on pod A must have tools") |
| 382 | +} |
0 commit comments