Skip to content

Commit 830deff

Browse files
committed
Add E2E test for stateless proxy mode
Verify that --stateless makes the proxy reject GET and DELETE with 405 and forward POST to the upstream MCP server. Uses an in-process mock server to avoid external dependencies. Signed-off-by: Greg Katz <gkatz@indeed.com>
1 parent af744f6 commit 830deff

1 file changed

Lines changed: 263 additions & 0 deletions

File tree

test/e2e/stateless_proxy_test.go

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package e2e_test
5+
6+
import (
7+
"encoding/json"
8+
"errors"
9+
"fmt"
10+
"io"
11+
"net"
12+
"net/http"
13+
"os"
14+
"os/exec"
15+
"strings"
16+
"sync/atomic"
17+
"time"
18+
19+
. "github.com/onsi/ginkgo/v2"
20+
. "github.com/onsi/gomega"
21+
22+
"github.com/stacklok/toolhive/test/e2e"
23+
)
24+
25+
var _ = Describe("Stateless Proxy Mode", Label("proxy", "stateless", "streamable-http", "e2e"), Serial, func() {
26+
var (
27+
config *e2e.TestConfig
28+
serverName string
29+
mockServer *statelessMockMCPServer
30+
)
31+
32+
BeforeEach(func() {
33+
config = e2e.NewTestConfig()
34+
serverName = e2e.GenerateUniqueServerName("stateless")
35+
36+
err := e2e.CheckTHVBinaryAvailable(config)
37+
Expect(err).ToNot(HaveOccurred(), "thv binary should be available")
38+
})
39+
40+
AfterEach(func() {
41+
if mockServer != nil {
42+
mockServer.Stop()
43+
mockServer = nil
44+
}
45+
46+
if config.CleanupAfter {
47+
err := e2e.StopAndRemoveMCPServer(config, serverName)
48+
Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server")
49+
}
50+
})
51+
52+
Describe("Method gating for stateless servers", func() {
53+
Context("when --stateless flag is set on a remote server", func() {
54+
It("should reject GET requests and forward POST requests", func() {
55+
By("Starting a stateless mock MCP server")
56+
var err error
57+
mockServer, err = newStatelessMockMCPServer()
58+
Expect(err).ToNot(HaveOccurred(), "Should be able to start mock server")
59+
60+
mockServerURL := mockServer.URL()
61+
GinkgoWriter.Printf("Mock server started at: %s\n", mockServerURL)
62+
63+
By("Starting thv with --stateless flag")
64+
thvCmd := exec.Command(config.THVBinary, "run",
65+
"--name", serverName,
66+
"--stateless",
67+
mockServerURL+"/mcp")
68+
thvCmd.Env = append(os.Environ(),
69+
"TOOLHIVE_REMOTE_HEALTHCHECKS=true",
70+
)
71+
thvCmd.Stdout = GinkgoWriter
72+
thvCmd.Stderr = GinkgoWriter
73+
74+
err = thvCmd.Start()
75+
Expect(err).ToNot(HaveOccurred(), "Should be able to start thv")
76+
77+
thvPID := thvCmd.Process.Pid
78+
GinkgoWriter.Printf("thv process started with PID: %d\n", thvPID)
79+
80+
defer func() {
81+
if proc, err := os.FindProcess(thvPID); err == nil {
82+
_ = proc.Kill()
83+
}
84+
}()
85+
86+
By("Waiting for thv to register as running")
87+
err = e2e.WaitForMCPServer(config, serverName, 60*time.Second)
88+
Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds")
89+
90+
By("Getting the proxy URL")
91+
proxyURL, err := e2e.GetMCPServerURL(config, serverName)
92+
Expect(err).ToNot(HaveOccurred(), "Should be able to get proxy URL")
93+
// Ensure URL has /mcp suffix
94+
if !strings.HasSuffix(proxyURL, "/mcp") {
95+
proxyURL += "/mcp"
96+
}
97+
GinkgoWriter.Printf("Proxy URL: %s\n", proxyURL)
98+
99+
By("Verifying GET requests are rejected with 405")
100+
resp, err := http.Get(proxyURL)
101+
Expect(err).ToNot(HaveOccurred(), "Should be able to connect to proxy")
102+
resp.Body.Close()
103+
Expect(resp.StatusCode).To(Equal(http.StatusMethodNotAllowed),
104+
"GET request should be rejected with 405 Method Not Allowed")
105+
106+
By("Verifying POST requests are forwarded successfully")
107+
initReq := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e-test","version":"1.0"}}}`
108+
postResp, err := http.Post(proxyURL, "application/json", strings.NewReader(initReq))
109+
Expect(err).ToNot(HaveOccurred(), "Should be able to POST to proxy")
110+
defer postResp.Body.Close()
111+
112+
Expect(postResp.StatusCode).To(Equal(http.StatusOK),
113+
"POST request should be forwarded and return 200")
114+
115+
body, err := io.ReadAll(postResp.Body)
116+
Expect(err).ToNot(HaveOccurred(), "Should be able to read response body")
117+
118+
var jsonRPC map[string]interface{}
119+
err = json.Unmarshal(body, &jsonRPC)
120+
Expect(err).ToNot(HaveOccurred(), "Response should be valid JSON-RPC")
121+
Expect(jsonRPC).To(HaveKey("result"), "Response should have a result field")
122+
123+
By("Verifying DELETE requests are also rejected")
124+
delReq, err := http.NewRequest(http.MethodDelete, proxyURL, nil)
125+
Expect(err).ToNot(HaveOccurred())
126+
delResp, err := http.DefaultClient.Do(delReq)
127+
Expect(err).ToNot(HaveOccurred(), "Should be able to send DELETE to proxy")
128+
delResp.Body.Close()
129+
Expect(delResp.StatusCode).To(Equal(http.StatusMethodNotAllowed),
130+
"DELETE request should be rejected with 405")
131+
132+
By("Verifying the mock server received POST requests through the proxy")
133+
Expect(mockServer.GetCount()).To(BeNumerically(">", 0),
134+
"Mock server should have received at least one POST request")
135+
})
136+
})
137+
})
138+
})
139+
140+
// statelessMockMCPServer is a minimal MCP server that only accepts POST.
141+
// It tracks whether any GET requests reached it (which would indicate
142+
// the proxy's method gate is not working).
143+
type statelessMockMCPServer struct {
144+
server *http.Server
145+
listener net.Listener
146+
port int
147+
gotGET atomic.Bool
148+
postHits atomic.Int32
149+
}
150+
151+
func newStatelessMockMCPServer() (*statelessMockMCPServer, error) {
152+
listener, err := net.Listen("tcp", "127.0.0.1:0")
153+
if err != nil {
154+
return nil, fmt.Errorf("failed to create listener: %w", err)
155+
}
156+
157+
port := listener.Addr().(*net.TCPAddr).Port
158+
159+
mock := &statelessMockMCPServer{
160+
listener: listener,
161+
port: port,
162+
}
163+
164+
mock.server = &http.Server{
165+
Handler: http.HandlerFunc(mock.handleRequest),
166+
}
167+
168+
go func() {
169+
if err := mock.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
170+
GinkgoWriter.Printf("Stateless mock server error: %v\n", err)
171+
}
172+
}()
173+
174+
time.Sleep(100 * time.Millisecond)
175+
176+
return mock, nil
177+
}
178+
179+
func (m *statelessMockMCPServer) handleRequest(w http.ResponseWriter, r *http.Request) {
180+
// Always return 404 for OAuth well-known URIs
181+
if strings.HasPrefix(r.URL.Path, "/.well-known/") {
182+
w.WriteHeader(http.StatusNotFound)
183+
return
184+
}
185+
186+
if r.Method == http.MethodGet {
187+
m.gotGET.Store(true)
188+
// A real stateless server would reject GETs, but we accept them here
189+
// so the test can detect if any GETs leaked through the proxy.
190+
w.WriteHeader(http.StatusMethodNotAllowed)
191+
return
192+
}
193+
194+
m.postHits.Add(1)
195+
196+
// Parse the JSON-RPC request to return appropriate responses
197+
body, err := io.ReadAll(r.Body)
198+
if err != nil {
199+
w.WriteHeader(http.StatusBadRequest)
200+
return
201+
}
202+
203+
var req map[string]interface{}
204+
if err := json.Unmarshal(body, &req); err != nil {
205+
w.WriteHeader(http.StatusBadRequest)
206+
return
207+
}
208+
209+
method, _ := req["method"].(string)
210+
id := req["id"]
211+
212+
w.Header().Set("Content-Type", "application/json")
213+
w.WriteHeader(http.StatusOK)
214+
215+
switch method {
216+
case "initialize":
217+
resp := map[string]interface{}{
218+
"jsonrpc": "2.0",
219+
"id": id,
220+
"result": map[string]interface{}{
221+
"protocolVersion": "2024-11-05",
222+
"capabilities": map[string]interface{}{},
223+
"serverInfo": map[string]interface{}{
224+
"name": "stateless-mock",
225+
"version": "1.0.0",
226+
},
227+
},
228+
}
229+
_ = json.NewEncoder(w).Encode(resp)
230+
case "ping":
231+
resp := map[string]interface{}{
232+
"jsonrpc": "2.0",
233+
"id": id,
234+
"result": map[string]interface{}{},
235+
}
236+
_ = json.NewEncoder(w).Encode(resp)
237+
default:
238+
resp := map[string]interface{}{
239+
"jsonrpc": "2.0",
240+
"id": id,
241+
"result": map[string]interface{}{},
242+
}
243+
_ = json.NewEncoder(w).Encode(resp)
244+
}
245+
}
246+
247+
func (m *statelessMockMCPServer) URL() string {
248+
return fmt.Sprintf("http://127.0.0.1:%d", m.port)
249+
}
250+
251+
func (m *statelessMockMCPServer) Stop() {
252+
if m.server != nil {
253+
_ = m.server.Close()
254+
}
255+
}
256+
257+
func (m *statelessMockMCPServer) GetCount() int32 {
258+
return m.postHits.Load()
259+
}
260+
261+
func (m *statelessMockMCPServer) GotGET() bool {
262+
return m.gotGET.Load()
263+
}

0 commit comments

Comments
 (0)