Skip to content

Commit ce1e8c8

Browse files
authored
ASI-07: Add mTLS and HMAC request signing for agent↔gateway communication (#4679)
Per OWASP Agentic Top 10 ASI-07, the MCP gateway communicated over plain HTTP with only an API key — no transport encryption, no message signing, no replay protection. This adds opt-in mTLS and HMAC-SHA256 request signing. ## Phase 1: mTLS New flags/env vars: `--tls-cert` / `MCP_GATEWAY_TLS_CERT`, `--tls-key` / `MCP_GATEWAY_TLS_KEY`, `--tls-ca` / `MCP_GATEWAY_CA_CERT`. - `internal/server/gateway_tls.go` — `LoadGatewayTLS(cert, key, ca)`: loads PEM files, sets `tls.RequireAndVerifyClientCert` when `ca` is non-empty (mTLS) - `internal/cmd/root.go` — wraps the TCP listener with TLS post-bind; plain HTTP remains default (no certs = no change in behaviour) - Partial TLS flag validation: `--tls-cert` and `--tls-key` must both be provided together; `--tls-ca` requires cert+key to also be set — prevents silent plaintext fallback from incomplete config - `writeGatewayConfig` now receives a `tlsEnabled` flag so emitted server URLs use `https://` when TLS is active ## Phase 2: HMAC Request Signing + Replay Protection New flag/env var: `--hmac-secret` / `MCP_GATEWAY_HMAC_SECRET`. - `internal/server/hmac.go` — `hmacMiddleware` injected into `wrapWithMiddleware` (both routed and unified modes); API key auth runs **before** HMAC so unauthenticated requests are rejected before paying the body-read cost - Validates three headers per request: - `X-MCP-Timestamp` — must be within ±30 s of server clock - `X-MCP-Nonce` — must not have been seen before (in-process cache, 60 s TTL) - `X-MCP-Signature` — `HMAC-SHA256(secret, "timestamp\nnonce\npath\nhex(sha256(body))")` - Nonce is recorded **only after** successful signature verification — prevents DoS cache poisoning via requests with invalid signatures - A read-only `seenNonce` pre-check provides fast replay rejection before the body-read step - HMAC applies to `/mcp` handlers only; common endpoints (`/health`, `/close`) are not HMAC-protected ``` # Enable one-way TLS awmg --tls-cert server.crt --tls-key server.key --config config.toml --routed # Enable mTLS (client cert required) awmg --tls-cert server.crt --tls-key server.key --tls-ca ca.crt ... # Enable HMAC signing (can combine with TLS) awmg --hmac-secret $(openssl rand -hex 32) ... ``` Both features are **opt-in and backward compatible** — all flags default to empty/disabled, so omitting them leaves plain-HTTP + API-key behaviour completely unchanged. ## Test coverage - `gateway_tls_test.go`: server-only TLS, mTLS config, live mTLS handshake (proper CA + client cert with `ExtKeyUsageClientAuth`), error paths - `hmac_test.go`: valid signature, missing headers, stale/future timestamps, replay detection, wrong secret, nonce cache concurrency, nonce cache DoS prevention (`TestHMACMiddleware_InvalidSigDoesNotPoisonNonceCache`), `seenNonce` read-only pre-check (`TestNonceCache_SeenNonce`) - `stdout_config_test.go`: `TestWriteGatewayConfig_TLSScheme` verifies `http://` vs `https://` URL emission based on TLS state
2 parents 7b46f70 + 8f56e9f commit ce1e8c8

20 files changed

Lines changed: 1111 additions & 78 deletions

internal/cmd/flags_tls.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package cmd
2+
3+
// TLS and HMAC security flags (ASI-07: Secure Agent↔Gateway Communication)
4+
5+
import (
6+
"github.com/github/gh-aw-mcpg/internal/envutil"
7+
"github.com/spf13/cobra"
8+
)
9+
10+
// TLS/HMAC flag variables
11+
var (
12+
tlsCertPath string
13+
tlsKeyPath string
14+
tlsCAPath string
15+
hmacSecret string
16+
)
17+
18+
func init() {
19+
RegisterFlag(func(cmd *cobra.Command) {
20+
cmd.Flags().StringVar(&tlsCertPath, "tls-cert", envutil.GetEnvString("MCP_GATEWAY_TLS_CERT", ""), "Path to TLS server certificate PEM file (enables HTTPS)")
21+
cmd.Flags().StringVar(&tlsKeyPath, "tls-key", envutil.GetEnvString("MCP_GATEWAY_TLS_KEY", ""), "Path to TLS server private key PEM file (enables HTTPS)")
22+
cmd.Flags().StringVar(&tlsCAPath, "tls-ca", envutil.GetEnvString("MCP_GATEWAY_CA_CERT", ""), "Path to CA certificate PEM file for client certificate verification (enables mTLS)")
23+
cmd.Flags().StringVar(&hmacSecret, "hmac-secret", envutil.GetEnvString("MCP_GATEWAY_HMAC_SECRET", ""), "Shared HMAC-SHA256 secret for request signing and replay protection")
24+
})
25+
}

internal/cmd/root.go

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cmd
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"encoding/json"
78
"fmt"
89
"io"
@@ -386,29 +387,69 @@ func run(cmd *cobra.Command, args []string) error {
386387
// Extract API key from gateway config (spec 7.1)
387388
apiKey := cfg.GetAPIKey()
388389

389-
httpServer = server.CreateHTTPServerForRoutedMode(listenAddr, unifiedServer, apiKey)
390+
httpServer = server.CreateHTTPServerForRoutedMode(listenAddr, unifiedServer, apiKey, hmacSecret)
390391
} else {
391392
logger.StartupInfo("Starting MCPG in UNIFIED mode on %s", listenAddr)
392393
logger.StartupInfo("Endpoint: /mcp")
393394

394395
// Extract API key from gateway config (spec 7.1)
395396
apiKey := cfg.GetAPIKey()
396397

397-
httpServer = server.CreateHTTPServerForMCP(listenAddr, unifiedServer, apiKey)
398+
httpServer = server.CreateHTTPServerForMCP(listenAddr, unifiedServer, apiKey, hmacSecret)
398399
}
399400
// Register the HTTP server shutdown function so the /close handler can drain
400401
// in-flight requests before exiting (spec 5.1.3)
401402
unifiedServer.SetHTTPShutdown(httpServer.Shutdown)
403+
404+
// Build net.Listener — optionally wrapping with TLS (ASI-07 Phase 1).
405+
// Plain HTTP is still used when no TLS certificate is configured (backward compatible).
406+
// Validate that TLS flags are consistent: cert+key must both be provided together,
407+
// and CA cert requires cert+key to be set.
408+
hasCert := tlsCertPath != ""
409+
hasKey := tlsKeyPath != ""
410+
hasCA := tlsCAPath != ""
411+
if hasCert != hasKey {
412+
return fmt.Errorf("--tls-cert and --tls-key must both be provided together")
413+
}
414+
if hasCA && !hasCert {
415+
return fmt.Errorf("--tls-ca requires --tls-cert and --tls-key to also be set")
416+
}
417+
418+
listener, err := net.Listen("tcp", listenAddr)
419+
if err != nil {
420+
return fmt.Errorf("failed to listen on %s: %w", listenAddr, err)
421+
}
422+
tlsEnabled := hasCert && hasKey
423+
var tlsCfg *tls.Config
424+
if tlsEnabled {
425+
tlsCfg, err = server.LoadGatewayTLS(tlsCertPath, tlsKeyPath, tlsCAPath)
426+
if err != nil {
427+
_ = listener.Close()
428+
return fmt.Errorf("failed to configure TLS: %w", err)
429+
}
430+
listener = tls.NewListener(listener, tlsCfg)
431+
mtlsNote := ""
432+
if tlsCAPath != "" {
433+
mtlsNote = " (mTLS enabled)"
434+
}
435+
logger.StartupInfo("TLS enabled: cert=%s, key=%s, ca=%s — listening on https://%s%s", tlsCertPath, tlsKeyPath, tlsCAPath, listenAddr, mtlsNote)
436+
} else {
437+
logger.StartupInfo("TLS not configured — listening on http://%s (set --tls-cert/--tls-key to enable)", listenAddr)
438+
}
439+
if hmacSecret != "" {
440+
logger.StartupInfo("HMAC request signing enabled (ASI-07)")
441+
}
442+
402443
// Start HTTP server in background
403444
go func() {
404-
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
445+
if err := httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
405446
log.Printf("HTTP server error: %v", err)
406447
cancel()
407448
}
408449
}()
409450

410451
// Write gateway configuration to stdout per spec section 5.4
411-
if err := writeGatewayConfigToStdout(cfg, listenAddr, mode); err != nil {
452+
if err := writeGatewayConfigToStdout(cfg, listenAddr, mode, tlsEnabled); err != nil {
412453
log.Printf("Warning: failed to write gateway configuration to stdout: %v", err)
413454
}
414455

@@ -479,11 +520,11 @@ func resolveGuardPolicyOverride(cmd *cobra.Command) (*config.GuardPolicy, string
479520

480521
// writeGatewayConfigToStdout writes the rewritten gateway configuration to stdout
481522
// per MCP Gateway Specification Section 5.4
482-
func writeGatewayConfigToStdout(cfg *config.Config, listenAddr, mode string) error {
483-
return writeGatewayConfig(cfg, listenAddr, mode, os.Stdout)
523+
func writeGatewayConfigToStdout(cfg *config.Config, listenAddr, mode string, tlsEnabled bool) error {
524+
return writeGatewayConfig(cfg, listenAddr, mode, tlsEnabled, os.Stdout)
484525
}
485526

486-
func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, w io.Writer) error {
527+
func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, tlsEnabled bool, w io.Writer) error {
487528
debugLog.Printf("Writing gateway config: listenAddr=%s, mode=%s, serverCount=%d", listenAddr, mode, len(cfg.Servers))
488529

489530
// Parse listen address to extract host and port
@@ -527,17 +568,22 @@ func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, w io.Writer
527568

528569
servers := outputConfig["mcpServers"].(map[string]interface{})
529570

571+
scheme := "http"
572+
if tlsEnabled {
573+
scheme = "https"
574+
}
575+
530576
for name, server := range cfg.Servers {
531577
serverConfig := map[string]interface{}{
532578
"type": "http",
533579
}
534580

535581
var serverURL string
536582
if mode == "routed" {
537-
serverURL = fmt.Sprintf("http://%s:%s/mcp/%s", domain, port, name)
583+
serverURL = fmt.Sprintf("%s://%s:%s/mcp/%s", scheme, domain, port, name)
538584
} else {
539585
// Unified mode - all servers use /mcp endpoint
540-
serverURL = fmt.Sprintf("http://%s:%s/mcp", domain, port)
586+
serverURL = fmt.Sprintf("%s://%s:%s/mcp", scheme, domain, port)
541587
}
542588
serverConfig["url"] = serverURL
543589
debugLog.Printf("Writing server config: name=%s, url=%s, toolCount=%d", name, serverURL, len(server.Tools))

internal/cmd/root_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ func TestWriteGatewayConfig(t *testing.T) {
292292
}
293293

294294
var buf bytes.Buffer
295-
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
295+
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
296296
require.NoError(t, err)
297297

298298
// Parse JSON output
@@ -329,7 +329,7 @@ func TestWriteGatewayConfig(t *testing.T) {
329329
}
330330

331331
var buf bytes.Buffer
332-
err := writeGatewayConfig(cfg, "localhost:8080", "routed", &buf)
332+
err := writeGatewayConfig(cfg, "localhost:8080", "routed", false, &buf)
333333
require.NoError(t, err)
334334

335335
// Parse JSON output
@@ -371,7 +371,7 @@ func TestWriteGatewayConfig(t *testing.T) {
371371
}
372372

373373
var buf bytes.Buffer
374-
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
374+
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
375375
require.NoError(t, err)
376376

377377
// Parse JSON output
@@ -408,7 +408,7 @@ func TestWriteGatewayConfig(t *testing.T) {
408408
}
409409

410410
var buf bytes.Buffer
411-
err := writeGatewayConfig(cfg, "[::1]:3000", "unified", &buf)
411+
err := writeGatewayConfig(cfg, "[::1]:3000", "unified", false, &buf)
412412
require.NoError(t, err)
413413

414414
// Parse JSON output
@@ -436,7 +436,7 @@ func TestWriteGatewayConfig(t *testing.T) {
436436
}
437437

438438
var buf bytes.Buffer
439-
err := writeGatewayConfig(cfg, "invalid-address", "unified", &buf)
439+
err := writeGatewayConfig(cfg, "invalid-address", "unified", false, &buf)
440440
require.NoError(t, err)
441441

442442
output := buf.String()
@@ -560,7 +560,7 @@ func TestWriteGatewayConfig_WildcardAddresses(t *testing.T) {
560560
}
561561

562562
var buf bytes.Buffer
563-
err := writeGatewayConfig(cfg, tt.listenAddr, tt.mode, &buf)
563+
err := writeGatewayConfig(cfg, tt.listenAddr, tt.mode, false, &buf)
564564
require.NoError(t, err)
565565

566566
var result map[string]interface{}
@@ -597,7 +597,7 @@ func TestWriteGatewayConfig_EmptyServerList(t *testing.T) {
597597
}
598598

599599
var buf bytes.Buffer
600-
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
600+
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
601601
require.NoError(t, err)
602602

603603
var result map[string]interface{}
@@ -623,7 +623,7 @@ func TestWriteGatewayConfig_FileSync(t *testing.T) {
623623
require.NoError(t, err)
624624
defer tmpFile.Close()
625625

626-
err = writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", tmpFile)
626+
err = writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, tmpFile)
627627
require.NoError(t, err)
628628

629629
// Re-read and verify the file was written correctly

internal/cmd/stdout_config_test.go

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func TestWriteGatewayConfigToStdout(t *testing.T) {
153153
for _, tt := range tests {
154154
t.Run(tt.name, func(t *testing.T) {
155155
var buf bytes.Buffer
156-
err := writeGatewayConfig(tt.cfg, tt.listenAddr, tt.mode, &buf)
156+
err := writeGatewayConfig(tt.cfg, tt.listenAddr, tt.mode, false, &buf)
157157
require.NoError(t, err)
158158

159159
var result map[string]interface{}
@@ -200,7 +200,7 @@ func TestWriteGatewayConfigToStdout_EmptyConfig(t *testing.T) {
200200
}
201201

202202
var buf bytes.Buffer
203-
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", &buf)
203+
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", false, &buf)
204204
require.NoError(t, err)
205205

206206
var result map[string]interface{}
@@ -221,7 +221,7 @@ func TestWriteGatewayConfigToStdout_JSONFormat(t *testing.T) {
221221
}
222222

223223
var buf bytes.Buffer
224-
err := writeGatewayConfig(cfg, "localhost:3000", "routed", &buf)
224+
err := writeGatewayConfig(cfg, "localhost:3000", "routed", false, &buf)
225225
require.NoError(t, err)
226226

227227
// Verify it's valid JSON
@@ -251,7 +251,7 @@ func TestWriteGatewayConfigToStdout_WithPipe(t *testing.T) {
251251
// Write configuration to pipe in a goroutine
252252
errCh := make(chan error, 1)
253253
go func() {
254-
writeErr := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", w)
254+
writeErr := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, w)
255255
w.Close() // Close writer to signal EOF
256256
errCh <- writeErr
257257
}()
@@ -283,7 +283,7 @@ func TestWriteGatewayConfig_WriteError(t *testing.T) {
283283
},
284284
}
285285

286-
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", errWriter{})
286+
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", false, errWriter{})
287287
require.Error(t, err)
288288
assert.Contains(t, err.Error(), "failed to encode configuration")
289289
}
@@ -298,7 +298,7 @@ func TestWriteGatewayConfig_PortOnlyAddress(t *testing.T) {
298298
}
299299

300300
var buf bytes.Buffer
301-
err := writeGatewayConfig(cfg, ":8080", "unified", &buf)
301+
err := writeGatewayConfig(cfg, ":8080", "unified", false, &buf)
302302
require.NoError(t, err)
303303

304304
var result map[string]interface{}
@@ -317,3 +317,43 @@ func TestWriteGatewayConfig_PortOnlyAddress(t *testing.T) {
317317
assert.Contains(t, url, DefaultListenIPv4, "Should use default IPv4 host when address has no host")
318318
assert.Contains(t, url, "8080", "Should preserve the port")
319319
}
320+
321+
// TestWriteGatewayConfig_TLSScheme verifies that https:// URLs are emitted when
322+
// tlsEnabled=true and http:// URLs are emitted otherwise.
323+
func TestWriteGatewayConfig_TLSScheme(t *testing.T) {
324+
cfg := &config.Config{
325+
Servers: map[string]*config.ServerConfig{
326+
"github": {Command: "echo"},
327+
},
328+
}
329+
330+
tests := []struct {
331+
name string
332+
tlsEnabled bool
333+
wantScheme string
334+
}{
335+
{"plain HTTP", false, "http://"},
336+
{"HTTPS (TLS)", true, "https://"},
337+
}
338+
339+
for _, tt := range tests {
340+
t.Run(tt.name, func(t *testing.T) {
341+
var buf bytes.Buffer
342+
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "routed", tt.tlsEnabled, &buf)
343+
require.NoError(t, err)
344+
345+
var result map[string]interface{}
346+
require.NoError(t, json.Unmarshal(buf.Bytes(), &result))
347+
348+
mcpServers := result["mcpServers"].(map[string]interface{})
349+
serverConfig := mcpServers["github"].(map[string]interface{})
350+
url := serverConfig["url"].(string)
351+
352+
assert.True(t, len(url) > 0, "URL should not be empty")
353+
assert.True(t,
354+
(tt.wantScheme == "https://" && url[:8] == "https://") ||
355+
(tt.wantScheme == "http://" && url[:7] == "http://"),
356+
"URL %q should start with %s", url, tt.wantScheme)
357+
})
358+
}
359+
}

internal/server/allowed_tools_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func TestAllowedTools_RoutedMode_ToolsListFiltered(t *testing.T) {
347347
defer us.Close()
348348

349349
// Create routed HTTP server and verify the filtered server only exposes allowed tools.
350-
httpSrv := CreateHTTPServerForRoutedMode("127.0.0.1:0", us, "") // no API key for test
350+
httpSrv := CreateHTTPServerForRoutedMode("127.0.0.1:0", us, "", "") // no API key for test
351351
ts := httptest.NewServer(httpSrv.Handler)
352352
defer ts.Close()
353353

internal/server/gateway_tls.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package server
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"fmt"
7+
"os"
8+
9+
"github.com/github/gh-aw-mcpg/internal/logger"
10+
)
11+
12+
var logGatewayTLS = logger.New("server:tls")
13+
14+
// LoadGatewayTLS loads a TLS configuration for the gateway HTTP server from PEM
15+
// certificate and key files. When caPath is non-empty the returned config
16+
// requires client certificates signed by that CA (mutual TLS / mTLS).
17+
//
18+
// Pass an empty caPath to use one-way TLS (server-only authentication).
19+
//
20+
// Example — one-way TLS (server cert only):
21+
//
22+
// tlsCfg, err := LoadGatewayTLS("/path/server.crt", "/path/server.key", "")
23+
//
24+
// Example — mutual TLS (client certs required):
25+
//
26+
// tlsCfg, err := LoadGatewayTLS("/path/server.crt", "/path/server.key", "/path/ca.crt")
27+
func LoadGatewayTLS(certPath, keyPath, caPath string) (*tls.Config, error) {
28+
logGatewayTLS.Printf("loading gateway TLS: cert=%s, key=%s, ca=%s", certPath, keyPath, caPath)
29+
30+
serverCert, err := tls.LoadX509KeyPair(certPath, keyPath)
31+
if err != nil {
32+
return nil, fmt.Errorf("failed to load server TLS certificate/key: %w", err)
33+
}
34+
logGatewayTLS.Print("server TLS key pair loaded")
35+
36+
cfg := &tls.Config{
37+
Certificates: []tls.Certificate{serverCert},
38+
MinVersion: tls.VersionTLS12,
39+
}
40+
41+
if caPath != "" {
42+
caPEM, err := os.ReadFile(caPath)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
45+
}
46+
47+
caPool := x509.NewCertPool()
48+
if !caPool.AppendCertsFromPEM(caPEM) {
49+
return nil, fmt.Errorf("failed to parse CA certificate from %s", caPath)
50+
}
51+
52+
// Require and verify client certificates signed by the provided CA.
53+
cfg.ClientCAs = caPool
54+
cfg.ClientAuth = tls.RequireAndVerifyClientCert
55+
logGatewayTLS.Printf("mTLS enabled: client certificates required, CA=%s", caPath)
56+
}
57+
58+
return cfg, nil
59+
}

0 commit comments

Comments
 (0)