Skip to content

Commit 6b5cb78

Browse files
committed
Route MCP sessions to the originating backend pod using httptrace
When a proxy runner pod restarts it recovers sessions from Redis but backend_url stored the ClusterIP, so kube-proxy could send follow-up requests to a different backend pod that never handled initialize — causing JSON-RPC -32001 "session not found" errors on the first request. Use net/http/httptrace.GotConn to capture the actual backend pod IP after kube-proxy DNAT on every initialize request, and store that as backend_url instead of the ClusterIP URL. The existing Rewrite closure already reads backend_url and pins routing to the correct pod; no changes to that path are needed. When the backend pod is later replaced (rescheduled to a new IP or restarted in place and lost in-memory session state), the proxy now re-initializes the backend session transparently rather than returning 404 to the client: - Dial error (pod IP unreachable): re-init triggers on TCP failure - Backend 404 (session lost, same IP): re-init triggers on response In both cases the proxy replays the stored initialize body against the ClusterIP, captures the new pod IP via GotConn, stores the new backend session ID, rewrites outbound Mcp-Session-Id headers, and replays the original client request — the client sees no error. DELETE responses are excluded from the 404 re-init path since the session is intentionally torn down in that case. Closes #4575
1 parent 22a4203 commit 6b5cb78

File tree

2 files changed

+333
-8
lines changed

2 files changed

+333
-8
lines changed

pkg/transport/proxy/transparent/backend_routing_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,146 @@ func TestRoundTripStoresBackendURLOnInitialize(t *testing.T) {
289289
require.True(t, ok, "session should have backend_url metadata")
290290
assert.Equal(t, backend.URL, backendURL)
291291
}
292+
293+
// TestRoundTripStoresInitBodyOnInitialize verifies that the raw JSON-RPC initialize
294+
// request body is stored in session metadata so the proxy can transparently
295+
// re-initialize the backend session if the pod is later replaced.
296+
func TestRoundTripStoresInitBodyOnInitialize(t *testing.T) {
297+
t.Parallel()
298+
299+
sessionID := uuid.New().String()
300+
const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
301+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
302+
w.Header().Set("Mcp-Session-Id", sessionID)
303+
w.WriteHeader(http.StatusOK)
304+
}))
305+
defer backend.Close()
306+
307+
proxy, addr := startProxy(t, backend.URL)
308+
309+
ctx := context.Background()
310+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
311+
"http://"+addr+"/mcp",
312+
strings.NewReader(initBody))
313+
require.NoError(t, err)
314+
req.Header.Set("Content-Type", "application/json")
315+
316+
resp, err := http.DefaultClient.Do(req)
317+
require.NoError(t, err)
318+
_ = resp.Body.Close()
319+
320+
sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID))
321+
require.True(t, ok, "session should have been created")
322+
stored, exists := sess.GetMetadataValue(sessionMetadataInitBody)
323+
require.True(t, exists, "init_body should be stored in session metadata")
324+
assert.Equal(t, initBody, stored)
325+
}
326+
327+
// TestRoundTripReinitializesOnBackend404 verifies that when the backend pod returns
328+
// 404 (session lost after restart on the same IP), the proxy transparently
329+
// re-initializes the backend session and replays the original request — client sees 200.
330+
func TestRoundTripReinitializesOnBackend404(t *testing.T) {
331+
t.Parallel()
332+
333+
// staleBackend simulates a pod that has lost its in-memory session state.
334+
var staleHit atomic.Int32
335+
staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
336+
staleHit.Add(1)
337+
w.WriteHeader(http.StatusNotFound)
338+
}))
339+
defer staleBackend.Close()
340+
341+
// freshBackend simulates a healthy pod: returns a session ID on initialize
342+
// and 200 for all other requests.
343+
freshSessionID := uuid.New().String()
344+
var freshHit atomic.Int32
345+
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
346+
freshHit.Add(1)
347+
body, _ := io.ReadAll(r.Body)
348+
if strings.Contains(string(body), `"initialize"`) {
349+
w.Header().Set("Mcp-Session-Id", freshSessionID)
350+
}
351+
w.WriteHeader(http.StatusOK)
352+
}))
353+
defer freshBackend.Close()
354+
355+
// targetURI (ClusterIP) points to freshBackend; the session has staleBackend as backend_url.
356+
proxy, addr := startProxy(t, freshBackend.URL)
357+
358+
clientSessionID := uuid.New().String()
359+
sess := session.NewProxySession(clientSessionID)
360+
sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL)
361+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
362+
require.NoError(t, proxy.sessionManager.AddSession(sess))
363+
364+
ctx := context.Background()
365+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
366+
"http://"+addr+"/mcp",
367+
strings.NewReader(`{"method":"tools/list"}`))
368+
require.NoError(t, err)
369+
req.Header.Set("Content-Type", "application/json")
370+
req.Header.Set("Mcp-Session-Id", clientSessionID)
371+
372+
resp, err := http.DefaultClient.Do(req)
373+
require.NoError(t, err)
374+
_ = resp.Body.Close()
375+
376+
assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init")
377+
assert.GreaterOrEqual(t, staleHit.Load(), int32(1), "stale backend should have been hit")
378+
assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay")
379+
380+
// Session should now have backend_sid mapping to the new backend session.
381+
updated, ok := proxy.sessionManager.Get(normalizeSessionID(clientSessionID))
382+
require.True(t, ok, "session should still exist after re-init")
383+
backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID)
384+
require.True(t, exists, "backend_sid should be set after re-init")
385+
assert.Equal(t, normalizeSessionID(freshSessionID), backendSID)
386+
}
387+
388+
// TestRoundTripReinitializesOnDialError verifies that when the proxy cannot reach
389+
// the stored pod IP (dial error — pod rescheduled to a new IP), it transparently
390+
// re-initializes the backend session via the ClusterIP and replays the original
391+
// request — the client sees a 200.
392+
func TestRoundTripReinitializesOnDialError(t *testing.T) {
393+
t.Parallel()
394+
395+
// Create a server and immediately close it so its URL refuses connections.
396+
dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
397+
deadURL := dead.URL
398+
dead.Close()
399+
400+
freshSessionID := uuid.New().String()
401+
var freshHit atomic.Int32
402+
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
403+
freshHit.Add(1)
404+
body, _ := io.ReadAll(r.Body)
405+
if strings.Contains(string(body), `"initialize"`) {
406+
w.Header().Set("Mcp-Session-Id", freshSessionID)
407+
}
408+
w.WriteHeader(http.StatusOK)
409+
}))
410+
defer freshBackend.Close()
411+
412+
proxy, addr := startProxy(t, freshBackend.URL)
413+
414+
clientSessionID := uuid.New().String()
415+
sess := session.NewProxySession(clientSessionID)
416+
sess.SetMetadata(sessionMetadataBackendURL, deadURL)
417+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
418+
require.NoError(t, proxy.sessionManager.AddSession(sess))
419+
420+
ctx := context.Background()
421+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
422+
"http://"+addr+"/mcp",
423+
strings.NewReader(`{"method":"tools/list"}`))
424+
require.NoError(t, err)
425+
req.Header.Set("Content-Type", "application/json")
426+
req.Header.Set("Mcp-Session-Id", clientSessionID)
427+
428+
resp, err := http.DefaultClient.Do(req)
429+
require.NoError(t, err)
430+
_ = resp.Body.Close()
431+
432+
assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init on dial error")
433+
assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay")
434+
}

pkg/transport/proxy/transparent/transparent_proxy.go

Lines changed: 190 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"log/slog"
1616
"net"
1717
"net/http"
18+
"net/http/httptrace"
1819
"net/http/httputil"
1920
"net/url"
2021
"os"
@@ -153,6 +154,17 @@ const (
153154
// It is written on initialize and read in the Rewrite closure to route follow-up requests
154155
// to the same backend pod that handled the session's initialize request.
155156
sessionMetadataBackendURL = "backend_url"
157+
158+
// sessionMetadataInitBody stores the raw JSON-RPC initialize request body.
159+
// It is used to transparently re-initialize a backend session when the pod that
160+
// originally handled initialize has been replaced (new IP or lost in-memory state).
161+
sessionMetadataInitBody = "init_body"
162+
163+
// sessionMetadataBackendSID stores the backend's assigned Mcp-Session-Id when it
164+
// diverges from the client-facing session ID after a transparent re-initialization.
165+
// The Rewrite closure rewrites the outbound Mcp-Session-Id header to this value so
166+
// the backend sees its own session ID while the client keeps its original one.
167+
sessionMetadataBackendSID = "backend_sid"
156168
)
157169

158170
// Option is a functional option for configuring TransparentProxy
@@ -436,12 +448,33 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
436448
}
437449
}
438450

451+
// Attach an httptrace to capture the actual backend pod IP after kube-proxy
452+
// DNAT resolves the ClusterIP to a specific pod. The captured address is stored
453+
// as backend_url so follow-up requests always reach the same pod, even after a
454+
// proxy runner restart that would otherwise lose the in-memory routing state.
455+
var capturedPodAddr string
456+
if sawInitialize {
457+
trace := &httptrace.ClientTrace{
458+
GotConn: func(info httptrace.GotConnInfo) {
459+
capturedPodAddr = info.Conn.RemoteAddr().String()
460+
},
461+
}
462+
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
463+
}
464+
439465
resp, err := t.forward(req)
440466
if err != nil {
441467
if errors.Is(err, context.Canceled) {
442468
// Expected during shutdown or client disconnect—silently ignore
443469
return nil, err
444470
}
471+
// Dial error against a stored pod IP means the pod has been replaced.
472+
// Attempt transparent re-initialization so the client sees no error.
473+
if isDialError(err) {
474+
if reInitResp, reInitErr := t.reinitializeAndReplay(req, reqBody); reInitResp != nil || reInitErr != nil {
475+
return reInitResp, reInitErr
476+
}
477+
}
445478
slog.Error("failed to forward request", "error", err)
446479
return nil, err
447480
}
@@ -471,6 +504,20 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
471504
}
472505
}
473506

507+
// Backend returned 404 for a non-initialize, non-DELETE request whose session IS
508+
// known to the proxy. This means the backend pod lost its in-memory session state
509+
// (e.g. it was restarted but got the same IP). Attempt transparent re-initialization
510+
// so the client sees no error. DELETE is excluded because the session has already
511+
// been cleaned up above and the 404 is the expected terminal response.
512+
if resp.StatusCode == http.StatusNotFound && !sawInitialize && req.Method != http.MethodDelete {
513+
if sid := req.Header.Get("Mcp-Session-Id"); sid != "" {
514+
if reInitResp, reInitErr := t.reinitializeAndReplay(req, reqBody); reInitResp != nil || reInitErr != nil {
515+
_ = resp.Body.Close()
516+
return reInitResp, reInitErr
517+
}
518+
}
519+
}
520+
474521
if resp.StatusCode == http.StatusOK {
475522
// check if we saw a valid mcp header
476523
ct := resp.Header.Get("Mcp-Session-Id")
@@ -480,14 +527,15 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
480527
internalID := normalizeSessionID(ct)
481528
if _, ok := t.p.sessionManager.Get(internalID); !ok {
482529
sess := session.NewProxySession(internalID)
483-
// Store targetURI as the default backend_url for this session.
484-
// In single-replica deployments targetURI is already the pod address,
485-
// so no override is needed. In multi-replica deployments the
486-
// vMCP/operator layer is responsible for setting backend_url to the
487-
// actual pod DNS name (e.g. http://mcp-server-0.mcp-server.default.svc:8080)
488-
// before the request reaches this proxy; the Rewrite closure then reads
489-
// that value and routes follow-up requests to the correct pod.
490-
sess.SetMetadata(sessionMetadataBackendURL, t.p.targetURI)
530+
// Store the actual pod IP (captured via GotConn) as backend_url so that
531+
// after a proxy runner restart the session is routed to the same backend
532+
// pod that handled initialize, not a random pod via ClusterIP.
533+
sess.SetMetadata(sessionMetadataBackendURL, t.podBackendURL(capturedPodAddr))
534+
// Store the initialize body so we can transparently re-initialize the
535+
// backend session if the pod is later replaced or loses session state.
536+
if len(reqBody) > 0 {
537+
sess.SetMetadata(sessionMetadataInitBody, string(reqBody))
538+
}
491539
if err := t.p.sessionManager.AddSession(sess); err != nil {
492540
//nolint:gosec // G706: session ID from HTTP response header
493541
slog.Error("failed to create session from header",
@@ -553,6 +601,133 @@ func (t *tracingTransport) detectInitialize(body []byte) bool {
553601
return false
554602
}
555603

604+
// podBackendURL constructs a backend URL that targets the specific pod IP captured
605+
// via httptrace.GotConn, using the scheme from targetURI. Falls back to targetURI
606+
// when no address was captured (e.g. single-replica, connection reuse without a new conn).
607+
func (t *tracingTransport) podBackendURL(capturedAddr string) string {
608+
if capturedAddr == "" {
609+
return t.p.targetURI
610+
}
611+
parsed, err := url.Parse(t.p.targetURI)
612+
if err != nil {
613+
return t.p.targetURI
614+
}
615+
parsed.Host = capturedAddr
616+
return parsed.String()
617+
}
618+
619+
// isDialError reports whether err is a TCP dial failure, indicating that the
620+
// target host is unreachable (pod has been terminated or rescheduled).
621+
func isDialError(err error) bool {
622+
var opErr *net.OpError
623+
return errors.As(err, &opErr) && opErr.Op == "dial"
624+
}
625+
626+
// reinitializeAndReplay is called when the proxy detects that the backend pod
627+
// that owned a session is no longer reachable (dial error) or has lost its
628+
// in-memory session state (backend returned 404). It transparently:
629+
// 1. Re-sends the stored initialize body to the ClusterIP service so kube-proxy
630+
// selects a healthy pod and the backend creates a new session.
631+
// 2. Captures the new pod IP via httptrace.GotConn and stores it as backend_url.
632+
// 3. Maps the client's original session ID to the new backend session ID.
633+
// 4. Replays the original client request so the client sees no error.
634+
//
635+
// Returns (nil, nil) when re-initialization is not applicable (no stored init
636+
// body, session unknown, or already routing via ClusterIP).
637+
func (t *tracingTransport) reinitializeAndReplay(req *http.Request, origBody []byte) (*http.Response, error) {
638+
sid := req.Header.Get("Mcp-Session-Id")
639+
if sid == "" {
640+
return nil, nil
641+
}
642+
internalSID := normalizeSessionID(sid)
643+
sess, ok := t.p.sessionManager.Get(internalSID)
644+
if !ok {
645+
return nil, nil
646+
}
647+
648+
initBody, hasInit := sess.GetMetadataValue(sessionMetadataInitBody)
649+
if !hasInit || initBody == "" {
650+
// No stored init body — cannot re-initialize transparently.
651+
// Reset backend_url to ClusterIP so the next request goes through
652+
// kube-proxy and lets the client receive a clean 404 to re-initialize.
653+
sess.SetMetadata(sessionMetadataBackendURL, t.p.targetURI)
654+
_ = t.p.sessionManager.UpsertSession(sess)
655+
return nil, nil
656+
}
657+
658+
slog.Debug("backend session lost; transparently re-initializing",
659+
"session_id", sid, "target", t.p.targetURI)
660+
661+
// Capture the new pod IP via GotConn on the re-initialize connection.
662+
var capturedPodAddr string
663+
trace := &httptrace.ClientTrace{
664+
GotConn: func(info httptrace.GotConnInfo) {
665+
capturedPodAddr = info.Conn.RemoteAddr().String()
666+
},
667+
}
668+
initCtx := httptrace.WithClientTrace(req.Context(), trace)
669+
670+
// Build a fresh initialize request to the ClusterIP (no Mcp-Session-Id —
671+
// the backend assigns a new session ID in the response).
672+
parsedTarget, err := url.Parse(t.p.targetURI)
673+
if err != nil {
674+
return nil, nil
675+
}
676+
initURL := *req.URL
677+
initURL.Scheme = parsedTarget.Scheme
678+
initURL.Host = parsedTarget.Host
679+
680+
initReq, err := http.NewRequestWithContext(initCtx, http.MethodPost, initURL.String(), bytes.NewReader([]byte(initBody)))
681+
if err != nil {
682+
return nil, nil
683+
}
684+
initReq.Header.Set("Content-Type", "application/json")
685+
686+
initResp, err := t.forward(initReq)
687+
if err != nil {
688+
slog.Error("transparent re-initialize failed", "error", err)
689+
return nil, err
690+
}
691+
_, _ = io.Copy(io.Discard, initResp.Body)
692+
_ = initResp.Body.Close()
693+
694+
newBackendSID := initResp.Header.Get("Mcp-Session-Id")
695+
if newBackendSID == "" {
696+
slog.Debug("re-initialize response contained no Mcp-Session-Id; falling back to ClusterIP")
697+
sess.SetMetadata(sessionMetadataBackendURL, t.p.targetURI)
698+
_ = t.p.sessionManager.UpsertSession(sess)
699+
return nil, nil
700+
}
701+
702+
// Update session: point backend_url at the newly-discovered pod and record
703+
// the backend session ID so Rewrite rewrites Mcp-Session-Id on outbound requests.
704+
newPodURL := t.podBackendURL(capturedPodAddr)
705+
sess.SetMetadata(sessionMetadataBackendURL, newPodURL)
706+
sess.SetMetadata(sessionMetadataBackendSID, normalizeSessionID(newBackendSID))
707+
if upsertErr := t.p.sessionManager.UpsertSession(sess); upsertErr != nil {
708+
slog.Debug("failed to update session after re-initialize", "error", upsertErr)
709+
}
710+
711+
// Replay the original client request to the new pod with the new backend SID.
712+
// Use the captured pod address directly so we bypass the Rewrite closure
713+
// (which still holds the old backend_url until the next session load).
714+
replayHost := capturedPodAddr
715+
if replayHost == "" {
716+
replayHost = parsedTarget.Host
717+
}
718+
replayReq := req.Clone(req.Context())
719+
replayReq.URL.Scheme = parsedTarget.Scheme
720+
replayReq.URL.Host = replayHost
721+
replayReq.Host = replayHost // keep Host header consistent with URL to avoid backend validation errors
722+
replayReq.Header.Set("Mcp-Session-Id", newBackendSID)
723+
replayReq.Body = io.NopCloser(bytes.NewReader(origBody))
724+
replayReq.ContentLength = int64(len(origBody))
725+
726+
slog.Debug("replaying original request after transparent re-initialization",
727+
"new_pod_url", newPodURL, "new_backend_sid", newBackendSID)
728+
return t.forward(replayReq)
729+
}
730+
556731
// modifyResponse modifies HTTP responses based on transport-specific requirements.
557732
// Delegates to the appropriate ResponseProcessor based on transport type.
558733
func (p *TransparentProxy) modifyResponse(resp *http.Response) error {
@@ -601,6 +776,13 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
601776
pr.Out.URL.Host = parsed.Host
602777
}
603778
}
779+
// After a transparent re-initialization the proxy maps the client's
780+
// session ID to the backend's newly-assigned session ID. Rewrite the
781+
// outbound header so the backend sees its own ID while the client
782+
// continues to use its original session ID unchanged.
783+
if backendSID, exists := sess.GetMetadataValue(sessionMetadataBackendSID); exists && backendSID != "" {
784+
pr.Out.Header.Set("Mcp-Session-Id", backendSID)
785+
}
604786
}
605787
}
606788

0 commit comments

Comments
 (0)