Skip to content

Commit d851c69

Browse files
yroblataskbot
andauthored
Route MCP sessions to the originating backend pod using httptrace (#4673)
* Route MCP sessions to the originating backend pod using httptrace When a proxy runner pod restarts it recovers sessions from Redis but backend_url stored the ClusterIP, so kube-proxy could send follow-up requests to a different backend pod that never handled initialize — causing JSON-RPC -32001 "session not found" errors on the first request. Use net/http/httptrace.GotConn to capture the actual backend pod IP after kube-proxy DNAT on every initialize request, and store that as backend_url instead of the ClusterIP URL. The existing Rewrite closure already reads backend_url and pins routing to the correct pod; no changes to that path are needed. When the backend pod is later replaced (rescheduled to a new IP or restarted in place and lost in-memory session state), the proxy now re-initializes the backend session transparently rather than returning 404 to the client: - Dial error (pod IP unreachable): re-init triggers on TCP failure - Backend 404 (session lost, same IP): re-init triggers on response In both cases the proxy replays the stored initialize body against the ClusterIP, captures the new pod IP via GotConn, stores the new backend session ID, rewrites outbound Mcp-Session-Id headers, and replays the original client request — the client sees no error. DELETE responses are excluded from the 404 re-init path since the session is intentionally torn down in that case. Closes #4575 * Reimplement tracingTransport logic + add e2e tests Extract re-initialization logic from tracingTransport into a dedicated backendRecovery type backed by a narrow recoverySessionStore interface (Get + UpsertSession) and a forward func. tracingTransport now owns only the request lifecycle (session guard, initialize detection, httptrace, session creation) and delegates all forwarding and recovery to backendRecovery. This makes reinitializeAndReplay and podBackendURL testable without standing up a full TransparentProxy: backend_recovery_test.go covers all recovery paths (no session, no init body, happy path, forward error, missing new session ID) using a stubSessionStore and inline httptest servers. tracingTransport.forward() and the base field are removed; all network I/O goes through recovery.forward — a single source of truth for the underlying transport. Also integrates the E2E acceptance test from #4574 that exercises backendReplicas=2 + proxy runner restart, verifying that sessions are routed to the correct backend pod after re-initialization. --------- Co-authored-by: taskbot <taskbot@users.noreply.github.com>
1 parent 65a78f4 commit d851c69

5 files changed

Lines changed: 1009 additions & 40 deletions

File tree

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package transparent
5+
6+
import (
7+
"bytes"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"sync"
13+
"testing"
14+
15+
"github.com/google/uuid"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
19+
"github.com/stacklok/toolhive/pkg/transport/session"
20+
)
21+
22+
// stubSessionStore is a minimal in-memory recoverySessionStore for unit tests.
23+
type stubSessionStore struct {
24+
sessions map[string]session.Session
25+
}
26+
27+
func newStubStore(sessions ...session.Session) *stubSessionStore {
28+
m := make(map[string]session.Session)
29+
for _, s := range sessions {
30+
m[s.ID()] = s
31+
}
32+
return &stubSessionStore{sessions: m}
33+
}
34+
35+
func (s *stubSessionStore) Get(id string) (session.Session, bool) {
36+
sess, ok := s.sessions[id]
37+
return sess, ok
38+
}
39+
40+
func (s *stubSessionStore) UpsertSession(sess session.Session) error {
41+
s.sessions[sess.ID()] = sess
42+
return nil
43+
}
44+
45+
// newRecovery builds a backendRecovery backed by the given store and forward func.
46+
func newRecovery(targetURL string, store recoverySessionStore, fwd func(*http.Request) (*http.Response, error)) *backendRecovery {
47+
return &backendRecovery{
48+
targetURI: targetURL,
49+
forward: fwd,
50+
sessions: store,
51+
}
52+
}
53+
54+
// TestBackendRecoveryNoSession verifies that reinitializeAndReplay returns
55+
// (nil, nil) when the request carries no Mcp-Session-Id.
56+
func TestBackendRecoveryNoSession(t *testing.T) {
57+
t.Parallel()
58+
59+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
60+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
61+
strings.NewReader(`{"method":"tools/list"}`))
62+
require.NoError(t, err)
63+
64+
resp, err := r.reinitializeAndReplay(req, nil)
65+
assert.Nil(t, resp)
66+
assert.NoError(t, err)
67+
}
68+
69+
// TestBackendRecoveryUnknownSession verifies that reinitializeAndReplay returns
70+
// (nil, nil) when the session ID is not in the store.
71+
func TestBackendRecoveryUnknownSession(t *testing.T) {
72+
t.Parallel()
73+
74+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
75+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
76+
strings.NewReader(`{"method":"tools/list"}`))
77+
require.NoError(t, err)
78+
req.Header.Set("Mcp-Session-Id", uuid.New().String())
79+
80+
resp, err := r.reinitializeAndReplay(req, nil)
81+
assert.Nil(t, resp)
82+
assert.NoError(t, err)
83+
}
84+
85+
// TestBackendRecoveryNoInitBody verifies that when the session has no stored
86+
// init body, reinitializeAndReplay resets backend_url to the ClusterIP and
87+
// returns (nil, nil) so the caller falls through to a 404 the client can handle.
88+
func TestBackendRecoveryNoInitBody(t *testing.T) {
89+
t.Parallel()
90+
91+
const clusterIP = "http://cluster-ip:8080"
92+
clientSID := uuid.New().String()
93+
sess := session.NewProxySession(clientSID)
94+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080") // stale pod IP
95+
store := newStubStore(sess)
96+
97+
r := newRecovery(clusterIP, store, nil)
98+
req, err := http.NewRequest(http.MethodPost, clusterIP+"/mcp",
99+
strings.NewReader(`{"method":"tools/list"}`))
100+
require.NoError(t, err)
101+
req.Header.Set("Mcp-Session-Id", clientSID)
102+
103+
resp, err := r.reinitializeAndReplay(req, nil)
104+
assert.Nil(t, resp)
105+
assert.NoError(t, err)
106+
107+
// backend_url should be reset to ClusterIP so the next request routes correctly.
108+
updated, ok := store.Get(clientSID)
109+
require.True(t, ok)
110+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
111+
require.True(t, exists)
112+
assert.Equal(t, clusterIP, backendURL, "backend_url should be reset to ClusterIP when no init body")
113+
}
114+
115+
// TestBackendRecoveryHappyPath verifies the full re-init flow: the stored
116+
// initialize body is replayed to the ClusterIP, the new backend session ID is
117+
// captured, the session is updated, and the original request is replayed — all
118+
// without standing up a full TransparentProxy.
119+
func TestBackendRecoveryHappyPath(t *testing.T) {
120+
t.Parallel()
121+
122+
const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
123+
newBackendSID := uuid.New().String()
124+
var (
125+
forwardMu sync.Mutex
126+
forwardCalls []string
127+
)
128+
129+
// Backend: returns a session ID on initialize, 200 otherwise.
130+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131+
body, _ := io.ReadAll(r.Body)
132+
forwardMu.Lock()
133+
forwardCalls = append(forwardCalls, r.Header.Get("Mcp-Session-Id"))
134+
forwardMu.Unlock()
135+
if strings.Contains(string(body), `"initialize"`) {
136+
w.Header().Set("Mcp-Session-Id", newBackendSID)
137+
}
138+
w.WriteHeader(http.StatusOK)
139+
}))
140+
defer backend.Close()
141+
142+
clientSID := uuid.New().String()
143+
sess := session.NewProxySession(clientSID)
144+
sess.SetMetadata(sessionMetadataInitBody, initBody)
145+
store := newStubStore(sess)
146+
147+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
148+
149+
origBody := []byte(`{"method":"tools/list"}`)
150+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
151+
bytes.NewReader(origBody))
152+
require.NoError(t, err)
153+
req.Header.Set("Mcp-Session-Id", clientSID)
154+
req.Header.Set("Content-Type", "application/json")
155+
156+
resp, err := r.reinitializeAndReplay(req, origBody)
157+
require.NoError(t, err)
158+
require.NotNil(t, resp)
159+
assert.Equal(t, http.StatusOK, resp.StatusCode)
160+
_ = resp.Body.Close()
161+
162+
// Verify session was updated with new backend SID and a pod URL.
163+
updated, ok := store.Get(clientSID)
164+
require.True(t, ok)
165+
backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID)
166+
require.True(t, exists)
167+
assert.Equal(t, newBackendSID, backendSID)
168+
169+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
170+
require.True(t, exists)
171+
assert.NotEmpty(t, backendURL)
172+
173+
// Two forward calls: initialize + replay. The initialize must not carry
174+
// a session ID; the replay must carry the new backend SID.
175+
forwardMu.Lock()
176+
defer forwardMu.Unlock()
177+
require.Len(t, forwardCalls, 2, "forward should be called for initialize and replay")
178+
assert.Empty(t, forwardCalls[0], "initialize request must not carry Mcp-Session-Id")
179+
assert.Equal(t, newBackendSID, forwardCalls[1], "replay must carry the new backend SID")
180+
}
181+
182+
// TestBackendRecoveryReinitForwardError verifies that a forward error during
183+
// re-initialization is returned to the caller.
184+
func TestBackendRecoveryReinitForwardError(t *testing.T) {
185+
t.Parallel()
186+
187+
// Server that is immediately closed — all connections will be refused.
188+
dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
189+
deadURL := dead.URL
190+
dead.Close()
191+
192+
clientSID := uuid.New().String()
193+
sess := session.NewProxySession(clientSID)
194+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
195+
store := newStubStore(sess)
196+
197+
r := newRecovery(deadURL, store, http.DefaultTransport.RoundTrip)
198+
199+
req, err := http.NewRequest(http.MethodPost, deadURL+"/mcp",
200+
strings.NewReader(`{"method":"tools/list"}`))
201+
require.NoError(t, err)
202+
req.Header.Set("Mcp-Session-Id", clientSID)
203+
204+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
205+
assert.Nil(t, resp)
206+
assert.Error(t, err, "forward error during re-init should be returned")
207+
}
208+
209+
// TestBackendRecoveryNoNewSessionID verifies that when the re-initialize
210+
// response carries no Mcp-Session-Id, reinitializeAndReplay resets backend_url
211+
// to ClusterIP and returns (nil, nil).
212+
func TestBackendRecoveryNoNewSessionID(t *testing.T) {
213+
t.Parallel()
214+
215+
// Backend that returns no Mcp-Session-Id on initialize.
216+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
217+
w.WriteHeader(http.StatusOK) // no Mcp-Session-Id header
218+
}))
219+
defer backend.Close()
220+
221+
clientSID := uuid.New().String()
222+
sess := session.NewProxySession(clientSID)
223+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
224+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080")
225+
store := newStubStore(sess)
226+
227+
// targetURI points to backend (so the init request succeeds), but we verify
228+
// that backend_url is reset to targetURI when no session ID comes back.
229+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
230+
231+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
232+
strings.NewReader(`{"method":"tools/list"}`))
233+
require.NoError(t, err)
234+
req.Header.Set("Mcp-Session-Id", clientSID)
235+
236+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
237+
assert.Nil(t, resp)
238+
assert.NoError(t, err)
239+
240+
updated, ok := store.Get(clientSID)
241+
require.True(t, ok)
242+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
243+
require.True(t, exists)
244+
assert.Equal(t, backend.URL, backendURL, "backend_url should fall back to targetURI when no new session ID")
245+
}
246+
247+
// TestPodBackendURLWithCapturedAddr verifies that a captured pod IP replaces the
248+
// host in targetURI while preserving the scheme.
249+
func TestPodBackendURLWithCapturedAddr(t *testing.T) {
250+
t.Parallel()
251+
252+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
253+
got := r.podBackendURL("10.0.0.5:8080")
254+
assert.Equal(t, "http://10.0.0.5:8080", got)
255+
}
256+
257+
// TestPodBackendURLFallback verifies that an empty captured address falls back
258+
// to targetURI unchanged.
259+
func TestPodBackendURLFallback(t *testing.T) {
260+
t.Parallel()
261+
262+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
263+
got := r.podBackendURL("")
264+
assert.Equal(t, "http://cluster-ip:8080", got)
265+
}
266+
267+
// TestPodBackendURLHTTPSFallback verifies that an HTTPS targetURI is never
268+
// rewritten to a pod IP. IP-literal HTTPS URLs fail TLS verification because
269+
// server certificates are issued for hostnames, not pod IPs.
270+
func TestPodBackendURLHTTPSFallback(t *testing.T) {
271+
t.Parallel()
272+
273+
r := &backendRecovery{targetURI: "https://mcp.example.com/mcp"}
274+
got := r.podBackendURL("1.2.3.4:443")
275+
assert.Equal(t, "https://mcp.example.com/mcp", got,
276+
"HTTPS target must not be rewritten to a pod IP")
277+
}

0 commit comments

Comments
 (0)