|
| 1 | +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package e2e_test |
| 5 | + |
| 6 | +import ( |
| 7 | + "encoding/json" |
| 8 | + "errors" |
| 9 | + "fmt" |
| 10 | + "io" |
| 11 | + "net" |
| 12 | + "net/http" |
| 13 | + "os" |
| 14 | + "os/exec" |
| 15 | + "strings" |
| 16 | + "sync/atomic" |
| 17 | + "time" |
| 18 | + |
| 19 | + . "github.com/onsi/ginkgo/v2" |
| 20 | + . "github.com/onsi/gomega" |
| 21 | + |
| 22 | + "github.com/stacklok/toolhive/test/e2e" |
| 23 | +) |
| 24 | + |
| 25 | +var _ = Describe("Stateless Proxy Mode", Label("proxy", "stateless", "streamable-http", "e2e"), Serial, func() { |
| 26 | + var ( |
| 27 | + config *e2e.TestConfig |
| 28 | + serverName string |
| 29 | + mockServer *statelessMockMCPServer |
| 30 | + ) |
| 31 | + |
| 32 | + BeforeEach(func() { |
| 33 | + config = e2e.NewTestConfig() |
| 34 | + serverName = e2e.GenerateUniqueServerName("stateless") |
| 35 | + |
| 36 | + err := e2e.CheckTHVBinaryAvailable(config) |
| 37 | + Expect(err).ToNot(HaveOccurred(), "thv binary should be available") |
| 38 | + }) |
| 39 | + |
| 40 | + AfterEach(func() { |
| 41 | + if mockServer != nil { |
| 42 | + mockServer.Stop() |
| 43 | + mockServer = nil |
| 44 | + } |
| 45 | + |
| 46 | + if config.CleanupAfter { |
| 47 | + err := e2e.StopAndRemoveMCPServer(config, serverName) |
| 48 | + Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server") |
| 49 | + } |
| 50 | + }) |
| 51 | + |
| 52 | + Describe("Method gating for stateless servers", func() { |
| 53 | + Context("when --stateless flag is set on a remote server", func() { |
| 54 | + It("should reject GET requests and forward POST requests", func() { |
| 55 | + By("Starting a stateless mock MCP server") |
| 56 | + var err error |
| 57 | + mockServer, err = newStatelessMockMCPServer() |
| 58 | + Expect(err).ToNot(HaveOccurred(), "Should be able to start mock server") |
| 59 | + |
| 60 | + mockServerURL := mockServer.URL() |
| 61 | + GinkgoWriter.Printf("Mock server started at: %s\n", mockServerURL) |
| 62 | + |
| 63 | + By("Starting thv with --stateless flag") |
| 64 | + thvCmd := exec.Command(config.THVBinary, "run", |
| 65 | + "--name", serverName, |
| 66 | + "--stateless", |
| 67 | + mockServerURL+"/mcp") |
| 68 | + thvCmd.Env = append(os.Environ(), |
| 69 | + "TOOLHIVE_REMOTE_HEALTHCHECKS=true", |
| 70 | + ) |
| 71 | + thvCmd.Stdout = GinkgoWriter |
| 72 | + thvCmd.Stderr = GinkgoWriter |
| 73 | + |
| 74 | + err = thvCmd.Start() |
| 75 | + Expect(err).ToNot(HaveOccurred(), "Should be able to start thv") |
| 76 | + |
| 77 | + thvPID := thvCmd.Process.Pid |
| 78 | + GinkgoWriter.Printf("thv process started with PID: %d\n", thvPID) |
| 79 | + |
| 80 | + defer func() { |
| 81 | + if proc, err := os.FindProcess(thvPID); err == nil { |
| 82 | + _ = proc.Kill() |
| 83 | + } |
| 84 | + }() |
| 85 | + |
| 86 | + By("Waiting for thv to register as running") |
| 87 | + err = e2e.WaitForMCPServer(config, serverName, 60*time.Second) |
| 88 | + Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds") |
| 89 | + |
| 90 | + By("Getting the proxy URL") |
| 91 | + proxyURL, err := e2e.GetMCPServerURL(config, serverName) |
| 92 | + Expect(err).ToNot(HaveOccurred(), "Should be able to get proxy URL") |
| 93 | + // Ensure URL has /mcp suffix |
| 94 | + if !strings.HasSuffix(proxyURL, "/mcp") { |
| 95 | + proxyURL += "/mcp" |
| 96 | + } |
| 97 | + GinkgoWriter.Printf("Proxy URL: %s\n", proxyURL) |
| 98 | + |
| 99 | + By("Verifying GET requests are rejected with 405") |
| 100 | + resp, err := http.Get(proxyURL) |
| 101 | + Expect(err).ToNot(HaveOccurred(), "Should be able to connect to proxy") |
| 102 | + resp.Body.Close() |
| 103 | + Expect(resp.StatusCode).To(Equal(http.StatusMethodNotAllowed), |
| 104 | + "GET request should be rejected with 405 Method Not Allowed") |
| 105 | + |
| 106 | + By("Verifying POST requests are forwarded successfully") |
| 107 | + initReq := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e-test","version":"1.0"}}}` |
| 108 | + postResp, err := http.Post(proxyURL, "application/json", strings.NewReader(initReq)) |
| 109 | + Expect(err).ToNot(HaveOccurred(), "Should be able to POST to proxy") |
| 110 | + defer postResp.Body.Close() |
| 111 | + |
| 112 | + Expect(postResp.StatusCode).To(Equal(http.StatusOK), |
| 113 | + "POST request should be forwarded and return 200") |
| 114 | + |
| 115 | + body, err := io.ReadAll(postResp.Body) |
| 116 | + Expect(err).ToNot(HaveOccurred(), "Should be able to read response body") |
| 117 | + |
| 118 | + var jsonRPC map[string]interface{} |
| 119 | + err = json.Unmarshal(body, &jsonRPC) |
| 120 | + Expect(err).ToNot(HaveOccurred(), "Response should be valid JSON-RPC") |
| 121 | + Expect(jsonRPC).To(HaveKey("result"), "Response should have a result field") |
| 122 | + |
| 123 | + By("Verifying DELETE requests are also rejected") |
| 124 | + delReq, err := http.NewRequest(http.MethodDelete, proxyURL, nil) |
| 125 | + Expect(err).ToNot(HaveOccurred()) |
| 126 | + delResp, err := http.DefaultClient.Do(delReq) |
| 127 | + Expect(err).ToNot(HaveOccurred(), "Should be able to send DELETE to proxy") |
| 128 | + delResp.Body.Close() |
| 129 | + Expect(delResp.StatusCode).To(Equal(http.StatusMethodNotAllowed), |
| 130 | + "DELETE request should be rejected with 405") |
| 131 | + |
| 132 | + By("Verifying the mock server received POST requests through the proxy") |
| 133 | + Expect(mockServer.GetCount()).To(BeNumerically(">", 0), |
| 134 | + "Mock server should have received at least one POST request") |
| 135 | + }) |
| 136 | + }) |
| 137 | + }) |
| 138 | +}) |
| 139 | + |
| 140 | +// statelessMockMCPServer is a minimal MCP server that only accepts POST. |
| 141 | +// It tracks whether any GET requests reached it (which would indicate |
| 142 | +// the proxy's method gate is not working). |
| 143 | +type statelessMockMCPServer struct { |
| 144 | + server *http.Server |
| 145 | + listener net.Listener |
| 146 | + port int |
| 147 | + gotGET atomic.Bool |
| 148 | + postHits atomic.Int32 |
| 149 | +} |
| 150 | + |
| 151 | +func newStatelessMockMCPServer() (*statelessMockMCPServer, error) { |
| 152 | + listener, err := net.Listen("tcp", "127.0.0.1:0") |
| 153 | + if err != nil { |
| 154 | + return nil, fmt.Errorf("failed to create listener: %w", err) |
| 155 | + } |
| 156 | + |
| 157 | + port := listener.Addr().(*net.TCPAddr).Port |
| 158 | + |
| 159 | + mock := &statelessMockMCPServer{ |
| 160 | + listener: listener, |
| 161 | + port: port, |
| 162 | + } |
| 163 | + |
| 164 | + mock.server = &http.Server{ |
| 165 | + Handler: http.HandlerFunc(mock.handleRequest), |
| 166 | + } |
| 167 | + |
| 168 | + go func() { |
| 169 | + if err := mock.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { |
| 170 | + GinkgoWriter.Printf("Stateless mock server error: %v\n", err) |
| 171 | + } |
| 172 | + }() |
| 173 | + |
| 174 | + time.Sleep(100 * time.Millisecond) |
| 175 | + |
| 176 | + return mock, nil |
| 177 | +} |
| 178 | + |
| 179 | +func (m *statelessMockMCPServer) handleRequest(w http.ResponseWriter, r *http.Request) { |
| 180 | + // Always return 404 for OAuth well-known URIs |
| 181 | + if strings.HasPrefix(r.URL.Path, "/.well-known/") { |
| 182 | + w.WriteHeader(http.StatusNotFound) |
| 183 | + return |
| 184 | + } |
| 185 | + |
| 186 | + if r.Method == http.MethodGet { |
| 187 | + m.gotGET.Store(true) |
| 188 | + // A real stateless server would reject GETs, but we accept them here |
| 189 | + // so the test can detect if any GETs leaked through the proxy. |
| 190 | + w.WriteHeader(http.StatusMethodNotAllowed) |
| 191 | + return |
| 192 | + } |
| 193 | + |
| 194 | + m.postHits.Add(1) |
| 195 | + |
| 196 | + // Parse the JSON-RPC request to return appropriate responses |
| 197 | + body, err := io.ReadAll(r.Body) |
| 198 | + if err != nil { |
| 199 | + w.WriteHeader(http.StatusBadRequest) |
| 200 | + return |
| 201 | + } |
| 202 | + |
| 203 | + var req map[string]interface{} |
| 204 | + if err := json.Unmarshal(body, &req); err != nil { |
| 205 | + w.WriteHeader(http.StatusBadRequest) |
| 206 | + return |
| 207 | + } |
| 208 | + |
| 209 | + method, _ := req["method"].(string) |
| 210 | + id := req["id"] |
| 211 | + |
| 212 | + w.Header().Set("Content-Type", "application/json") |
| 213 | + w.WriteHeader(http.StatusOK) |
| 214 | + |
| 215 | + switch method { |
| 216 | + case "initialize": |
| 217 | + resp := map[string]interface{}{ |
| 218 | + "jsonrpc": "2.0", |
| 219 | + "id": id, |
| 220 | + "result": map[string]interface{}{ |
| 221 | + "protocolVersion": "2024-11-05", |
| 222 | + "capabilities": map[string]interface{}{}, |
| 223 | + "serverInfo": map[string]interface{}{ |
| 224 | + "name": "stateless-mock", |
| 225 | + "version": "1.0.0", |
| 226 | + }, |
| 227 | + }, |
| 228 | + } |
| 229 | + _ = json.NewEncoder(w).Encode(resp) |
| 230 | + case "ping": |
| 231 | + resp := map[string]interface{}{ |
| 232 | + "jsonrpc": "2.0", |
| 233 | + "id": id, |
| 234 | + "result": map[string]interface{}{}, |
| 235 | + } |
| 236 | + _ = json.NewEncoder(w).Encode(resp) |
| 237 | + default: |
| 238 | + resp := map[string]interface{}{ |
| 239 | + "jsonrpc": "2.0", |
| 240 | + "id": id, |
| 241 | + "result": map[string]interface{}{}, |
| 242 | + } |
| 243 | + _ = json.NewEncoder(w).Encode(resp) |
| 244 | + } |
| 245 | +} |
| 246 | + |
| 247 | +func (m *statelessMockMCPServer) URL() string { |
| 248 | + return fmt.Sprintf("http://127.0.0.1:%d", m.port) |
| 249 | +} |
| 250 | + |
| 251 | +func (m *statelessMockMCPServer) Stop() { |
| 252 | + if m.server != nil { |
| 253 | + _ = m.server.Close() |
| 254 | + } |
| 255 | +} |
| 256 | + |
| 257 | +func (m *statelessMockMCPServer) GetCount() int32 { |
| 258 | + return m.postHits.Load() |
| 259 | +} |
| 260 | + |
| 261 | +func (m *statelessMockMCPServer) GotGET() bool { |
| 262 | + return m.gotGET.Load() |
| 263 | +} |
0 commit comments