Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type RunFlags struct {
// Remote MCP server support
RemoteURL string

// Stateless indicates the server is stateless (POST-only, no SSE)
Stateless bool

// Security and audit
AuthzConfig string
AuditConfig string
Expand Down Expand Up @@ -253,6 +256,9 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().BoolVar(&config.TrustProxyHeaders, "trust-proxy-headers", false,
"Trust X-Forwarded-* headers from reverse proxies (X-Forwarded-Proto, X-Forwarded-Host, X-Forwarded-Port, X-Forwarded-Prefix) "+
"(default false)")
cmd.Flags().BoolVar(&config.Stateless, "stateless", false,
"Declare the server as stateless (POST-only, no SSE). "+
"Use for MCP servers implementing streamable-HTTP stateless mode.")
cmd.Flags().StringVar(&config.EndpointPrefix, "endpoint-prefix", "",
"Path prefix to prepend to SSE endpoint URLs (e.g., /playwright)")
cmd.Flags().StringVar(&config.Network, "network", "",
Expand Down Expand Up @@ -611,6 +617,7 @@ func buildRunnerConfig(
runner.WithNetworkIsolation(runFlags.IsolateNetwork),
runner.WithAllowDockerGateway(runFlags.AllowDockerGateway),
runner.WithTrustProxyHeaders(runFlags.TrustProxyHeaders),
runner.WithStateless(runFlags.Stateless),
runner.WithEndpointPrefix(runFlags.EndpointPrefix),
runner.WithNetworkMode(runFlags.Network),
runner.WithK8sPodPatch(runFlags.K8sPodPatch),
Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ type RunConfig struct {
// TrustProxyHeaders indicates whether to trust X-Forwarded-* headers from reverse proxies
TrustProxyHeaders bool `json:"trust_proxy_headers,omitempty" yaml:"trust_proxy_headers,omitempty"`

// Stateless indicates the server only supports POST (no SSE/GET).
// When true, the proxy returns 405 for incoming GET requests and uses a
// POST-based health check instead of the default GET probe.
// Applies to both remote URLs and local container workloads.
Stateless bool `json:"stateless,omitempty" yaml:"stateless,omitempty"`

// ProxyMode is the effective HTTP protocol the proxy uses.
// For stdio transports, this is the configured mode (sse or streamable-http).
// For direct transports (sse/streamable-http), this matches the transport type.
Expand Down
8 changes: 8 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ func WithTrustProxyHeaders(trust bool) RunConfigBuilderOption {
}
}

// WithStateless declares the server is stateless (POST-only, no SSE).
func WithStateless(stateless bool) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
b.config.Stateless = stateless
return nil
}
}

// WithEndpointPrefix sets the path prefix for SSE endpoint URLs
func WithEndpointPrefix(prefix string) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
Expand Down
11 changes: 11 additions & 0 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,17 @@ func (r *Runner) Run(ctx context.Context) error {
})
}

// Configure stateless mode if requested. Stateless mode applies to any
// streamable-HTTP server (remote or local container) where the upstream
// only accepts POST and does not support SSE-based sessions.
if r.Config.Stateless {
httpT, ok := transportHandler.(*transport.HTTPTransport)
if !ok {
return fmt.Errorf("--stateless requires streamable-HTTP or SSE transport, got %T", transportHandler)
}
httpT.SetStateless(true)
}

// Start the transport (which also starts the container and monitoring)
slog.Debug("starting transport", "transport", r.Config.Transport, "container", r.Config.ContainerName)
if err := transportHandler.Start(ctx); err != nil {
Expand Down
33 changes: 25 additions & 8 deletions pkg/transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type HTTPTransport struct {
// Remote MCP server support
remoteURL string

// stateless indicates the server is POST-only (no SSE/GET support)
stateless bool

// tokenSource is the OAuth token source for remote authentication
tokenSource oauth2.TokenSource

Expand Down Expand Up @@ -152,6 +155,11 @@ func (t *HTTPTransport) SetOnHealthCheckFailed(callback types.HealthCheckFailedC
t.onHealthCheckFailed = callback
}

// SetStateless configures the transport for a stateless server.
func (t *HTTPTransport) SetStateless(stateless bool) {
t.stateless = stateless
}

// SetOnUnauthorizedResponse sets the callback for 401 Unauthorized responses
// The callback is wrapped to check the unauthorized flag to prevent repeated status updates
func (t *HTTPTransport) SetOnUnauthorizedResponse(callback types.UnauthorizedResponseCallback) {
Expand Down Expand Up @@ -321,14 +329,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
enableHealthCheck := shouldEnableHealthCheck(isRemote)

// Build proxy options
var proxyOptions []transparent.Option
if remoteBasePath != "" {
proxyOptions = append(proxyOptions, transparent.WithRemoteBasePath(remoteBasePath))
}
proxyOptions = append(proxyOptions, transparent.WithRemoteRawQuery(remoteRawQuery))
if t.sessionStorage != nil {
proxyOptions = append(proxyOptions, transparent.WithSessionStorage(t.sessionStorage))
}
proxyOptions := t.buildProxyOptions(remoteBasePath, remoteRawQuery)

// Create the transparent proxy
t.proxy = transparent.NewTransparentProxyWithOptions(
Expand Down Expand Up @@ -421,6 +422,22 @@ func (t *HTTPTransport) Stop(ctx context.Context) error {
return nil
}

// buildProxyOptions constructs the transparent proxy options for this transport.
func (t *HTTPTransport) buildProxyOptions(remoteBasePath, remoteRawQuery string) []transparent.Option {
var opts []transparent.Option
if remoteBasePath != "" {
opts = append(opts, transparent.WithRemoteBasePath(remoteBasePath))
}
opts = append(opts, transparent.WithRemoteRawQuery(remoteRawQuery))
if t.stateless {
opts = append(opts, transparent.WithStateless())
}
if t.sessionStorage != nil {
opts = append(opts, transparent.WithSessionStorage(t.sessionStorage))
}
return opts
}

// handleContainerExit handles container exit events.
// It loops to support reconnecting the monitor when a container is restarted
// by Docker (e.g., via restart policy) rather than truly exiting.
Expand Down
73 changes: 73 additions & 0 deletions pkg/transport/proxy/transparent/method_gate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package transparent

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStatelessMethodGate(t *testing.T) {
t.Parallel()

inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

tests := []struct {
name string
method string
expectedStatus int
expectAllow bool
}{
{
name: "GET returns 405 with Allow header",
method: http.MethodGet,
expectedStatus: http.StatusMethodNotAllowed,
expectAllow: true,
},
{
name: "HEAD returns 405 with Allow header",
method: http.MethodHead,
expectedStatus: http.StatusMethodNotAllowed,
expectAllow: true,
},
{
name: "DELETE returns 405 with Allow header",
method: http.MethodDelete,
expectedStatus: http.StatusMethodNotAllowed,
expectAllow: true,
},
{
name: "POST is forwarded",
method: http.MethodPost,
expectedStatus: http.StatusOK,
},
{
name: "PUT is forwarded",
method: http.MethodPut,
expectedStatus: http.StatusOK,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handler := statelessMethodGate(inner)
rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.method, "/", nil)

handler.ServeHTTP(rec, req)

assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.expectAllow {
assert.Equal(t, "POST, OPTIONS", rec.Header().Get("Allow"))
}
})
}
}
66 changes: 66 additions & 0 deletions pkg/transport/proxy/transparent/pinger.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"

"github.com/stacklok/toolhive/pkg/healthcheck"
Expand Down Expand Up @@ -87,3 +88,68 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) {

return duration, fmt.Errorf("SSE server health check failed with status %d", resp.StatusCode)
}

// StatelessMCPPinger health-checks stateless streamable-HTTP servers via POST ping.
// Stateless servers don't support GET, so we send a minimal JSON-RPC ping instead.
type StatelessMCPPinger struct {
targetURL string
client *http.Client
}

// NewStatelessMCPPinger creates a pinger for stateless streamable-HTTP servers.
func NewStatelessMCPPinger(targetURL string) healthcheck.MCPPinger {
return NewStatelessMCPPingerWithTimeout(targetURL, DefaultPingerTimeout)
}

// NewStatelessMCPPingerWithTimeout creates a stateless pinger with a custom timeout.
func NewStatelessMCPPingerWithTimeout(targetURL string, timeout time.Duration) healthcheck.MCPPinger {
if timeout <= 0 {
timeout = DefaultPingerTimeout
}
return &StatelessMCPPinger{
targetURL: targetURL,
client: &http.Client{
Timeout: timeout,
},
}
}

// Ping sends a JSON-RPC ping POST to check if the stateless server is reachable.
// Accepts any 2xx-4xx response as healthy; only network errors and 5xx indicate failure.
func (p *StatelessMCPPinger) Ping(ctx context.Context) (time.Duration, error) {
start := time.Now()

body := `{"jsonrpc":"2.0","id":0,"method":"ping","params":{}}`
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.targetURL, strings.NewReader(body))
if err != nil {
return 0, fmt.Errorf("failed to create stateless ping request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")

//nolint:gosec // G706: logging target URL from config
slog.Debug("checking stateless MCP server health via POST ping", "target", p.targetURL)

resp, err := p.client.Do(req)
if err != nil {
return time.Since(start), fmt.Errorf("stateless ping failed to connect: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
slog.Debug("failed to close ping response body", "error", err)
}
}()

duration := time.Since(start)

// Accept 2xx-4xx: even 401/403 means the server is reachable.
// Only 5xx or network errors indicate the server is down.
if resp.StatusCode >= 200 && resp.StatusCode < 500 {
//nolint:gosec // G706: logging HTTP status code from health check response
slog.Debug("stateless MCP server health check successful",
"duration", duration, "status", resp.StatusCode)
return duration, nil
}

return duration, fmt.Errorf("stateless ping returned status %d", resp.StatusCode)
}
Loading
Loading