Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions internal/cmd/flags_tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package cmd

// TLS and HMAC security flags (ASI-07: Secure Agent↔Gateway Communication)

import (
"github.com/github/gh-aw-mcpg/internal/envutil"
"github.com/spf13/cobra"
)

// TLS/HMAC flag variables
var (
tlsCertPath string
tlsKeyPath string
tlsCAPath string
hmacSecret string
)

func init() {
RegisterFlag(func(cmd *cobra.Command) {
cmd.Flags().StringVar(&tlsCertPath, "tls-cert", envutil.GetEnvString("MCP_GATEWAY_TLS_CERT", ""), "Path to TLS server certificate PEM file (enables HTTPS)")
cmd.Flags().StringVar(&tlsKeyPath, "tls-key", envutil.GetEnvString("MCP_GATEWAY_TLS_KEY", ""), "Path to TLS server private key PEM file (enables HTTPS)")
cmd.Flags().StringVar(&tlsCAPath, "tls-ca", envutil.GetEnvString("MCP_GATEWAY_CA_CERT", ""), "Path to CA certificate PEM file for client certificate verification (enables mTLS)")
cmd.Flags().StringVar(&hmacSecret, "hmac-secret", envutil.GetEnvString("MCP_GATEWAY_HMAC_SECRET", ""), "Shared HMAC-SHA256 secret for request signing and replay protection")
})
}
64 changes: 55 additions & 9 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"bufio"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -386,29 +387,69 @@ func run(cmd *cobra.Command, args []string) error {
// Extract API key from gateway config (spec 7.1)
apiKey := cfg.GetAPIKey()

httpServer = server.CreateHTTPServerForRoutedMode(listenAddr, unifiedServer, apiKey)
httpServer = server.CreateHTTPServerForRoutedMode(listenAddr, unifiedServer, apiKey, hmacSecret)
} else {
logger.StartupInfo("Starting MCPG in UNIFIED mode on %s", listenAddr)
logger.StartupInfo("Endpoint: /mcp")

// Extract API key from gateway config (spec 7.1)
apiKey := cfg.GetAPIKey()

httpServer = server.CreateHTTPServerForMCP(listenAddr, unifiedServer, apiKey)
httpServer = server.CreateHTTPServerForMCP(listenAddr, unifiedServer, apiKey, hmacSecret)
}
// Register the HTTP server shutdown function so the /close handler can drain
// in-flight requests before exiting (spec 5.1.3)
unifiedServer.SetHTTPShutdown(httpServer.Shutdown)

// Build net.Listener — optionally wrapping with TLS (ASI-07 Phase 1).
// Plain HTTP is still used when no TLS certificate is configured (backward compatible).
// Validate that TLS flags are consistent: cert+key must both be provided together,
// and CA cert requires cert+key to be set.
hasCert := tlsCertPath != ""
hasKey := tlsKeyPath != ""
hasCA := tlsCAPath != ""
if hasCert != hasKey {
return fmt.Errorf("--tls-cert and --tls-key must both be provided together")
}
if hasCA && !hasCert {
return fmt.Errorf("--tls-ca requires --tls-cert and --tls-key to also be set")
}

listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", listenAddr, err)
}
tlsEnabled := hasCert && hasKey
var tlsCfg *tls.Config
if tlsEnabled {
tlsCfg, err = server.LoadGatewayTLS(tlsCertPath, tlsKeyPath, tlsCAPath)
if err != nil {
_ = listener.Close()
return fmt.Errorf("failed to configure TLS: %w", err)
}
listener = tls.NewListener(listener, tlsCfg)
mtlsNote := ""
if tlsCAPath != "" {
mtlsNote = " (mTLS enabled)"
}
logger.StartupInfo("TLS enabled: cert=%s, key=%s, ca=%s — listening on https://%s%s", tlsCertPath, tlsKeyPath, tlsCAPath, listenAddr, mtlsNote)
} else {
logger.StartupInfo("TLS not configured — listening on http://%s (set --tls-cert/--tls-key to enable)", listenAddr)
}
if hmacSecret != "" {
logger.StartupInfo("HMAC request signing enabled (ASI-07)")
}

// Start HTTP server in background
go func() {
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if err := httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
Comment on lines +404 to +445
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When TLS is enabled via --tls-cert/--tls-key, the server will speak HTTPS on the listener, but writeGatewayConfigToStdout still emits http://... URLs. This will cause clients that consume the emitted config to attempt plaintext connections and fail. Consider threading the effective scheme (http/https) into the config writer (or deriving it from the TLS flags) so the emitted URLs match the actual listener.

Copilot uses AI. Check for mistakes.
log.Printf("HTTP server error: %v", err)
cancel()
}
}()

// Write gateway configuration to stdout per spec section 5.4
if err := writeGatewayConfigToStdout(cfg, listenAddr, mode); err != nil {
if err := writeGatewayConfigToStdout(cfg, listenAddr, mode, tlsEnabled); err != nil {
log.Printf("Warning: failed to write gateway configuration to stdout: %v", err)
}

Expand Down Expand Up @@ -479,11 +520,11 @@ func resolveGuardPolicyOverride(cmd *cobra.Command) (*config.GuardPolicy, string

// writeGatewayConfigToStdout writes the rewritten gateway configuration to stdout
// per MCP Gateway Specification Section 5.4
func writeGatewayConfigToStdout(cfg *config.Config, listenAddr, mode string) error {
return writeGatewayConfig(cfg, listenAddr, mode, os.Stdout)
func writeGatewayConfigToStdout(cfg *config.Config, listenAddr, mode string, tlsEnabled bool) error {
return writeGatewayConfig(cfg, listenAddr, mode, tlsEnabled, os.Stdout)
}

func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, w io.Writer) error {
func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, tlsEnabled bool, w io.Writer) error {
debugLog.Printf("Writing gateway config: listenAddr=%s, mode=%s, serverCount=%d", listenAddr, mode, len(cfg.Servers))

// Parse listen address to extract host and port
Expand Down Expand Up @@ -527,17 +568,22 @@ func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, w io.Writer

servers := outputConfig["mcpServers"].(map[string]interface{})

scheme := "http"
if tlsEnabled {
scheme = "https"
}

for name, server := range cfg.Servers {
serverConfig := map[string]interface{}{
"type": "http",
}

var serverURL string
if mode == "routed" {
serverURL = fmt.Sprintf("http://%s:%s/mcp/%s", domain, port, name)
serverURL = fmt.Sprintf("%s://%s:%s/mcp/%s", scheme, domain, port, name)
} else {
// Unified mode - all servers use /mcp endpoint
serverURL = fmt.Sprintf("http://%s:%s/mcp", domain, port)
serverURL = fmt.Sprintf("%s://%s:%s/mcp", scheme, domain, port)
}
serverConfig["url"] = serverURL
debugLog.Printf("Writing server config: name=%s, url=%s, toolCount=%d", name, serverURL, len(server.Tools))
Expand Down
16 changes: 8 additions & 8 deletions internal/cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func TestWriteGatewayConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
require.NoError(t, err)

// Parse JSON output
Expand Down Expand Up @@ -329,7 +329,7 @@ func TestWriteGatewayConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "localhost:8080", "routed", &buf)
err := writeGatewayConfig(cfg, "localhost:8080", "routed", false, &buf)
require.NoError(t, err)

// Parse JSON output
Expand Down Expand Up @@ -371,7 +371,7 @@ func TestWriteGatewayConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
require.NoError(t, err)

// Parse JSON output
Expand Down Expand Up @@ -408,7 +408,7 @@ func TestWriteGatewayConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "[::1]:3000", "unified", &buf)
err := writeGatewayConfig(cfg, "[::1]:3000", "unified", false, &buf)
require.NoError(t, err)

// Parse JSON output
Expand Down Expand Up @@ -436,7 +436,7 @@ func TestWriteGatewayConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "invalid-address", "unified", &buf)
err := writeGatewayConfig(cfg, "invalid-address", "unified", false, &buf)
require.NoError(t, err)

output := buf.String()
Expand Down Expand Up @@ -560,7 +560,7 @@ func TestWriteGatewayConfig_WildcardAddresses(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, tt.listenAddr, tt.mode, &buf)
err := writeGatewayConfig(cfg, tt.listenAddr, tt.mode, false, &buf)
require.NoError(t, err)

var result map[string]interface{}
Expand Down Expand Up @@ -597,7 +597,7 @@ func TestWriteGatewayConfig_EmptyServerList(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", &buf)
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, &buf)
require.NoError(t, err)

var result map[string]interface{}
Expand All @@ -623,7 +623,7 @@ func TestWriteGatewayConfig_FileSync(t *testing.T) {
require.NoError(t, err)
defer tmpFile.Close()

err = writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", tmpFile)
err = writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, tmpFile)
require.NoError(t, err)

// Re-read and verify the file was written correctly
Expand Down
52 changes: 46 additions & 6 deletions internal/cmd/stdout_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestWriteGatewayConfigToStdout(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := writeGatewayConfig(tt.cfg, tt.listenAddr, tt.mode, &buf)
err := writeGatewayConfig(tt.cfg, tt.listenAddr, tt.mode, false, &buf)
require.NoError(t, err)

var result map[string]interface{}
Expand Down Expand Up @@ -200,7 +200,7 @@ func TestWriteGatewayConfigToStdout_EmptyConfig(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", &buf)
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", false, &buf)
require.NoError(t, err)

var result map[string]interface{}
Expand All @@ -221,7 +221,7 @@ func TestWriteGatewayConfigToStdout_JSONFormat(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, "localhost:3000", "routed", &buf)
err := writeGatewayConfig(cfg, "localhost:3000", "routed", false, &buf)
require.NoError(t, err)

// Verify it's valid JSON
Expand Down Expand Up @@ -251,7 +251,7 @@ func TestWriteGatewayConfigToStdout_WithPipe(t *testing.T) {
// Write configuration to pipe in a goroutine
errCh := make(chan error, 1)
go func() {
writeErr := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", w)
writeErr := writeGatewayConfig(cfg, "127.0.0.1:3000", "unified", false, w)
w.Close() // Close writer to signal EOF
errCh <- writeErr
}()
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestWriteGatewayConfig_WriteError(t *testing.T) {
},
}

err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", errWriter{})
err := writeGatewayConfig(cfg, "127.0.0.1:8080", "routed", false, errWriter{})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to encode configuration")
}
Expand All @@ -298,7 +298,7 @@ func TestWriteGatewayConfig_PortOnlyAddress(t *testing.T) {
}

var buf bytes.Buffer
err := writeGatewayConfig(cfg, ":8080", "unified", &buf)
err := writeGatewayConfig(cfg, ":8080", "unified", false, &buf)
require.NoError(t, err)

var result map[string]interface{}
Expand All @@ -317,3 +317,43 @@ func TestWriteGatewayConfig_PortOnlyAddress(t *testing.T) {
assert.Contains(t, url, DefaultListenIPv4, "Should use default IPv4 host when address has no host")
assert.Contains(t, url, "8080", "Should preserve the port")
}

// TestWriteGatewayConfig_TLSScheme verifies that https:// URLs are emitted when
// tlsEnabled=true and http:// URLs are emitted otherwise.
func TestWriteGatewayConfig_TLSScheme(t *testing.T) {
cfg := &config.Config{
Servers: map[string]*config.ServerConfig{
"github": {Command: "echo"},
},
}

tests := []struct {
name string
tlsEnabled bool
wantScheme string
}{
{"plain HTTP", false, "http://"},
{"HTTPS (TLS)", true, "https://"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := writeGatewayConfig(cfg, "127.0.0.1:3000", "routed", tt.tlsEnabled, &buf)
require.NoError(t, err)

var result map[string]interface{}
require.NoError(t, json.Unmarshal(buf.Bytes(), &result))

mcpServers := result["mcpServers"].(map[string]interface{})
serverConfig := mcpServers["github"].(map[string]interface{})
url := serverConfig["url"].(string)

assert.True(t, len(url) > 0, "URL should not be empty")
assert.True(t,
(tt.wantScheme == "https://" && url[:8] == "https://") ||
(tt.wantScheme == "http://" && url[:7] == "http://"),
"URL %q should start with %s", url, tt.wantScheme)
})
}
}
2 changes: 1 addition & 1 deletion internal/server/allowed_tools_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func TestAllowedTools_RoutedMode_ToolsListFiltered(t *testing.T) {
defer us.Close()

// Create routed HTTP server and verify the filtered server only exposes allowed tools.
httpSrv := CreateHTTPServerForRoutedMode("127.0.0.1:0", us, "") // no API key for test
httpSrv := CreateHTTPServerForRoutedMode("127.0.0.1:0", us, "", "") // no API key for test
ts := httptest.NewServer(httpSrv.Handler)
defer ts.Close()

Expand Down
59 changes: 59 additions & 0 deletions internal/server/gateway_tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package server

import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"

"github.com/github/gh-aw-mcpg/internal/logger"
)

var logGatewayTLS = logger.New("server:tls")

// LoadGatewayTLS loads a TLS configuration for the gateway HTTP server from PEM
// certificate and key files. When caPath is non-empty the returned config
// requires client certificates signed by that CA (mutual TLS / mTLS).
//
// Pass an empty caPath to use one-way TLS (server-only authentication).
//
// Example — one-way TLS (server cert only):
//
// tlsCfg, err := LoadGatewayTLS("/path/server.crt", "/path/server.key", "")
//
// Example — mutual TLS (client certs required):
//
// tlsCfg, err := LoadGatewayTLS("/path/server.crt", "/path/server.key", "/path/ca.crt")
func LoadGatewayTLS(certPath, keyPath, caPath string) (*tls.Config, error) {
logGatewayTLS.Printf("loading gateway TLS: cert=%s, key=%s, ca=%s", certPath, keyPath, caPath)

serverCert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load server TLS certificate/key: %w", err)
}
logGatewayTLS.Print("server TLS key pair loaded")

cfg := &tls.Config{
Certificates: []tls.Certificate{serverCert},
MinVersion: tls.VersionTLS12,
}

if caPath != "" {
caPEM, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}

caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caPEM) {
return nil, fmt.Errorf("failed to parse CA certificate from %s", caPath)
}

// Require and verify client certificates signed by the provided CA.
cfg.ClientCAs = caPool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
logGatewayTLS.Printf("mTLS enabled: client certificates required, CA=%s", caPath)
}

return cfg, nil
}
Loading
Loading