|
9 | 9 | "net/http" |
10 | 10 | "net/http/httptest" |
11 | 11 | "strings" |
| 12 | + "sync" |
12 | 13 | "sync/atomic" |
13 | 14 | "testing" |
14 | 15 | "time" |
@@ -289,3 +290,230 @@ func TestRoundTripStoresBackendURLOnInitialize(t *testing.T) { |
289 | 290 | require.True(t, ok, "session should have backend_url metadata") |
290 | 291 | assert.Equal(t, backend.URL, backendURL) |
291 | 292 | } |
| 293 | + |
| 294 | +// TestRoundTripStoresInitBodyOnInitialize verifies that the raw JSON-RPC initialize |
| 295 | +// request body is stored in session metadata so the proxy can transparently |
| 296 | +// re-initialize the backend session if the pod is later replaced. |
| 297 | +func TestRoundTripStoresInitBodyOnInitialize(t *testing.T) { |
| 298 | + t.Parallel() |
| 299 | + |
| 300 | + sessionID := uuid.New().String() |
| 301 | + const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}` |
| 302 | + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { |
| 303 | + w.Header().Set("Mcp-Session-Id", sessionID) |
| 304 | + w.WriteHeader(http.StatusOK) |
| 305 | + })) |
| 306 | + defer backend.Close() |
| 307 | + |
| 308 | + proxy, addr := startProxy(t, backend.URL) |
| 309 | + |
| 310 | + ctx := context.Background() |
| 311 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, |
| 312 | + "http://"+addr+"/mcp", |
| 313 | + strings.NewReader(initBody)) |
| 314 | + require.NoError(t, err) |
| 315 | + req.Header.Set("Content-Type", "application/json") |
| 316 | + |
| 317 | + resp, err := http.DefaultClient.Do(req) |
| 318 | + require.NoError(t, err) |
| 319 | + _ = resp.Body.Close() |
| 320 | + |
| 321 | + sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID)) |
| 322 | + require.True(t, ok, "session should have been created") |
| 323 | + stored, exists := sess.GetMetadataValue(sessionMetadataInitBody) |
| 324 | + require.True(t, exists, "init_body should be stored in session metadata") |
| 325 | + assert.Equal(t, initBody, stored) |
| 326 | +} |
| 327 | + |
| 328 | +// TestRoundTripReinitializesOnBackend404 verifies that when the backend pod returns |
| 329 | +// 404 (session lost after restart on the same IP), the proxy transparently |
| 330 | +// re-initializes the backend session and replays the original request — client sees 200. |
| 331 | +func TestRoundTripReinitializesOnBackend404(t *testing.T) { |
| 332 | + t.Parallel() |
| 333 | + |
| 334 | + // staleBackend simulates a pod that has lost its in-memory session state. |
| 335 | + var staleHit atomic.Int32 |
| 336 | + staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { |
| 337 | + staleHit.Add(1) |
| 338 | + w.WriteHeader(http.StatusNotFound) |
| 339 | + })) |
| 340 | + defer staleBackend.Close() |
| 341 | + |
| 342 | + // freshBackend simulates a healthy pod: returns a session ID on initialize |
| 343 | + // and 200 for all other requests. |
| 344 | + freshSessionID := uuid.New().String() |
| 345 | + var freshHit atomic.Int32 |
| 346 | + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 347 | + freshHit.Add(1) |
| 348 | + body, _ := io.ReadAll(r.Body) |
| 349 | + if strings.Contains(string(body), `"initialize"`) { |
| 350 | + w.Header().Set("Mcp-Session-Id", freshSessionID) |
| 351 | + } |
| 352 | + w.WriteHeader(http.StatusOK) |
| 353 | + })) |
| 354 | + defer freshBackend.Close() |
| 355 | + |
| 356 | + // targetURI (ClusterIP) points to freshBackend; the session has staleBackend as backend_url. |
| 357 | + proxy, addr := startProxy(t, freshBackend.URL) |
| 358 | + |
| 359 | + clientSessionID := uuid.New().String() |
| 360 | + sess := session.NewProxySession(clientSessionID) |
| 361 | + sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL) |
| 362 | + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) |
| 363 | + require.NoError(t, proxy.sessionManager.AddSession(sess)) |
| 364 | + |
| 365 | + ctx := context.Background() |
| 366 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, |
| 367 | + "http://"+addr+"/mcp", |
| 368 | + strings.NewReader(`{"method":"tools/list"}`)) |
| 369 | + require.NoError(t, err) |
| 370 | + req.Header.Set("Content-Type", "application/json") |
| 371 | + req.Header.Set("Mcp-Session-Id", clientSessionID) |
| 372 | + |
| 373 | + resp, err := http.DefaultClient.Do(req) |
| 374 | + require.NoError(t, err) |
| 375 | + _ = resp.Body.Close() |
| 376 | + |
| 377 | + assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init") |
| 378 | + assert.GreaterOrEqual(t, staleHit.Load(), int32(1), "stale backend should have been hit") |
| 379 | + assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay") |
| 380 | + |
| 381 | + // Session should now have backend_sid mapping to the new backend session. |
| 382 | + updated, ok := proxy.sessionManager.Get(normalizeSessionID(clientSessionID)) |
| 383 | + require.True(t, ok, "session should still exist after re-init") |
| 384 | + backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID) |
| 385 | + require.True(t, exists, "backend_sid should be set after re-init") |
| 386 | + assert.Equal(t, freshSessionID, backendSID, "backend_sid must be the raw value the backend issued, not normalized") |
| 387 | +} |
| 388 | + |
| 389 | +// TestRoundTripReinitializesPreservesNonUUIDBackendSessionID verifies that when the |
| 390 | +// backend issues a non-UUID Mcp-Session-Id on re-initialization, the proxy stores |
| 391 | +// and forwards the raw value — not a UUID v5 hash of it — on all subsequent requests. |
| 392 | +// |
| 393 | +// The normalization bug only manifests on the request AFTER the replay: the replay |
| 394 | +// sets Mcp-Session-Id directly from newBackendSID (bypassing Rewrite), but subsequent |
| 395 | +// requests go through the Rewrite closure which reads backend_sid from session metadata. |
| 396 | +// If backend_sid was stored as normalizeSessionID(newBackendSID), Rewrite would send |
| 397 | +// the wrong (hashed) value and the backend would reject every subsequent request. |
| 398 | +func TestRoundTripReinitializesPreservesNonUUIDBackendSessionID(t *testing.T) { |
| 399 | + t.Parallel() |
| 400 | + |
| 401 | + // Non-UUID opaque token, as some MCP servers issue. |
| 402 | + const nonUUIDSessionID = "opaque-session-token-abc123" |
| 403 | + |
| 404 | + staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { |
| 405 | + w.WriteHeader(http.StatusNotFound) |
| 406 | + })) |
| 407 | + defer staleBackend.Close() |
| 408 | + |
| 409 | + // receivedSIDs tracks Mcp-Session-Id values arriving on non-initialize requests, |
| 410 | + // in order. Index 0 = replay (direct from reinitializeAndReplay), index 1 = second |
| 411 | + // client request (routed through Rewrite reading backend_sid from session metadata). |
| 412 | + var ( |
| 413 | + receivedMu sync.Mutex |
| 414 | + receivedSIDs []string |
| 415 | + ) |
| 416 | + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 417 | + body, _ := io.ReadAll(r.Body) |
| 418 | + if strings.Contains(string(body), `"initialize"`) { |
| 419 | + w.Header().Set("Mcp-Session-Id", nonUUIDSessionID) |
| 420 | + w.WriteHeader(http.StatusOK) |
| 421 | + return |
| 422 | + } |
| 423 | + receivedMu.Lock() |
| 424 | + receivedSIDs = append(receivedSIDs, r.Header.Get("Mcp-Session-Id")) |
| 425 | + receivedMu.Unlock() |
| 426 | + w.WriteHeader(http.StatusOK) |
| 427 | + })) |
| 428 | + defer freshBackend.Close() |
| 429 | + |
| 430 | + proxy, addr := startProxy(t, freshBackend.URL) |
| 431 | + |
| 432 | + clientSessionID := uuid.New().String() |
| 433 | + sess := session.NewProxySession(clientSessionID) |
| 434 | + sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL) |
| 435 | + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) |
| 436 | + require.NoError(t, proxy.sessionManager.AddSession(sess)) |
| 437 | + |
| 438 | + doRequest := func() *http.Response { |
| 439 | + ctx := context.Background() |
| 440 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, |
| 441 | + "http://"+addr+"/mcp", |
| 442 | + strings.NewReader(`{"method":"tools/list"}`)) |
| 443 | + require.NoError(t, err) |
| 444 | + req.Header.Set("Content-Type", "application/json") |
| 445 | + req.Header.Set("Mcp-Session-Id", clientSessionID) |
| 446 | + resp, err := http.DefaultClient.Do(req) |
| 447 | + require.NoError(t, err) |
| 448 | + return resp |
| 449 | + } |
| 450 | + |
| 451 | + // First request: triggers re-init. The replay (inside reinitializeAndReplay) sets |
| 452 | + // Mcp-Session-Id directly, so receivedSIDs[0] is always the raw value regardless |
| 453 | + // of what is stored in session metadata. |
| 454 | + resp1 := doRequest() |
| 455 | + _ = resp1.Body.Close() |
| 456 | + require.Equal(t, http.StatusOK, resp1.StatusCode) |
| 457 | + |
| 458 | + // Second request: goes through the Rewrite closure, which reads backend_sid from |
| 459 | + // session metadata. This is where the normalization bug manifests — if backend_sid |
| 460 | + // was stored as normalizeSessionID(nonUUIDSessionID), Rewrite would forward the |
| 461 | + // wrong hashed value and receivedSIDs[1] would not equal nonUUIDSessionID. |
| 462 | + resp2 := doRequest() |
| 463 | + _ = resp2.Body.Close() |
| 464 | + require.Equal(t, http.StatusOK, resp2.StatusCode) |
| 465 | + |
| 466 | + receivedMu.Lock() |
| 467 | + defer receivedMu.Unlock() |
| 468 | + require.Len(t, receivedSIDs, 2, "fresh backend should have received replay + second request") |
| 469 | + assert.Equal(t, nonUUIDSessionID, receivedSIDs[0], "replay must forward raw non-UUID session ID") |
| 470 | + assert.Equal(t, nonUUIDSessionID, receivedSIDs[1], "subsequent request via Rewrite must forward raw non-UUID session ID") |
| 471 | +} |
| 472 | + |
| 473 | +// TestRoundTripReinitializesOnDialError verifies that when the proxy cannot reach |
| 474 | +// the stored pod IP (dial error — pod rescheduled to a new IP), it transparently |
| 475 | +// re-initializes the backend session via the ClusterIP and replays the original |
| 476 | +// request — the client sees a 200. |
| 477 | +func TestRoundTripReinitializesOnDialError(t *testing.T) { |
| 478 | + t.Parallel() |
| 479 | + |
| 480 | + // Create a server and immediately close it so its URL refuses connections. |
| 481 | + dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) |
| 482 | + deadURL := dead.URL |
| 483 | + dead.Close() |
| 484 | + |
| 485 | + freshSessionID := uuid.New().String() |
| 486 | + var freshHit atomic.Int32 |
| 487 | + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 488 | + freshHit.Add(1) |
| 489 | + body, _ := io.ReadAll(r.Body) |
| 490 | + if strings.Contains(string(body), `"initialize"`) { |
| 491 | + w.Header().Set("Mcp-Session-Id", freshSessionID) |
| 492 | + } |
| 493 | + w.WriteHeader(http.StatusOK) |
| 494 | + })) |
| 495 | + defer freshBackend.Close() |
| 496 | + |
| 497 | + proxy, addr := startProxy(t, freshBackend.URL) |
| 498 | + |
| 499 | + clientSessionID := uuid.New().String() |
| 500 | + sess := session.NewProxySession(clientSessionID) |
| 501 | + sess.SetMetadata(sessionMetadataBackendURL, deadURL) |
| 502 | + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) |
| 503 | + require.NoError(t, proxy.sessionManager.AddSession(sess)) |
| 504 | + |
| 505 | + ctx := context.Background() |
| 506 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, |
| 507 | + "http://"+addr+"/mcp", |
| 508 | + strings.NewReader(`{"method":"tools/list"}`)) |
| 509 | + require.NoError(t, err) |
| 510 | + req.Header.Set("Content-Type", "application/json") |
| 511 | + req.Header.Set("Mcp-Session-Id", clientSessionID) |
| 512 | + |
| 513 | + resp, err := http.DefaultClient.Do(req) |
| 514 | + require.NoError(t, err) |
| 515 | + _ = resp.Body.Close() |
| 516 | + |
| 517 | + assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init on dial error") |
| 518 | + assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay") |
| 519 | +} |
0 commit comments