diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 7aef54bfcc..3e3560306b 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -58,6 +58,9 @@ type RunFlags struct { // Remote MCP server support RemoteURL string + // Stateless indicates the server is stateless (POST-only, no SSE) + Stateless bool + // Security and audit AuthzConfig string AuditConfig string @@ -253,6 +256,9 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { cmd.Flags().BoolVar(&config.TrustProxyHeaders, "trust-proxy-headers", false, "Trust X-Forwarded-* headers from reverse proxies (X-Forwarded-Proto, X-Forwarded-Host, X-Forwarded-Port, X-Forwarded-Prefix) "+ "(default false)") + cmd.Flags().BoolVar(&config.Stateless, "stateless", false, + "Declare the server as stateless (POST-only, no SSE). "+ + "Use for MCP servers implementing streamable-HTTP stateless mode.") cmd.Flags().StringVar(&config.EndpointPrefix, "endpoint-prefix", "", "Path prefix to prepend to SSE endpoint URLs (e.g., /playwright)") cmd.Flags().StringVar(&config.Network, "network", "", @@ -611,6 +617,7 @@ func buildRunnerConfig( runner.WithNetworkIsolation(runFlags.IsolateNetwork), runner.WithAllowDockerGateway(runFlags.AllowDockerGateway), runner.WithTrustProxyHeaders(runFlags.TrustProxyHeaders), + runner.WithStateless(runFlags.Stateless), runner.WithEndpointPrefix(runFlags.EndpointPrefix), runner.WithNetworkMode(runFlags.Network), runner.WithK8sPodPatch(runFlags.K8sPodPatch), diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index d0e94e015b..7d9408283a 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -177,6 +177,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] --runtime-add-package stringArray Add additional packages to install in the builder and runtime stages (can be repeated) --runtime-image string Override the default base image for protocol schemes (e.g., golang:1.24-alpine, node:20-alpine, python:3.11-slim) --secret stringArray Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET) + --stateless Declare the server as stateless (POST-only, no SSE). Use for MCP servers implementing streamable-HTTP stateless mode. --target-host string Host to forward traffic to (only applicable to SSE or Streamable HTTP transport) (default "127.0.0.1") --target-port int Port for the container to expose (only applicable to SSE or Streamable HTTP transport) --thv-ca-bundle string Path to CA certificate bundle for ToolHive HTTP operations (JWKS, OIDC discovery, etc.) diff --git a/docs/server/docs.go b/docs/server/docs.go index 0d84615d0a..04cbdc5e11 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1283,6 +1283,10 @@ const docTemplate = `{ "type": "array", "uniqueItems": false }, + "stateless": { + "description": "Stateless indicates the server only supports POST (no SSE/GET).\nWhen true, the proxy returns 405 for incoming GET requests and uses a\nPOST-based health check instead of the default GET probe.\nApplies to both remote URLs and local container workloads.", + "type": "boolean" + }, "target_host": { "description": "TargetHost is the host to forward traffic to (only applicable to SSE transport)", "type": "string" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 131683db86..dfc5051e02 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1276,6 +1276,10 @@ "type": "array", "uniqueItems": false }, + "stateless": { + "description": "Stateless indicates the server only supports POST (no SSE/GET).\nWhen true, the proxy returns 405 for incoming GET requests and uses a\nPOST-based health check instead of the default GET probe.\nApplies to both remote URLs and local container workloads.", + "type": "boolean" + }, "target_host": { "description": "TargetHost is the host to forward traffic to (only applicable to SSE transport)", "type": "string" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 42cd734008..08fc84d1ba 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1210,6 +1210,13 @@ components: type: string type: array uniqueItems: false + stateless: + description: |- + Stateless indicates the server only supports POST (no SSE/GET). + When true, the proxy returns 405 for incoming GET requests and uses a + POST-based health check instead of the default GET probe. + Applies to both remote URLs and local container workloads. + type: boolean target_host: description: TargetHost is the host to forward traffic to (only applicable to SSE transport) diff --git a/pkg/runner/config.go b/pkg/runner/config.go index c0f651f04d..519f277abd 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -191,6 +191,12 @@ type RunConfig struct { // TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies TrustProxyHeaders bool `json:"trust_proxy_headers,omitempty" yaml:"trust_proxy_headers,omitempty"` + // Stateless indicates the server only supports POST (no SSE/GET). + // When true, the proxy returns 405 for incoming GET requests and uses a + // POST-based health check instead of the default GET probe. + // Applies to both remote URLs and local container workloads. + Stateless bool `json:"stateless,omitempty" yaml:"stateless,omitempty"` + // ProxyMode is the effective HTTP protocol the proxy uses. // For stdio transports, this is the configured mode (sse or streamable-http). // For direct transports (sse/streamable-http), this matches the transport type. diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 2efd933634..195398b65a 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -319,6 +319,14 @@ func WithTrustProxyHeaders(trust bool) RunConfigBuilderOption { } } +// WithStateless declares the server is stateless (POST-only, no SSE). +func WithStateless(stateless bool) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + b.config.Stateless = stateless + return nil + } +} + // WithEndpointPrefix sets the path prefix for SSE endpoint URLs func WithEndpointPrefix(prefix string) RunConfigBuilderOption { return func(b *runConfigBuilder) error { diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 42be99f7fb..aab3a8cbcc 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -443,6 +443,17 @@ func (r *Runner) Run(ctx context.Context) error { }) } + // Configure stateless mode if requested. Stateless mode applies to any + // streamable-HTTP server (remote or local container) where the upstream + // only accepts POST and does not support SSE-based sessions. + if r.Config.Stateless { + httpT, ok := transportHandler.(*transport.HTTPTransport) + if !ok { + return fmt.Errorf("--stateless requires streamable-HTTP or SSE transport, got %T", transportHandler) + } + httpT.SetStateless(true) + } + // Start the transport (which also starts the container and monitoring) slog.Debug("starting transport", "transport", r.Config.Transport, "container", r.Config.ContainerName) if err := transportHandler.Start(ctx); err != nil { diff --git a/pkg/transport/http.go b/pkg/transport/http.go index a9c81bd343..53b9af6d12 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -58,6 +58,9 @@ type HTTPTransport struct { // Remote MCP server support remoteURL string + // stateless indicates the server is POST-only (no SSE/GET support) + stateless bool + // tokenSource is the OAuth token source for remote authentication tokenSource oauth2.TokenSource @@ -152,6 +155,11 @@ func (t *HTTPTransport) SetOnHealthCheckFailed(callback types.HealthCheckFailedC t.onHealthCheckFailed = callback } +// SetStateless configures the transport for a stateless server. +func (t *HTTPTransport) SetStateless(stateless bool) { + t.stateless = stateless +} + // SetOnUnauthorizedResponse sets the callback for 401 Unauthorized responses // The callback is wrapped to check the unauthorized flag to prevent repeated status updates func (t *HTTPTransport) SetOnUnauthorizedResponse(callback types.UnauthorizedResponseCallback) { @@ -321,14 +329,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error { enableHealthCheck := shouldEnableHealthCheck(isRemote) // Build proxy options - var proxyOptions []transparent.Option - if remoteBasePath != "" { - proxyOptions = append(proxyOptions, transparent.WithRemoteBasePath(remoteBasePath)) - } - proxyOptions = append(proxyOptions, transparent.WithRemoteRawQuery(remoteRawQuery)) - if t.sessionStorage != nil { - proxyOptions = append(proxyOptions, transparent.WithSessionStorage(t.sessionStorage)) - } + proxyOptions := t.buildProxyOptions(remoteBasePath, remoteRawQuery) // Create the transparent proxy t.proxy = transparent.NewTransparentProxyWithOptions( @@ -421,6 +422,22 @@ func (t *HTTPTransport) Stop(ctx context.Context) error { return nil } +// buildProxyOptions constructs the transparent proxy options for this transport. +func (t *HTTPTransport) buildProxyOptions(remoteBasePath, remoteRawQuery string) []transparent.Option { + var opts []transparent.Option + if remoteBasePath != "" { + opts = append(opts, transparent.WithRemoteBasePath(remoteBasePath)) + } + opts = append(opts, transparent.WithRemoteRawQuery(remoteRawQuery)) + if t.stateless { + opts = append(opts, transparent.WithStateless()) + } + if t.sessionStorage != nil { + opts = append(opts, transparent.WithSessionStorage(t.sessionStorage)) + } + return opts +} + // handleContainerExit handles container exit events. // It loops to support reconnecting the monitor when a container is restarted // by Docker (e.g., via restart policy) rather than truly exiting. diff --git a/pkg/transport/proxy/transparent/method_gate_test.go b/pkg/transport/proxy/transparent/method_gate_test.go new file mode 100644 index 0000000000..9bfc2e35c5 --- /dev/null +++ b/pkg/transport/proxy/transparent/method_gate_test.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transparent + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStatelessMethodGate(t *testing.T) { + t.Parallel() + + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + method string + expectedStatus int + expectAllow bool + }{ + { + name: "GET returns 405 with Allow header", + method: http.MethodGet, + expectedStatus: http.StatusMethodNotAllowed, + expectAllow: true, + }, + { + name: "HEAD returns 405 with Allow header", + method: http.MethodHead, + expectedStatus: http.StatusMethodNotAllowed, + expectAllow: true, + }, + { + name: "DELETE returns 405 with Allow header", + method: http.MethodDelete, + expectedStatus: http.StatusMethodNotAllowed, + expectAllow: true, + }, + { + name: "POST is forwarded", + method: http.MethodPost, + expectedStatus: http.StatusOK, + }, + { + name: "PUT is forwarded", + method: http.MethodPut, + expectedStatus: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler := statelessMethodGate(inner) + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.method, "/", nil) + + handler.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Code) + if tc.expectAllow { + assert.Equal(t, "POST, OPTIONS", rec.Header().Get("Allow")) + } + }) + } +} diff --git a/pkg/transport/proxy/transparent/pinger.go b/pkg/transport/proxy/transparent/pinger.go index 008024dc8c..9144bff464 100644 --- a/pkg/transport/proxy/transparent/pinger.go +++ b/pkg/transport/proxy/transparent/pinger.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/http" + "strings" "time" "github.com/stacklok/toolhive/pkg/healthcheck" @@ -87,3 +88,68 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) { return duration, fmt.Errorf("SSE server health check failed with status %d", resp.StatusCode) } + +// StatelessMCPPinger health-checks stateless streamable-HTTP servers via POST ping. +// Stateless servers don't support GET, so we send a minimal JSON-RPC ping instead. +type StatelessMCPPinger struct { + targetURL string + client *http.Client +} + +// NewStatelessMCPPinger creates a pinger for stateless streamable-HTTP servers. +func NewStatelessMCPPinger(targetURL string) healthcheck.MCPPinger { + return NewStatelessMCPPingerWithTimeout(targetURL, DefaultPingerTimeout) +} + +// NewStatelessMCPPingerWithTimeout creates a stateless pinger with a custom timeout. +func NewStatelessMCPPingerWithTimeout(targetURL string, timeout time.Duration) healthcheck.MCPPinger { + if timeout <= 0 { + timeout = DefaultPingerTimeout + } + return &StatelessMCPPinger{ + targetURL: targetURL, + client: &http.Client{ + Timeout: timeout, + }, + } +} + +// Ping sends a JSON-RPC ping POST to check if the stateless server is reachable. +// Accepts any 2xx-4xx response as healthy; only network errors and 5xx indicate failure. +func (p *StatelessMCPPinger) Ping(ctx context.Context) (time.Duration, error) { + start := time.Now() + + body := `{"jsonrpc":"2.0","id":0,"method":"ping","params":{}}` + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.targetURL, strings.NewReader(body)) + if err != nil { + return 0, fmt.Errorf("failed to create stateless ping request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + //nolint:gosec // G706: logging target URL from config + slog.Debug("checking stateless MCP server health via POST ping", "target", p.targetURL) + + resp, err := p.client.Do(req) + if err != nil { + return time.Since(start), fmt.Errorf("stateless ping failed to connect: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Debug("failed to close ping response body", "error", err) + } + }() + + duration := time.Since(start) + + // Accept 2xx-4xx: even 401/403 means the server is reachable. + // Only 5xx or network errors indicate the server is down. + if resp.StatusCode >= 200 && resp.StatusCode < 500 { + //nolint:gosec // G706: logging HTTP status code from health check response + slog.Debug("stateless MCP server health check successful", + "duration", duration, "status", resp.StatusCode) + return duration, nil + } + + return duration, fmt.Errorf("stateless ping returned status %d", resp.StatusCode) +} diff --git a/pkg/transport/proxy/transparent/pinger_test.go b/pkg/transport/proxy/transparent/pinger_test.go new file mode 100644 index 0000000000..1c675c3f5a --- /dev/null +++ b/pkg/transport/proxy/transparent/pinger_test.go @@ -0,0 +1,154 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transparent + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatelessMCPPinger_Ping(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverFunc func(w http.ResponseWriter, r *http.Request) + wantErr bool + wantHealthy bool // true = nil error, positive duration + }{ + { + name: "200 OK is healthy", + serverFunc: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }, + wantErr: false, + wantHealthy: true, + }, + { + name: "401 unauthorized is treated as healthy", + serverFunc: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }, + wantErr: false, + wantHealthy: true, + }, + { + name: "403 forbidden is treated as healthy", + serverFunc: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + }, + wantErr: false, + wantHealthy: true, + }, + { + name: "500 server error returns an error", + serverFunc: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(tc.serverFunc)) + defer srv.Close() + + pinger := NewStatelessMCPPinger(srv.URL) + duration, err := pinger.Ping(context.Background()) + + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Positive(t, duration, "duration should be positive on success") + }) + } +} + +func TestStatelessMCPPinger_Ping_ConnectionRefused(t *testing.T) { + t.Parallel() + + // Point at a port where nothing is listening. Use a server, start it, + // close it immediately so the port is definitely not in use. + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + addr := srv.URL + srv.Close() + + pinger := NewStatelessMCPPingerWithTimeout(addr, 2*time.Second) + _, err := pinger.Ping(context.Background()) + require.Error(t, err, "should return error when connection is refused") +} + +func TestStatelessMCPPinger_Ping_UsesPost(t *testing.T) { + t.Parallel() + + var receivedMethod string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + pinger := NewStatelessMCPPinger(srv.URL) + _, err := pinger.Ping(context.Background()) + require.NoError(t, err) + + assert.Equal(t, http.MethodPost, receivedMethod, "pinger should use POST method") +} + +func TestStatelessMCPPinger_Ping_SendsJsonBody(t *testing.T) { + t.Parallel() + + var body map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + err = json.Unmarshal(raw, &body) + require.NoError(t, err) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + pinger := NewStatelessMCPPinger(srv.URL) + _, err := pinger.Ping(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "2.0", body["jsonrpc"], "body should contain jsonrpc field") + assert.Equal(t, "ping", body["method"], "body should contain method field") + _, hasID := body["id"] + assert.True(t, hasID, "body should contain id field") +} + +func TestNewStatelessMCPPingerWithTimeout_ZeroTimeout(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // Zero timeout should be replaced by DefaultPingerTimeout — the pinger + // must still work (i.e., not time out immediately on a live server). + pinger := NewStatelessMCPPingerWithTimeout(srv.URL, 0) + _, err := pinger.Ping(context.Background()) + require.NoError(t, err, "pinger with zero timeout should default to DefaultPingerTimeout and succeed") + + // Verify the underlying client has the default timeout set. + sp, ok := pinger.(*StatelessMCPPinger) + require.True(t, ok, "pinger should be *StatelessMCPPinger") + assert.Equal(t, DefaultPingerTimeout, sp.client.Timeout) +} diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index f60f32798c..05d6bc6cae 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -89,6 +89,9 @@ type TransparentProxy struct { // Transport type (sse, streamable-http) transportType string + // stateless indicates the server is POST-only (no SSE/GET support) + stateless bool + // Callback when health check fails (for remote servers) onHealthCheckFailed types.HealthCheckFailedCallback @@ -213,6 +216,15 @@ func WithRemoteRawQuery(rawQuery string) Option { } } +// WithStateless configures the proxy for stateless streamable-HTTP servers. +// In stateless mode, incoming GET and DELETE requests receive 405 Method Not Allowed +// instead of being forwarded, and health checks use POST ping instead of GET. +func WithStateless() Option { + return func(p *TransparentProxy) { + p.stateless = true + } +} + // withHealthCheckPingTimeout sets the health check ping timeout. // This is primarily useful for testing with shorter timeouts. // Ignores non-positive timeouts; default will be used. @@ -365,7 +377,11 @@ func NewTransparentProxyWithOptions( if pingTimeout == 0 { pingTimeout = DefaultPingerTimeout } - mcpPinger = NewMCPPingerWithTimeout(targetURI, pingTimeout) + if proxy.stateless { + mcpPinger = NewStatelessMCPPingerWithTimeout(targetURI, pingTimeout) + } else { + mcpPinger = NewMCPPingerWithTimeout(targetURI, pingTimeout) + } } proxy.healthChecker = healthcheck.NewHealthChecker(transportType, mcpPinger) @@ -934,7 +950,11 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // 5. Catch-all proxy handler (least specific - ServeMux routing handles precedence) // Note: No manual path checking needed - ServeMux longest-match routing ensures - // more specific paths registered above take precedence over this catch-all + // more specific paths registered above take precedence over this catch-all. + // In stateless mode, wrap with a method gate that rejects GET/DELETE with 405. + if p.stateless { + finalHandler = statelessMethodGate(finalHandler) + } mux.Handle("/", finalHandler) // Use ListenConfig with SO_REUSEADDR to allow port reuse after unclean shutdown @@ -1207,3 +1227,18 @@ func (*TransparentProxy) SendMessageToDestination(_ jsonrpc2.Message) error { func (*TransparentProxy) ForwardResponseToClients(_ context.Context, _ jsonrpc2.Message) error { return fmt.Errorf("ForwardResponseToClients not implemented for TransparentProxy") } + +// statelessMethodGate wraps a handler to reject GET, HEAD, and DELETE requests with 405. +// Used in stateless mode where the server only supports POST. +// HEAD is blocked alongside GET because HEAD is semantically a GET without a response body; +// a server that cannot handle GET will not handle HEAD either. +func statelessMethodGate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodDelete { + w.Header().Set("Allow", "POST, OPTIONS") + http.Error(w, "method not allowed: server is stateless (POST only)", http.StatusMethodNotAllowed) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/test/e2e/stateless_proxy_test.go b/test/e2e/stateless_proxy_test.go new file mode 100644 index 0000000000..3b9c64baf8 --- /dev/null +++ b/test/e2e/stateless_proxy_test.go @@ -0,0 +1,263 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "strings" + "sync/atomic" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/stacklok/toolhive/test/e2e" +) + +var _ = Describe("Stateless Proxy Mode", Label("proxy", "stateless", "streamable-http", "e2e"), Serial, func() { + var ( + config *e2e.TestConfig + serverName string + mockServer *statelessMockMCPServer + ) + + BeforeEach(func() { + config = e2e.NewTestConfig() + serverName = e2e.GenerateUniqueServerName("stateless") + + err := e2e.CheckTHVBinaryAvailable(config) + Expect(err).ToNot(HaveOccurred(), "thv binary should be available") + }) + + AfterEach(func() { + if mockServer != nil { + mockServer.Stop() + mockServer = nil + } + + if config.CleanupAfter { + err := e2e.StopAndRemoveMCPServer(config, serverName) + Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server") + } + }) + + Describe("Method gating for stateless servers", func() { + Context("when --stateless flag is set on a remote server", func() { + It("should reject GET requests and forward POST requests", func() { + By("Starting a stateless mock MCP server") + var err error + mockServer, err = newStatelessMockMCPServer() + Expect(err).ToNot(HaveOccurred(), "Should be able to start mock server") + + mockServerURL := mockServer.URL() + GinkgoWriter.Printf("Mock server started at: %s\n", mockServerURL) + + By("Starting thv with --stateless flag") + thvCmd := exec.Command(config.THVBinary, "run", + "--name", serverName, + "--stateless", + mockServerURL+"/mcp") + thvCmd.Env = append(os.Environ(), + "TOOLHIVE_REMOTE_HEALTHCHECKS=true", + ) + thvCmd.Stdout = GinkgoWriter + thvCmd.Stderr = GinkgoWriter + + err = thvCmd.Start() + Expect(err).ToNot(HaveOccurred(), "Should be able to start thv") + + thvPID := thvCmd.Process.Pid + GinkgoWriter.Printf("thv process started with PID: %d\n", thvPID) + + defer func() { + if proc, err := os.FindProcess(thvPID); err == nil { + _ = proc.Kill() + } + }() + + By("Waiting for thv to register as running") + err = e2e.WaitForMCPServer(config, serverName, 60*time.Second) + Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds") + + By("Getting the proxy URL") + proxyURL, err := e2e.GetMCPServerURL(config, serverName) + Expect(err).ToNot(HaveOccurred(), "Should be able to get proxy URL") + // Ensure URL has /mcp suffix + if !strings.HasSuffix(proxyURL, "/mcp") { + proxyURL += "/mcp" + } + GinkgoWriter.Printf("Proxy URL: %s\n", proxyURL) + + By("Verifying GET requests are rejected with 405") + resp, err := http.Get(proxyURL) + Expect(err).ToNot(HaveOccurred(), "Should be able to connect to proxy") + resp.Body.Close() + Expect(resp.StatusCode).To(Equal(http.StatusMethodNotAllowed), + "GET request should be rejected with 405 Method Not Allowed") + + By("Verifying POST requests are forwarded successfully") + initReq := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e-test","version":"1.0"}}}` + postResp, err := http.Post(proxyURL, "application/json", strings.NewReader(initReq)) + Expect(err).ToNot(HaveOccurred(), "Should be able to POST to proxy") + defer postResp.Body.Close() + + Expect(postResp.StatusCode).To(Equal(http.StatusOK), + "POST request should be forwarded and return 200") + + body, err := io.ReadAll(postResp.Body) + Expect(err).ToNot(HaveOccurred(), "Should be able to read response body") + + var jsonRPC map[string]interface{} + err = json.Unmarshal(body, &jsonRPC) + Expect(err).ToNot(HaveOccurred(), "Response should be valid JSON-RPC") + Expect(jsonRPC).To(HaveKey("result"), "Response should have a result field") + + By("Verifying DELETE requests are also rejected") + delReq, err := http.NewRequest(http.MethodDelete, proxyURL, nil) + Expect(err).ToNot(HaveOccurred()) + delResp, err := http.DefaultClient.Do(delReq) + Expect(err).ToNot(HaveOccurred(), "Should be able to send DELETE to proxy") + delResp.Body.Close() + Expect(delResp.StatusCode).To(Equal(http.StatusMethodNotAllowed), + "DELETE request should be rejected with 405") + + By("Verifying the mock server received POST requests through the proxy") + Expect(mockServer.GetCount()).To(BeNumerically(">", 0), + "Mock server should have received at least one POST request") + }) + }) + }) +}) + +// statelessMockMCPServer is a minimal MCP server that only accepts POST. +// It tracks whether any GET requests reached it (which would indicate +// the proxy's method gate is not working). +type statelessMockMCPServer struct { + server *http.Server + listener net.Listener + port int + gotGET atomic.Bool + postHits atomic.Int32 +} + +func newStatelessMockMCPServer() (*statelessMockMCPServer, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to create listener: %w", err) + } + + port := listener.Addr().(*net.TCPAddr).Port + + mock := &statelessMockMCPServer{ + listener: listener, + port: port, + } + + mock.server = &http.Server{ + Handler: http.HandlerFunc(mock.handleRequest), + } + + go func() { + if err := mock.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + GinkgoWriter.Printf("Stateless mock server error: %v\n", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + return mock, nil +} + +func (m *statelessMockMCPServer) handleRequest(w http.ResponseWriter, r *http.Request) { + // Always return 404 for OAuth well-known URIs + if strings.HasPrefix(r.URL.Path, "/.well-known/") { + w.WriteHeader(http.StatusNotFound) + return + } + + if r.Method == http.MethodGet { + m.gotGET.Store(true) + // A real stateless server would reject GETs, but we accept them here + // so the test can detect if any GETs leaked through the proxy. + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + m.postHits.Add(1) + + // Parse the JSON-RPC request to return appropriate responses + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + method, _ := req["method"].(string) + id := req["id"] + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + switch method { + case "initialize": + resp := map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "serverInfo": map[string]interface{}{ + "name": "stateless-mock", + "version": "1.0.0", + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + case "ping": + resp := map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{}, + } + _ = json.NewEncoder(w).Encode(resp) + default: + resp := map[string]interface{}{ + "jsonrpc": "2.0", + "id": id, + "result": map[string]interface{}{}, + } + _ = json.NewEncoder(w).Encode(resp) + } +} + +func (m *statelessMockMCPServer) URL() string { + return fmt.Sprintf("http://127.0.0.1:%d", m.port) +} + +func (m *statelessMockMCPServer) Stop() { + if m.server != nil { + _ = m.server.Close() + } +} + +func (m *statelessMockMCPServer) GetCount() int32 { + return m.postHits.Load() +} + +func (m *statelessMockMCPServer) GotGET() bool { + return m.gotGET.Load() +}