Skip to content

Commit 5ead2ff

Browse files
committed
Wire HeaderForward into vMCP session HTTP client
PR #5239 added HeaderForward support to the startup capability-discovery client at pkg/vmcp/client/client.go, but the per-session MCP HTTP client in pkg/vmcp/session/internal/backend builds a parallel transport chain (DefaultTransport -> auth -> identity) that never reads target.HeaderForward. Every post-initialize MCP call (tools/list, tools/call, ...) therefore reaches the upstream without user-configured headers, leaving features like GitHub Copilot's X-MCP-Toolsets filter silently broken in v0.27.2. Construct a single secrets.EnvironmentProvider at connector build time, plumb it into createMCPClient, and wrap the shared chain with BuildHeaderForwardTripper as the outermost stage so vMCP auth/identity headers still win on overlapping names. Export the existing helper from pkg/vmcp/client so the session backend can reuse it without duplication. Fixes #5289
1 parent f78bb42 commit 5ead2ff

4 files changed

Lines changed: 221 additions & 3 deletions

File tree

pkg/vmcp/headerforward/transport.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response
6464
// backend's pre-resolved HeaderForwardConfig. Returns base unchanged when no
6565
// header injection is configured or the effective header set is empty.
6666
//
67+
// Used by both the vMCP backend client (startup capability discovery) and the
68+
// per-session backend connector (long-lived MCP traffic). Exported so the
69+
// session backend in pkg/vmcp/session/internal/backend can share the same
70+
// transport-chain wiring.
71+
//
6772
// Fails loudly (constructor validation, per go-style.md) when a secret identifier
6873
// cannot be resolved through the provider, so a misconfigured backend surfaces
6974
// at pod startup — not as a silent missing-header on every request.

pkg/vmcp/session/internal/backend/mcp_session.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ import (
1717
"github.com/mark3labs/mcp-go/mcp"
1818

1919
"github.com/stacklok/toolhive/pkg/auth"
20+
"github.com/stacklok/toolhive/pkg/secrets"
2021
"github.com/stacklok/toolhive/pkg/versions"
2122
"github.com/stacklok/toolhive/pkg/vmcp"
2223
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
2324
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
2425
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
26+
"github.com/stacklok/toolhive/pkg/vmcp/headerforward"
2527
)
2628

2729
const (
@@ -196,19 +198,25 @@ func (c *mcpSession) GetPrompt(
196198
//
197199
// registry provides the authentication strategy for outgoing backend requests.
198200
// Pass a registry configured with the "unauthenticated" strategy to disable auth.
201+
//
202+
// A single secrets.EnvironmentProvider is constructed once per connector and
203+
// shared across every session it creates; its lifetime matches the connector's.
204+
// It is consumed by BuildHeaderForwardTripper to resolve secret-backed entries
205+
// in target.HeaderForward.
199206
func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func(
200207
ctx context.Context,
201208
target *vmcp.BackendTarget,
202209
identity *auth.Identity,
203210
sessionHint string,
204211
) (Session, *vmcp.CapabilityList, error) {
212+
provider := secrets.NewEnvironmentProvider()
205213
return func(
206214
ctx context.Context,
207215
target *vmcp.BackendTarget,
208216
identity *auth.Identity,
209217
sessionHint string,
210218
) (Session, *vmcp.CapabilityList, error) {
211-
c, err := createMCPClient(target, identity, registry, sessionHint)
219+
c, err := createMCPClient(ctx, target, identity, registry, sessionHint, provider)
212220
if err != nil {
213221
return nil, nil, fmt.Errorf("failed to create MCP client for backend %s: %w", target.WorkloadID, err)
214222
}
@@ -238,11 +246,17 @@ func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func(
238246
// to client.Close(), not to any caller-supplied init context.
239247
// sessionHint, when non-empty, is passed as the initial Mcp-Session-Id for
240248
// streamable-HTTP transports so the backend can resume an existing session.
249+
//
250+
// ctx is used only to resolve secret-backed entries in target.HeaderForward at
251+
// client-creation time; the transport itself is started with context.Background()
252+
// as described above. provider supplies values for those secret-backed headers.
241253
func createMCPClient(
254+
ctx context.Context,
242255
target *vmcp.BackendTarget,
243256
identity *auth.Identity,
244257
registry vmcpauth.OutgoingAuthRegistry,
245258
sessionHint string,
259+
provider secrets.Provider,
246260
) (*mcpclient.Client, error) {
247261
// Resolve and validate the auth strategy once at client creation time.
248262
strategyName := authtypes.StrategyTypeUnauthenticated
@@ -259,7 +273,10 @@ func createMCPClient(
259273

260274
slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID)
261275

262-
// Build shared transport chain: auth → identity propagation.
276+
// Build shared transport chain: auth → identity propagation → header forward.
277+
// HeaderForward is the outermost stage so inner stages (auth, identity) win
278+
// on any overlapping header name — matching the ordering in
279+
// headerForwardRoundTripper.RoundTrip, which skips names already set.
263280
// The per-transport sections below may add a size-limiting wrapper on top.
264281
base := http.RoundTripper(http.DefaultTransport)
265282
base = &authRoundTripper{
@@ -269,6 +286,10 @@ func createMCPClient(
269286
target: target,
270287
}
271288
base = &identityRoundTripper{base: base, identity: identity}
289+
base, err = headerforward.BuildHeaderForwardTripper(ctx, base, target.HeaderForward, provider, target.WorkloadID)
290+
if err != nil {
291+
return nil, fmt.Errorf("failed to build header-forward transport for backend %s: %w", target.WorkloadID, err)
292+
}
272293

273294
var c *mcpclient.Client
274295
switch target.TransportType {
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package backend
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"io"
10+
"net/http"
11+
"net/http/httptest"
12+
"sync"
13+
"testing"
14+
"time"
15+
16+
"github.com/mark3labs/mcp-go/mcp"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
20+
"github.com/stacklok/toolhive/pkg/vmcp"
21+
)
22+
23+
// headerCapturingBackend is a minimal streamable-HTTP MCP fake that records
24+
// inbound request headers keyed by JSON-RPC method. The test asserts that a
25+
// user-configured HeaderForward header reaches the backend on POST-INITIALIZE
26+
// traffic — see issue #5289. The startup capability-discovery path was fixed
27+
// in PR #5239; per-session HTTP traffic is still missing the wrap.
28+
type headerCapturingBackend struct {
29+
t *testing.T
30+
31+
mu sync.Mutex
32+
headersByMethod map[string]http.Header
33+
}
34+
35+
func newHeaderCapturingBackend(t *testing.T) (*headerCapturingBackend, string) {
36+
t.Helper()
37+
fb := &headerCapturingBackend{
38+
t: t,
39+
headersByMethod: make(map[string]http.Header),
40+
}
41+
mux := http.NewServeMux()
42+
mux.HandleFunc("/mcp", fb.handle)
43+
ts := httptest.NewServer(mux)
44+
t.Cleanup(ts.Close)
45+
return fb, ts.URL + "/mcp"
46+
}
47+
48+
func (f *headerCapturingBackend) headersFor(method string) http.Header {
49+
f.mu.Lock()
50+
defer f.mu.Unlock()
51+
return f.headersByMethod[method]
52+
}
53+
54+
func (f *headerCapturingBackend) handle(w http.ResponseWriter, r *http.Request) {
55+
if r.Method != http.MethodPost {
56+
// Streamable-HTTP transports may open a GET for server-pushed
57+
// notifications; rejecting it cleanly is fine for this test.
58+
w.WriteHeader(http.StatusMethodNotAllowed)
59+
return
60+
}
61+
62+
body, err := io.ReadAll(r.Body)
63+
if err != nil {
64+
f.t.Errorf("headerCapturingBackend: read body: %v", err)
65+
w.WriteHeader(http.StatusBadRequest)
66+
return
67+
}
68+
defer func() { _ = r.Body.Close() }()
69+
70+
var msg struct {
71+
JSONRPC string `json:"jsonrpc"`
72+
ID json.RawMessage `json:"id"`
73+
Method string `json:"method"`
74+
}
75+
if err := json.Unmarshal(body, &msg); err != nil {
76+
f.t.Errorf("headerCapturingBackend: decode: %v body=%s", err, string(body))
77+
w.WriteHeader(http.StatusBadRequest)
78+
return
79+
}
80+
81+
f.mu.Lock()
82+
f.headersByMethod[msg.Method] = r.Header.Clone()
83+
f.mu.Unlock()
84+
85+
// Notifications (no id, e.g. notifications/initialized) get an empty 202.
86+
if len(msg.ID) == 0 || string(msg.ID) == "null" {
87+
w.WriteHeader(http.StatusAccepted)
88+
return
89+
}
90+
91+
switch msg.Method {
92+
case string(mcp.MethodInitialize):
93+
w.Header().Set("Mcp-Session-Id", "header-forward-test-session")
94+
f.writeResult(w, msg.ID, map[string]any{
95+
"protocolVersion": mcp.LATEST_PROTOCOL_VERSION,
96+
"capabilities": map[string]any{
97+
"tools": map[string]any{},
98+
},
99+
"serverInfo": map[string]any{"name": "header-forward-fake", "version": "0.0.0"},
100+
})
101+
case string(mcp.MethodToolsList):
102+
f.writeResult(w, msg.ID, map[string]any{
103+
"tools": []mcp.Tool{{Name: "echo", Description: "echo tool"}},
104+
})
105+
case string(mcp.MethodToolsCall):
106+
f.writeResult(w, msg.ID, map[string]any{
107+
"content": []map[string]any{
108+
{"type": "text", "text": "ok"},
109+
},
110+
"isError": false,
111+
})
112+
default:
113+
f.writeResult(w, msg.ID, map[string]any{})
114+
}
115+
}
116+
117+
func (f *headerCapturingBackend) writeResult(w http.ResponseWriter, id json.RawMessage, result any) {
118+
w.Header().Set("Content-Type", "application/json")
119+
w.WriteHeader(http.StatusOK)
120+
if err := json.NewEncoder(w).Encode(map[string]any{
121+
"jsonrpc": "2.0",
122+
"id": json.RawMessage(id),
123+
"result": result,
124+
}); err != nil {
125+
f.t.Errorf("headerCapturingBackend: encode result: %v", err)
126+
}
127+
}
128+
129+
// TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests is the red-phase
130+
// regression test for issue #5289. PR #5239 fixed HeaderForward for the vMCP
131+
// backend client (used for startup capability discovery) but did not extend the
132+
// fix to the session-side connector at pkg/vmcp/session/internal/backend.
133+
// As a result, user-configured headers (e.g. X-MCP-Toolsets for GitHub MCP)
134+
// never reach the backend on per-session requests like tools/call.
135+
//
136+
// The test asserts that, after the connector completes Initialize, a
137+
// subsequent CallTool carries the configured plaintext header on the wire.
138+
// On main today it fails because the connector's transport chain does not
139+
// include a header-forward round-tripper — see createMCPClient in
140+
// mcp_session.go (the chain is http.DefaultTransport → authRoundTripper →
141+
// identityRoundTripper, with no HeaderForward stage).
142+
func TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests(t *testing.T) {
143+
t.Parallel()
144+
145+
const (
146+
headerName = "X-MCP-Toolsets"
147+
headerValue = "projects,issues,pull_requests,users,repos"
148+
)
149+
150+
fb, url := newHeaderCapturingBackend(t)
151+
152+
target := &vmcp.BackendTarget{
153+
WorkloadID: "header-forward-backend",
154+
WorkloadName: "header-forward-backend",
155+
BaseURL: url,
156+
TransportType: "streamable-http",
157+
HeaderForward: &vmcp.HeaderForwardConfig{
158+
AddPlaintextHeaders: map[string]string{
159+
headerName: headerValue,
160+
},
161+
},
162+
}
163+
164+
registry := newTestRegistry(t)
165+
connector := NewHTTPConnector(registry)
166+
167+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
168+
defer cancel()
169+
170+
sess, caps, err := connector(ctx, target, nil, "")
171+
require.NoError(t, err, "connector must initialise the backend successfully")
172+
require.NotNil(t, sess, "connector returned nil session")
173+
require.NotNil(t, caps, "connector returned nil capability list")
174+
t.Cleanup(func() { _ = sess.Close() })
175+
176+
// Make a single MCP call AFTER initialize completes. tools/call exercises
177+
// the same transport chain as initialize but is unambiguously a
178+
// post-handshake request — which is exactly where the regression lives.
179+
_, err = sess.CallTool(ctx, "echo", map[string]any{}, nil)
180+
require.NoError(t, err, "post-initialize CallTool must succeed")
181+
182+
// The recorded inbound headers for the tools/call request must include the
183+
// user-configured forward header. This is the single assertion target:
184+
// the test fails for exactly one reason — header missing on the recorded
185+
// post-initialize request.
186+
gotHeaders := fb.headersFor(string(mcp.MethodToolsCall))
187+
require.NotNil(t, gotHeaders, "backend never received a tools/call request")
188+
assert.Equal(t, headerValue, gotHeaders.Get(headerName),
189+
"HeaderForward.AddPlaintextHeaders must reach the backend on post-initialize requests")
190+
}

pkg/vmcp/session/internal/backend/mcp_session_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
package backend
55

66
import (
7+
"context"
78
"testing"
89

910
"github.com/stretchr/testify/assert"
1011
"github.com/stretchr/testify/require"
1112

13+
"github.com/stacklok/toolhive/pkg/secrets"
1214
"github.com/stacklok/toolhive/pkg/vmcp"
1315
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
1416
"github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
@@ -40,7 +42,7 @@ func TestCreateMCPClient_UnsupportedTransport(t *testing.T) {
4042
TransportType: transport,
4143
}
4244

43-
_, err := createMCPClient(target, nil, newTestRegistry(t), "")
45+
_, err := createMCPClient(context.Background(), target, nil, newTestRegistry(t), "", secrets.NewEnvironmentProvider())
4446
require.Error(t, err)
4547
assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport,
4648
"transport %q should return ErrUnsupportedTransport", transport)

0 commit comments

Comments
 (0)