Skip to content

Commit ce9d3ca

Browse files
committed
Reimplement tracingTransport logic + add e2e tests
Extract re-initialization logic from tracingTransport into a dedicated backendRecovery type backed by a narrow recoverySessionStore interface (Get + UpsertSession) and a forward func. tracingTransport now owns only the request lifecycle (session guard, initialize detection, httptrace, session creation) and delegates all forwarding and recovery to backendRecovery. This makes reinitializeAndReplay and podBackendURL testable without standing up a full TransparentProxy: backend_recovery_test.go covers all recovery paths (no session, no init body, happy path, forward error, missing new session ID) using a stubSessionStore and inline httptest servers. tracingTransport.forward() and the base field are removed; all network I/O goes through recovery.forward — a single source of truth for the underlying transport. Also integrates the E2E acceptance test from #4574 that exercises backendReplicas=2 + proxy runner restart, verifying that sessions are routed to the correct backend pod after re-initialization.
1 parent 32960b1 commit ce9d3ca

File tree

5 files changed

+614
-57
lines changed

5 files changed

+614
-57
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package transparent
5+
6+
import (
7+
"bytes"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"sync"
13+
"testing"
14+
15+
"github.com/google/uuid"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
19+
"github.com/stacklok/toolhive/pkg/transport/session"
20+
)
21+
22+
// stubSessionStore is a minimal in-memory recoverySessionStore for unit tests.
23+
type stubSessionStore struct {
24+
sessions map[string]session.Session
25+
}
26+
27+
func newStubStore(sessions ...session.Session) *stubSessionStore {
28+
m := make(map[string]session.Session)
29+
for _, s := range sessions {
30+
m[s.ID()] = s
31+
}
32+
return &stubSessionStore{sessions: m}
33+
}
34+
35+
func (s *stubSessionStore) Get(id string) (session.Session, bool) {
36+
sess, ok := s.sessions[id]
37+
return sess, ok
38+
}
39+
40+
func (s *stubSessionStore) UpsertSession(sess session.Session) error {
41+
s.sessions[sess.ID()] = sess
42+
return nil
43+
}
44+
45+
// newRecovery builds a backendRecovery backed by the given store and forward func.
46+
func newRecovery(targetURL string, store recoverySessionStore, fwd func(*http.Request) (*http.Response, error)) *backendRecovery {
47+
return &backendRecovery{
48+
targetURI: targetURL,
49+
forward: fwd,
50+
sessions: store,
51+
}
52+
}
53+
54+
// TestBackendRecoveryNoSession verifies that reinitializeAndReplay returns
55+
// (nil, nil) when the request carries no Mcp-Session-Id.
56+
func TestBackendRecoveryNoSession(t *testing.T) {
57+
t.Parallel()
58+
59+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
60+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
61+
strings.NewReader(`{"method":"tools/list"}`))
62+
require.NoError(t, err)
63+
64+
resp, err := r.reinitializeAndReplay(req, nil)
65+
assert.Nil(t, resp)
66+
assert.NoError(t, err)
67+
}
68+
69+
// TestBackendRecoveryUnknownSession verifies that reinitializeAndReplay returns
70+
// (nil, nil) when the session ID is not in the store.
71+
func TestBackendRecoveryUnknownSession(t *testing.T) {
72+
t.Parallel()
73+
74+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
75+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
76+
strings.NewReader(`{"method":"tools/list"}`))
77+
require.NoError(t, err)
78+
req.Header.Set("Mcp-Session-Id", uuid.New().String())
79+
80+
resp, err := r.reinitializeAndReplay(req, nil)
81+
assert.Nil(t, resp)
82+
assert.NoError(t, err)
83+
}
84+
85+
// TestBackendRecoveryNoInitBody verifies that when the session has no stored
86+
// init body, reinitializeAndReplay resets backend_url to the ClusterIP and
87+
// returns (nil, nil) so the caller falls through to a 404 the client can handle.
88+
func TestBackendRecoveryNoInitBody(t *testing.T) {
89+
t.Parallel()
90+
91+
const clusterIP = "http://cluster-ip:8080"
92+
clientSID := uuid.New().String()
93+
sess := session.NewProxySession(clientSID)
94+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080") // stale pod IP
95+
store := newStubStore(sess)
96+
97+
r := newRecovery(clusterIP, store, nil)
98+
req, err := http.NewRequest(http.MethodPost, clusterIP+"/mcp",
99+
strings.NewReader(`{"method":"tools/list"}`))
100+
require.NoError(t, err)
101+
req.Header.Set("Mcp-Session-Id", clientSID)
102+
103+
resp, err := r.reinitializeAndReplay(req, nil)
104+
assert.Nil(t, resp)
105+
assert.NoError(t, err)
106+
107+
// backend_url should be reset to ClusterIP so the next request routes correctly.
108+
updated, ok := store.Get(clientSID)
109+
require.True(t, ok)
110+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
111+
require.True(t, exists)
112+
assert.Equal(t, clusterIP, backendURL, "backend_url should be reset to ClusterIP when no init body")
113+
}
114+
115+
// TestBackendRecoveryHappyPath verifies the full re-init flow: the stored
116+
// initialize body is replayed to the ClusterIP, the new backend session ID is
117+
// captured, the session is updated, and the original request is replayed — all
118+
// without standing up a full TransparentProxy.
119+
func TestBackendRecoveryHappyPath(t *testing.T) {
120+
t.Parallel()
121+
122+
const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
123+
newBackendSID := uuid.New().String()
124+
var (
125+
forwardMu sync.Mutex
126+
forwardCalls []string
127+
)
128+
129+
// Backend: returns a session ID on initialize, 200 otherwise.
130+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131+
body, _ := io.ReadAll(r.Body)
132+
forwardMu.Lock()
133+
forwardCalls = append(forwardCalls, r.Header.Get("Mcp-Session-Id"))
134+
forwardMu.Unlock()
135+
if strings.Contains(string(body), `"initialize"`) {
136+
w.Header().Set("Mcp-Session-Id", newBackendSID)
137+
}
138+
w.WriteHeader(http.StatusOK)
139+
}))
140+
defer backend.Close()
141+
142+
clientSID := uuid.New().String()
143+
sess := session.NewProxySession(clientSID)
144+
sess.SetMetadata(sessionMetadataInitBody, initBody)
145+
store := newStubStore(sess)
146+
147+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
148+
149+
origBody := []byte(`{"method":"tools/list"}`)
150+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
151+
bytes.NewReader(origBody))
152+
require.NoError(t, err)
153+
req.Header.Set("Mcp-Session-Id", clientSID)
154+
req.Header.Set("Content-Type", "application/json")
155+
156+
resp, err := r.reinitializeAndReplay(req, origBody)
157+
require.NoError(t, err)
158+
require.NotNil(t, resp)
159+
assert.Equal(t, http.StatusOK, resp.StatusCode)
160+
_ = resp.Body.Close()
161+
162+
// Verify session was updated with new backend SID and a pod URL.
163+
updated, ok := store.Get(clientSID)
164+
require.True(t, ok)
165+
backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID)
166+
require.True(t, exists)
167+
assert.Equal(t, newBackendSID, backendSID)
168+
169+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
170+
require.True(t, exists)
171+
assert.NotEmpty(t, backendURL)
172+
173+
// Two forward calls: initialize + replay. The initialize must not carry
174+
// a session ID; the replay must carry the new backend SID.
175+
forwardMu.Lock()
176+
defer forwardMu.Unlock()
177+
require.Len(t, forwardCalls, 2, "forward should be called for initialize and replay")
178+
assert.Empty(t, forwardCalls[0], "initialize request must not carry Mcp-Session-Id")
179+
assert.Equal(t, newBackendSID, forwardCalls[1], "replay must carry the new backend SID")
180+
}
181+
182+
// TestBackendRecoveryReinitForwardError verifies that a forward error during
183+
// re-initialization is returned to the caller.
184+
func TestBackendRecoveryReinitForwardError(t *testing.T) {
185+
t.Parallel()
186+
187+
// Server that is immediately closed — all connections will be refused.
188+
dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
189+
deadURL := dead.URL
190+
dead.Close()
191+
192+
clientSID := uuid.New().String()
193+
sess := session.NewProxySession(clientSID)
194+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
195+
store := newStubStore(sess)
196+
197+
r := newRecovery(deadURL, store, http.DefaultTransport.RoundTrip)
198+
199+
req, err := http.NewRequest(http.MethodPost, deadURL+"/mcp",
200+
strings.NewReader(`{"method":"tools/list"}`))
201+
require.NoError(t, err)
202+
req.Header.Set("Mcp-Session-Id", clientSID)
203+
204+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
205+
assert.Nil(t, resp)
206+
assert.Error(t, err, "forward error during re-init should be returned")
207+
}
208+
209+
// TestBackendRecoveryNoNewSessionID verifies that when the re-initialize
210+
// response carries no Mcp-Session-Id, reinitializeAndReplay resets backend_url
211+
// to ClusterIP and returns (nil, nil).
212+
func TestBackendRecoveryNoNewSessionID(t *testing.T) {
213+
t.Parallel()
214+
215+
// Backend that returns no Mcp-Session-Id on initialize.
216+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
217+
w.WriteHeader(http.StatusOK) // no Mcp-Session-Id header
218+
}))
219+
defer backend.Close()
220+
221+
clientSID := uuid.New().String()
222+
sess := session.NewProxySession(clientSID)
223+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
224+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080")
225+
store := newStubStore(sess)
226+
227+
// targetURI points to backend (so the init request succeeds), but we verify
228+
// that backend_url is reset to targetURI when no session ID comes back.
229+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
230+
231+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
232+
strings.NewReader(`{"method":"tools/list"}`))
233+
require.NoError(t, err)
234+
req.Header.Set("Mcp-Session-Id", clientSID)
235+
236+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
237+
assert.Nil(t, resp)
238+
assert.NoError(t, err)
239+
240+
updated, ok := store.Get(clientSID)
241+
require.True(t, ok)
242+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
243+
require.True(t, exists)
244+
assert.Equal(t, backend.URL, backendURL, "backend_url should fall back to targetURI when no new session ID")
245+
}
246+
247+
// TestPodBackendURLWithCapturedAddr verifies that a captured pod IP replaces the
248+
// host in targetURI while preserving the scheme.
249+
func TestPodBackendURLWithCapturedAddr(t *testing.T) {
250+
t.Parallel()
251+
252+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
253+
got := r.podBackendURL("10.0.0.5:8080")
254+
assert.Equal(t, "http://10.0.0.5:8080", got)
255+
}
256+
257+
// TestPodBackendURLFallback verifies that an empty captured address falls back
258+
// to targetURI unchanged.
259+
func TestPodBackendURLFallback(t *testing.T) {
260+
t.Parallel()
261+
262+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
263+
got := r.podBackendURL("")
264+
assert.Equal(t, "http://cluster-ip:8080", got)
265+
}
266+
267+
// TestPodBackendURLHTTPSFallback verifies that an HTTPS targetURI is never
268+
// rewritten to a pod IP. IP-literal HTTPS URLs fail TLS verification because
269+
// server certificates are issued for hostnames, not pod IPs.
270+
func TestPodBackendURLHTTPSFallback(t *testing.T) {
271+
t.Parallel()
272+
273+
r := &backendRecovery{targetURI: "https://mcp.example.com/mcp"}
274+
got := r.podBackendURL("1.2.3.4:443")
275+
assert.Equal(t, "https://mcp.example.com/mcp", got,
276+
"HTTPS target must not be rewritten to a pod IP")
277+
}

pkg/transport/proxy/transparent/backend_routing_test.go

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ func TestRoundTripReturns404ForUnknownSession(t *testing.T) {
172172
}))
173173
defer backend.Close()
174174

175-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
175+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
176176
"localhost", 0, backend.URL,
177177
nil, nil, nil,
178178
false, false, "sse",
179179
nil, nil, "", false,
180180
nil,
181-
)}
181+
))
182182

183183
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
184184
strings.NewReader(`{"method":"tools/list"}`))
@@ -207,13 +207,13 @@ func TestRoundTripAllowsInitializeWithUnknownSession(t *testing.T) {
207207
}))
208208
defer backend.Close()
209209

210-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
210+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
211211
"localhost", 0, backend.URL,
212212
nil, nil, nil,
213213
false, false, "sse",
214214
nil, nil, "", false,
215215
nil,
216-
)}
216+
))
217217

218218
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
219219
strings.NewReader(`{"method":"initialize"}`))
@@ -239,13 +239,13 @@ func TestRoundTripAllowsBatchInitializeWithUnknownSession(t *testing.T) {
239239
}))
240240
defer backend.Close()
241241

242-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
242+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
243243
"localhost", 0, backend.URL,
244244
nil, nil, nil,
245245
false, false, "sse",
246246
nil, nil, "", false,
247247
nil,
248-
)}
248+
))
249249

250250
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
251251
strings.NewReader(`[{"method":"initialize"},{"method":"tools/list"}]`))
@@ -470,6 +470,63 @@ func TestRoundTripReinitializesPreservesNonUUIDBackendSessionID(t *testing.T) {
470470
assert.Equal(t, nonUUIDSessionID, receivedSIDs[1], "subsequent request via Rewrite must forward raw non-UUID session ID")
471471
}
472472

473+
// TestRoundTripReinitializesAfterPriorReinit verifies that re-initialization
474+
// triggers correctly on a second failure when the session already has a
475+
// backend_sid from a prior re-init. Without the clientSID capture fix,
476+
// RoundTrip rewrites the header to backend_sid before calling reinitializeAndReplay,
477+
// which then looks up the session by the (wrong) backend SID and finds nothing.
478+
func TestRoundTripReinitializesAfterPriorReinit(t *testing.T) {
479+
t.Parallel()
480+
481+
firstBackendSID := uuid.New().String()
482+
secondBackendSID := uuid.New().String()
483+
484+
// staleBackend: returns 404 to trigger re-init.
485+
staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
486+
w.WriteHeader(http.StatusNotFound)
487+
}))
488+
defer staleBackend.Close()
489+
490+
var freshHit atomic.Int32
491+
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
492+
freshHit.Add(1)
493+
body, _ := io.ReadAll(r.Body)
494+
if strings.Contains(string(body), `"initialize"`) {
495+
w.Header().Set("Mcp-Session-Id", secondBackendSID)
496+
}
497+
w.WriteHeader(http.StatusOK)
498+
}))
499+
defer freshBackend.Close()
500+
501+
proxy, addr := startProxy(t, freshBackend.URL)
502+
503+
// Session pre-populated as if a prior re-init already happened:
504+
// backend_url points to staleBackend, backend_sid is set to firstBackendSID.
505+
clientSessionID := uuid.New().String()
506+
sess := session.NewProxySession(clientSessionID)
507+
sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL)
508+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
509+
sess.SetMetadata(sessionMetadataBackendSID, firstBackendSID)
510+
require.NoError(t, proxy.sessionManager.AddSession(sess))
511+
512+
ctx := context.Background()
513+
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
514+
"http://"+addr+"/mcp",
515+
strings.NewReader(`{"method":"tools/list"}`))
516+
require.NoError(t, err)
517+
req.Header.Set("Content-Type", "application/json")
518+
req.Header.Set("Mcp-Session-Id", clientSessionID)
519+
520+
resp, err := http.DefaultClient.Do(req)
521+
require.NoError(t, err)
522+
_ = resp.Body.Close()
523+
524+
assert.Equal(t, http.StatusOK, resp.StatusCode,
525+
"client should see 200: re-init must use client SID for session lookup, not backend SID")
526+
assert.GreaterOrEqual(t, freshHit.Load(), int32(2),
527+
"fresh backend should receive re-initialize + replay")
528+
}
529+
473530
// TestRoundTripReinitializesOnDialError verifies that when the proxy cannot reach
474531
// the stored pod IP (dial error — pod rescheduled to a new IP), it transparently
475532
// re-initializes the backend session via the ClusterIP and replays the original

0 commit comments

Comments
 (0)