Skip to content

Commit ea61b05

Browse files
committed
Implement tracing support for MCP transport
- Added enableTracing flag to configuration for raw JSON-RPC message tracing. - Introduced TracingTransport to log outgoing requests and incoming responses. - Updated MCPProxyServer and Client to utilize tracing functionality. - Enhanced connection handling to respect OAuth state during connection attempts. - Improved logging for connection status and OAuth token management.
1 parent e4c267e commit ea61b05

11 files changed

Lines changed: 1365 additions & 217 deletions

File tree

cmd/mcpproxy/main.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ var (
3333
allowServerAdd bool
3434
allowServerRemove bool
3535
enablePrompts bool
36+
enableTracing bool
3637

3738
version = "v0.1.0" // This will be injected by -ldflags during build
3839
)
@@ -67,6 +68,7 @@ func main() {
6768
rootCmd.PersistentFlags().BoolVar(&allowServerAdd, "allow-server-add", true, "Allow adding new servers")
6869
rootCmd.PersistentFlags().BoolVar(&allowServerRemove, "allow-server-remove", true, "Allow removing existing servers")
6970
rootCmd.PersistentFlags().BoolVar(&enablePrompts, "enable-prompts", true, "Enable prompts for user input")
71+
rootCmd.PersistentFlags().BoolVar(&enableTracing, "enable-tracing", false, "Enable raw JSON-RPC message tracing")
7072

7173
if err := rootCmd.Execute(); err != nil {
7274
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
@@ -151,14 +153,21 @@ func runServer(cmd *cobra.Command, _ []string) error {
151153
cfg.AllowServerAdd = allowServerAdd
152154
cfg.AllowServerRemove = allowServerRemove
153155
cfg.EnablePrompts = enablePrompts
156+
cfg.EnableTracing = enableTracing
157+
158+
// Also check environment variable for tracing
159+
if os.Getenv("MCP_TRACE") != "" {
160+
cfg.EnableTracing = true
161+
}
154162

155163
logger.Info("Configuration loaded",
156164
zap.String("data_dir", cfg.DataDir),
157165
zap.Int("servers_count", len(cfg.Servers)),
158166
zap.Bool("tray_enabled", cfg.EnableTray),
159167
zap.Bool("read_only_mode", cfg.ReadOnlyMode),
160168
zap.Bool("disable_management", cfg.DisableManagement),
161-
zap.Bool("enable_prompts", cfg.EnablePrompts))
169+
zap.Bool("enable_prompts", cfg.EnablePrompts),
170+
zap.Bool("enable_tracing", cfg.EnableTracing))
162171

163172
// Create server
164173
srv, err := server.NewServer(cfg, logger)

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ type Config struct {
103103
// Prompts settings
104104
EnablePrompts bool `json:"enable_prompts" mapstructure:"enable-prompts"`
105105

106+
// Tracing settings
107+
EnableTracing bool `json:"enable_tracing" mapstructure:"enable-tracing"`
108+
106109
// Deployment configuration
107110
PublicURL string `json:"public_url,omitempty" mapstructure:"public-url"` // For remote deployments
108111
DeploymentType string `json:"deployment_type,omitempty" mapstructure:"deployment-type"` // "local", "remote", "headless", "auto"

internal/server/mcp.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,60 @@ func NewMCPProxyServer(
5555
debugSearch bool,
5656
config *config.Config,
5757
) *MCPProxyServer {
58-
// Create MCP server with capabilities
58+
// Create tracing hooks for raw JSON-RPC message logging
59+
hooks := &mcpserver.Hooks{}
60+
61+
// Add hook for incoming requests (before processing)
62+
hooks.AddBeforeAny(func(ctx context.Context, id any, method mcp.MCPMethod, message any) {
63+
if config.EnableTracing {
64+
messageBytes, err := json.Marshal(message)
65+
if err != nil {
66+
logger.Error("Failed to marshal incoming message for tracing",
67+
zap.Error(err),
68+
zap.String("method", string(method)),
69+
zap.Any("id", id))
70+
return
71+
}
72+
logger.Info("🔍 MCP Request",
73+
zap.String("method", string(method)),
74+
zap.Any("id", id),
75+
zap.String("message", string(messageBytes)))
76+
}
77+
})
78+
79+
// Add hook for successful responses (after processing)
80+
hooks.AddOnSuccess(func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
81+
if config.EnableTracing {
82+
resultBytes, err := json.Marshal(result)
83+
if err != nil {
84+
logger.Error("Failed to marshal response for tracing",
85+
zap.Error(err),
86+
zap.String("method", string(method)),
87+
zap.Any("id", id))
88+
return
89+
}
90+
logger.Info("✅ MCP Response",
91+
zap.String("method", string(method)),
92+
zap.Any("id", id),
93+
zap.String("result", string(resultBytes)))
94+
}
95+
})
96+
97+
// Add hook for errors
98+
hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
99+
if config.EnableTracing {
100+
logger.Error("❌ MCP Error",
101+
zap.String("method", string(method)),
102+
zap.Any("id", id),
103+
zap.Error(err))
104+
}
105+
})
106+
107+
// Create MCP server with capabilities and hooks
59108
capabilities := []mcpserver.ServerOption{
60109
mcpserver.WithToolCapabilities(true),
61110
mcpserver.WithRecovery(),
111+
mcpserver.WithHooks(hooks), // Add tracing hooks
62112
}
63113

64114
// Add prompts capability if enabled

internal/server/mcp_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ func TestHandleCallToolErrorRecovery(t *testing.T) {
545545
// This test verifies the core issue mentioned in the error logs
546546

547547
mockProxy := &MCPProxyServer{
548-
upstreamManager: upstream.NewManager(zap.NewNop(), config.DefaultConfig()),
548+
upstreamManager: upstream.NewManager(zap.NewNop(), config.DefaultConfig(), nil),
549549
logger: zap.NewNop(),
550550
}
551551

@@ -587,7 +587,7 @@ func TestHandleCallToolCompleteErrorHandling(t *testing.T) {
587587
// Test comprehensive error handling scenarios including self-referential calls
588588

589589
mockProxy := &MCPProxyServer{
590-
upstreamManager: upstream.NewManager(zap.NewNop(), config.DefaultConfig()),
590+
upstreamManager: upstream.NewManager(zap.NewNop(), config.DefaultConfig(), nil),
591591
logger: zap.NewNop(),
592592
config: &config.Config{}, // Add minimal config for testing
593593
}

internal/server/server.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func NewServer(cfg *config.Config, logger *zap.Logger) (*Server, error) {
7474
}
7575

7676
// Initialize upstream manager
77-
upstreamManager := upstream.NewManager(logger, cfg)
77+
upstreamManager := upstream.NewManager(logger, cfg, storageManager)
7878

7979
// Set logging configuration on upstream manager for per-server logging
8080
if cfg.Logging != nil {
@@ -261,6 +261,7 @@ func (s *Server) connectAllWithRetry(ctx context.Context) {
261261
stats := s.upstreamManager.GetStats()
262262
connectedCount := 0
263263
totalCount := 0
264+
oauthPendingCount := 0
264265

265266
if serverStats, ok := stats["servers"].(map[string]interface{}); ok {
266267
totalCount = len(serverStats)
@@ -269,37 +270,48 @@ func (s *Server) connectAllWithRetry(ctx context.Context) {
269270
if connected, ok := stat["connected"].(bool); ok && connected {
270271
connectedCount++
271272
}
273+
// Count OAuth pending servers separately - they shouldn't trigger retries
274+
if oauthPending, ok := stat["oauth_pending"].(bool); ok && oauthPending {
275+
oauthPendingCount++
276+
}
272277
}
273278
}
274279
}
275280

276281
s.logger.Debug("Connection status",
277282
zap.Int("connected_count", connectedCount),
283+
zap.Int("oauth_pending_count", oauthPendingCount),
278284
zap.Int("total_count", totalCount))
279285

280-
if connectedCount < totalCount {
286+
// Only attempt connections if there are servers that can actually connect
287+
// Don't retry if all remaining servers are OAuth pending
288+
serversNeedingConnection := totalCount - connectedCount - oauthPendingCount
289+
290+
if serversNeedingConnection > 0 {
281291
// Only update status to "Connecting" if server is not running
282292
// If server is running, don't override the "Running" status
283293
s.mu.RLock()
284294
isRunning := s.running
285295
s.mu.RUnlock()
286296

287297
if !isRunning {
288-
s.updateStatus("Connecting", fmt.Sprintf("Connected to %d/%d servers, retrying...", connectedCount, totalCount))
298+
s.updateStatus("Connecting", fmt.Sprintf("Connected to %d/%d servers (%d OAuth pending), retrying...",
299+
connectedCount, totalCount, oauthPendingCount))
289300
}
290301

291-
// ISSUE: This hardcoded 10s timeout overrides custom server timeouts!
292-
// Try to connect with timeout
293-
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
302+
// Use a reasonable timeout but don't override custom server timeouts completely
303+
// The individual clients will still respect their own timeouts
304+
connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
294305
defer cancel()
295306

296-
s.logger.Debug("Calling upstreamManager.ConnectAll with hardcoded 10s timeout - this overrides custom server timeouts!")
307+
s.logger.Debug("Calling upstreamManager.ConnectAll",
308+
zap.Int("servers_needing_connection", serversNeedingConnection))
297309

298310
if err := s.upstreamManager.ConnectAll(connectCtx); err != nil {
299311
s.logger.Warn("Some upstream servers failed to connect", zap.Error(err))
300312
}
301313
} else {
302-
s.logger.Debug("All servers already connected, skipping connection attempt")
314+
s.logger.Debug("All servers already connected or OAuth pending, skipping connection attempt")
303315
}
304316
}
305317

internal/server/upstream_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestUpstreamServersHandlerPerformance(t *testing.T) {
4545
defer indexManager.Close()
4646

4747
// Create upstream manager
48-
upstreamManager := upstream.NewManager(zap.NewNop(), cfg)
48+
upstreamManager := upstream.NewManager(zap.NewNop(), cfg, storageManager)
4949

5050
// Create cache manager
5151
cacheManager, err := cache.NewManager(storageManager.GetDB(), zap.NewNop())
@@ -144,7 +144,7 @@ func TestUpstreamServersListOperation(t *testing.T) {
144144
defer indexManager.Close()
145145

146146
// Create upstream manager
147-
upstreamManager := upstream.NewManager(zap.NewNop(), cfg)
147+
upstreamManager := upstream.NewManager(zap.NewNop(), cfg, storageManager)
148148

149149
// Create cache manager
150150
cacheManager, err := cache.NewManager(storageManager.GetDB(), zap.NewNop())

internal/storage/manager.go

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,73 @@ func (m *Manager) SaveUpstreamServer(serverConfig *config.ServerConfig) error {
7676
Updated: time.Now(),
7777
}
7878

79+
// Include OAuth tokens if present
80+
if serverConfig.OAuth != nil && serverConfig.OAuth.TokenStorage != nil {
81+
record.OAuthTokens = &OAuthTokenRecord{
82+
AccessToken: serverConfig.OAuth.TokenStorage.AccessToken,
83+
RefreshToken: serverConfig.OAuth.TokenStorage.RefreshToken,
84+
ExpiresAt: serverConfig.OAuth.TokenStorage.ExpiresAt,
85+
TokenType: serverConfig.OAuth.TokenStorage.TokenType,
86+
Scope: serverConfig.OAuth.TokenStorage.Scope,
87+
Updated: time.Now(),
88+
}
89+
}
90+
7991
return m.db.SaveUpstream(record)
8092
}
8193

94+
// SaveOAuthTokens saves OAuth tokens for a specific upstream server
95+
func (m *Manager) SaveOAuthTokens(serverName string, tokens *config.TokenStorage) error {
96+
m.mu.Lock()
97+
defer m.mu.Unlock()
98+
99+
// Get existing upstream record
100+
record, err := m.db.GetUpstream(serverName)
101+
if err != nil {
102+
return fmt.Errorf("failed to get upstream server: %w", err)
103+
}
104+
105+
// Update OAuth tokens
106+
if tokens != nil {
107+
record.OAuthTokens = &OAuthTokenRecord{
108+
AccessToken: tokens.AccessToken,
109+
RefreshToken: tokens.RefreshToken,
110+
ExpiresAt: tokens.ExpiresAt,
111+
TokenType: tokens.TokenType,
112+
Scope: tokens.Scope,
113+
Updated: time.Now(),
114+
}
115+
} else {
116+
record.OAuthTokens = nil
117+
}
118+
119+
record.Updated = time.Now()
120+
return m.db.SaveUpstream(record)
121+
}
122+
123+
// LoadOAuthTokens loads OAuth tokens for a specific upstream server
124+
func (m *Manager) LoadOAuthTokens(serverName string) (*config.TokenStorage, error) {
125+
m.mu.RLock()
126+
defer m.mu.RUnlock()
127+
128+
record, err := m.db.GetUpstream(serverName)
129+
if err != nil {
130+
return nil, fmt.Errorf("failed to get upstream server: %w", err)
131+
}
132+
133+
if record.OAuthTokens == nil {
134+
return nil, nil // No tokens stored
135+
}
136+
137+
return &config.TokenStorage{
138+
AccessToken: record.OAuthTokens.AccessToken,
139+
RefreshToken: record.OAuthTokens.RefreshToken,
140+
ExpiresAt: record.OAuthTokens.ExpiresAt,
141+
TokenType: record.OAuthTokens.TokenType,
142+
Scope: record.OAuthTokens.Scope,
143+
}, nil
144+
}
145+
82146
// GetUpstreamServer retrieves an upstream server by name
83147
func (m *Manager) GetUpstreamServer(name string) (*config.ServerConfig, error) {
84148
m.mu.RLock()
@@ -89,7 +153,7 @@ func (m *Manager) GetUpstreamServer(name string) (*config.ServerConfig, error) {
89153
return nil, err
90154
}
91155

92-
return &config.ServerConfig{
156+
serverConfig := &config.ServerConfig{
93157
Name: record.Name,
94158
URL: record.URL,
95159
Protocol: record.Protocol,
@@ -101,7 +165,23 @@ func (m *Manager) GetUpstreamServer(name string) (*config.ServerConfig, error) {
101165
Quarantined: record.Quarantined,
102166
Created: record.Created,
103167
Updated: record.Updated,
104-
}, nil
168+
}
169+
170+
// Load OAuth tokens if present
171+
if record.OAuthTokens != nil {
172+
if serverConfig.OAuth == nil {
173+
serverConfig.OAuth = &config.OAuthConfig{}
174+
}
175+
serverConfig.OAuth.TokenStorage = &config.TokenStorage{
176+
AccessToken: record.OAuthTokens.AccessToken,
177+
RefreshToken: record.OAuthTokens.RefreshToken,
178+
ExpiresAt: record.OAuthTokens.ExpiresAt,
179+
TokenType: record.OAuthTokens.TokenType,
180+
Scope: record.OAuthTokens.Scope,
181+
}
182+
}
183+
184+
return serverConfig, nil
105185
}
106186

107187
// ListUpstreamServers returns all upstream servers

internal/storage/models.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ type UpstreamRecord struct {
3737
Quarantined bool `json:"quarantined"` // Security quarantine status
3838
Created time.Time `json:"created"`
3939
Updated time.Time `json:"updated"`
40+
41+
// OAuth token storage for persistence across restarts
42+
OAuthTokens *OAuthTokenRecord `json:"oauth_tokens,omitempty"`
43+
}
44+
45+
// OAuthTokenRecord represents stored OAuth tokens for an upstream server
46+
type OAuthTokenRecord struct {
47+
AccessToken string `json:"access_token,omitempty"`
48+
RefreshToken string `json:"refresh_token,omitempty"`
49+
ExpiresAt time.Time `json:"expires_at,omitempty"`
50+
TokenType string `json:"token_type,omitempty"`
51+
Scope string `json:"scope,omitempty"`
52+
Updated time.Time `json:"updated"`
4053
}
4154

4255
// ToolStatRecord represents tool usage statistics

0 commit comments

Comments
 (0)