Skip to content

Commit 8dda48d

Browse files
committed
Reimplement tracingTransport logic + add e2e tests
Extract re-initialization logic from tracingTransport into a dedicated backendRecovery type backed by a narrow recoverySessionStore interface (Get + UpsertSession) and a forward func. tracingTransport now owns only the request lifecycle (session guard, initialize detection, httptrace, session creation) and delegates all forwarding and recovery to backendRecovery. This makes reinitializeAndReplay and podBackendURL testable without standing up a full TransparentProxy: backend_recovery_test.go covers all recovery paths (no session, no init body, happy path, forward error, missing new session ID) using a stubSessionStore and inline httptest servers. tracingTransport.forward() and the base field are removed; all network I/O goes through recovery.forward — a single source of truth for the underlying transport. Also integrates the E2E acceptance test from #4574 that exercises backendReplicas=2 + proxy runner restart, verifying that sessions are routed to the correct backend pod after re-initialization.
1 parent 32960b1 commit 8dda48d

File tree

5 files changed

+514
-47
lines changed

5 files changed

+514
-47
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package transparent
5+
6+
import (
7+
"bytes"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"testing"
13+
14+
"github.com/google/uuid"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
18+
"github.com/stacklok/toolhive/pkg/transport/session"
19+
)
20+
21+
// stubSessionStore is a minimal in-memory recoverySessionStore for unit tests.
22+
type stubSessionStore struct {
23+
sessions map[string]session.Session
24+
}
25+
26+
func newStubStore(sessions ...session.Session) *stubSessionStore {
27+
m := make(map[string]session.Session)
28+
for _, s := range sessions {
29+
m[s.ID()] = s
30+
}
31+
return &stubSessionStore{sessions: m}
32+
}
33+
34+
func (s *stubSessionStore) Get(id string) (session.Session, bool) {
35+
sess, ok := s.sessions[id]
36+
return sess, ok
37+
}
38+
39+
func (s *stubSessionStore) UpsertSession(sess session.Session) error {
40+
s.sessions[sess.ID()] = sess
41+
return nil
42+
}
43+
44+
// newRecovery builds a backendRecovery backed by the given store and forward func.
45+
func newRecovery(targetURL string, store recoverySessionStore, fwd func(*http.Request) (*http.Response, error)) *backendRecovery {
46+
return &backendRecovery{
47+
targetURI: targetURL,
48+
forward: fwd,
49+
sessions: store,
50+
}
51+
}
52+
53+
// TestBackendRecoveryNoSession verifies that reinitializeAndReplay returns
54+
// (nil, nil) when the request carries no Mcp-Session-Id.
55+
func TestBackendRecoveryNoSession(t *testing.T) {
56+
t.Parallel()
57+
58+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
59+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
60+
strings.NewReader(`{"method":"tools/list"}`))
61+
require.NoError(t, err)
62+
63+
resp, err := r.reinitializeAndReplay(req, nil)
64+
assert.Nil(t, resp)
65+
assert.NoError(t, err)
66+
}
67+
68+
// TestBackendRecoveryUnknownSession verifies that reinitializeAndReplay returns
69+
// (nil, nil) when the session ID is not in the store.
70+
func TestBackendRecoveryUnknownSession(t *testing.T) {
71+
t.Parallel()
72+
73+
r := newRecovery("http://cluster-ip:8080", newStubStore(), nil)
74+
req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp",
75+
strings.NewReader(`{"method":"tools/list"}`))
76+
require.NoError(t, err)
77+
req.Header.Set("Mcp-Session-Id", uuid.New().String())
78+
79+
resp, err := r.reinitializeAndReplay(req, nil)
80+
assert.Nil(t, resp)
81+
assert.NoError(t, err)
82+
}
83+
84+
// TestBackendRecoveryNoInitBody verifies that when the session has no stored
85+
// init body, reinitializeAndReplay resets backend_url to the ClusterIP and
86+
// returns (nil, nil) so the caller falls through to a 404 the client can handle.
87+
func TestBackendRecoveryNoInitBody(t *testing.T) {
88+
t.Parallel()
89+
90+
const clusterIP = "http://cluster-ip:8080"
91+
clientSID := uuid.New().String()
92+
sess := session.NewProxySession(clientSID)
93+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080") // stale pod IP
94+
store := newStubStore(sess)
95+
96+
r := newRecovery(clusterIP, store, nil)
97+
req, err := http.NewRequest(http.MethodPost, clusterIP+"/mcp",
98+
strings.NewReader(`{"method":"tools/list"}`))
99+
require.NoError(t, err)
100+
req.Header.Set("Mcp-Session-Id", clientSID)
101+
102+
resp, err := r.reinitializeAndReplay(req, nil)
103+
assert.Nil(t, resp)
104+
assert.NoError(t, err)
105+
106+
// backend_url should be reset to ClusterIP so the next request routes correctly.
107+
updated, ok := store.Get(clientSID)
108+
require.True(t, ok)
109+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
110+
require.True(t, exists)
111+
assert.Equal(t, clusterIP, backendURL, "backend_url should be reset to ClusterIP when no init body")
112+
}
113+
114+
// TestBackendRecoveryHappyPath verifies the full re-init flow: the stored
115+
// initialize body is replayed to the ClusterIP, the new backend session ID is
116+
// captured, the session is updated, and the original request is replayed — all
117+
// without standing up a full TransparentProxy.
118+
func TestBackendRecoveryHappyPath(t *testing.T) {
119+
t.Parallel()
120+
121+
const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
122+
newBackendSID := uuid.New().String()
123+
var forwardCalls []string
124+
125+
// Backend: returns a session ID on initialize, 200 otherwise.
126+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
127+
body, _ := io.ReadAll(r.Body)
128+
forwardCalls = append(forwardCalls, r.Header.Get("Mcp-Session-Id"))
129+
if strings.Contains(string(body), `"initialize"`) {
130+
w.Header().Set("Mcp-Session-Id", newBackendSID)
131+
}
132+
w.WriteHeader(http.StatusOK)
133+
}))
134+
defer backend.Close()
135+
136+
clientSID := uuid.New().String()
137+
sess := session.NewProxySession(clientSID)
138+
sess.SetMetadata(sessionMetadataInitBody, initBody)
139+
store := newStubStore(sess)
140+
141+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
142+
143+
origBody := []byte(`{"method":"tools/list"}`)
144+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
145+
bytes.NewReader(origBody))
146+
require.NoError(t, err)
147+
req.Header.Set("Mcp-Session-Id", clientSID)
148+
req.Header.Set("Content-Type", "application/json")
149+
150+
resp, err := r.reinitializeAndReplay(req, origBody)
151+
require.NoError(t, err)
152+
require.NotNil(t, resp)
153+
assert.Equal(t, http.StatusOK, resp.StatusCode)
154+
_ = resp.Body.Close()
155+
156+
// Verify session was updated with new backend SID and a pod URL.
157+
updated, ok := store.Get(clientSID)
158+
require.True(t, ok)
159+
backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID)
160+
require.True(t, exists)
161+
assert.Equal(t, newBackendSID, backendSID)
162+
163+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
164+
require.True(t, exists)
165+
assert.NotEmpty(t, backendURL)
166+
167+
// Two forward calls: initialize + replay. The initialize must not carry
168+
// a session ID; the replay must carry the new backend SID.
169+
require.Len(t, forwardCalls, 2, "forward should be called for initialize and replay")
170+
assert.Empty(t, forwardCalls[0], "initialize request must not carry Mcp-Session-Id")
171+
assert.Equal(t, newBackendSID, forwardCalls[1], "replay must carry the new backend SID")
172+
}
173+
174+
// TestBackendRecoveryReinitForwardError verifies that a forward error during
175+
// re-initialization is returned to the caller.
176+
func TestBackendRecoveryReinitForwardError(t *testing.T) {
177+
t.Parallel()
178+
179+
// Server that is immediately closed — all connections will be refused.
180+
dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
181+
deadURL := dead.URL
182+
dead.Close()
183+
184+
clientSID := uuid.New().String()
185+
sess := session.NewProxySession(clientSID)
186+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
187+
store := newStubStore(sess)
188+
189+
r := newRecovery(deadURL, store, http.DefaultTransport.RoundTrip)
190+
191+
req, err := http.NewRequest(http.MethodPost, deadURL+"/mcp",
192+
strings.NewReader(`{"method":"tools/list"}`))
193+
require.NoError(t, err)
194+
req.Header.Set("Mcp-Session-Id", clientSID)
195+
196+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
197+
assert.Nil(t, resp)
198+
assert.Error(t, err, "forward error during re-init should be returned")
199+
}
200+
201+
// TestBackendRecoveryNoNewSessionID verifies that when the re-initialize
202+
// response carries no Mcp-Session-Id, reinitializeAndReplay resets backend_url
203+
// to ClusterIP and returns (nil, nil).
204+
func TestBackendRecoveryNoNewSessionID(t *testing.T) {
205+
t.Parallel()
206+
207+
// Backend that returns no Mcp-Session-Id on initialize.
208+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
209+
w.WriteHeader(http.StatusOK) // no Mcp-Session-Id header
210+
}))
211+
defer backend.Close()
212+
213+
clientSID := uuid.New().String()
214+
sess := session.NewProxySession(clientSID)
215+
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
216+
sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080")
217+
store := newStubStore(sess)
218+
219+
// targetURI points to backend (so the init request succeeds), but we verify
220+
// that backend_url is reset to targetURI when no session ID comes back.
221+
r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip)
222+
223+
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
224+
strings.NewReader(`{"method":"tools/list"}`))
225+
require.NoError(t, err)
226+
req.Header.Set("Mcp-Session-Id", clientSID)
227+
228+
resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`))
229+
assert.Nil(t, resp)
230+
assert.NoError(t, err)
231+
232+
updated, ok := store.Get(clientSID)
233+
require.True(t, ok)
234+
backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL)
235+
require.True(t, exists)
236+
assert.Equal(t, backend.URL, backendURL, "backend_url should fall back to targetURI when no new session ID")
237+
}
238+
239+
// TestPodBackendURLWithCapturedAddr verifies that a captured pod IP replaces the
240+
// host in targetURI while preserving the scheme.
241+
func TestPodBackendURLWithCapturedAddr(t *testing.T) {
242+
t.Parallel()
243+
244+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
245+
got := r.podBackendURL("10.0.0.5:8080")
246+
assert.Equal(t, "http://10.0.0.5:8080", got)
247+
}
248+
249+
// TestPodBackendURLFallback verifies that an empty captured address falls back
250+
// to targetURI unchanged.
251+
func TestPodBackendURLFallback(t *testing.T) {
252+
t.Parallel()
253+
254+
r := &backendRecovery{targetURI: "http://cluster-ip:8080"}
255+
got := r.podBackendURL("")
256+
assert.Equal(t, "http://cluster-ip:8080", got)
257+
}

pkg/transport/proxy/transparent/backend_routing_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ func TestRoundTripReturns404ForUnknownSession(t *testing.T) {
172172
}))
173173
defer backend.Close()
174174

175-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
175+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
176176
"localhost", 0, backend.URL,
177177
nil, nil, nil,
178178
false, false, "sse",
179179
nil, nil, "", false,
180180
nil,
181-
)}
181+
))
182182

183183
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
184184
strings.NewReader(`{"method":"tools/list"}`))
@@ -207,13 +207,13 @@ func TestRoundTripAllowsInitializeWithUnknownSession(t *testing.T) {
207207
}))
208208
defer backend.Close()
209209

210-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
210+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
211211
"localhost", 0, backend.URL,
212212
nil, nil, nil,
213213
false, false, "sse",
214214
nil, nil, "", false,
215215
nil,
216-
)}
216+
))
217217

218218
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
219219
strings.NewReader(`{"method":"initialize"}`))
@@ -239,13 +239,13 @@ func TestRoundTripAllowsBatchInitializeWithUnknownSession(t *testing.T) {
239239
}))
240240
defer backend.Close()
241241

242-
tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions(
242+
tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions(
243243
"localhost", 0, backend.URL,
244244
nil, nil, nil,
245245
false, false, "sse",
246246
nil, nil, "", false,
247247
nil,
248-
)}
248+
))
249249

250250
req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp",
251251
strings.NewReader(`[{"method":"initialize"},{"method":"tools/list"}]`))

0 commit comments

Comments
 (0)