Skip to content

Commit c89d4bf

Browse files
committed
Enhance Client and Manager for upstream connection management
- Added upstreamManager to Client for better callback handling. - Updated NewClient function to accept upstreamManager as a parameter. - Implemented ForceReconnect method in Manager to manage client reconnections. - Refactored OAuth flow to utilize upstreamManager for connection state management.
1 parent a76c5e4 commit c89d4bf

2 files changed

Lines changed: 49 additions & 23 deletions

File tree

internal/upstream/client.go

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ type Client struct {
9797
// Global configuration for accessing tracing settings
9898
globalConfig *config.Config
9999

100+
// Upstream manager for callbacks
101+
upstreamManager *Manager
102+
100103
// Connection state (protected by mutex)
101104
mu sync.RWMutex
102105
connected bool
@@ -144,12 +147,13 @@ type Tool struct {
144147
}
145148

146149
// NewClient creates a new MCP client for connecting to an upstream server
147-
func NewClient(id string, serverConfig *config.ServerConfig, logger *zap.Logger, logConfig *config.LogConfig, globalConfig *config.Config, storageManager *storage.Manager) (*Client, error) {
150+
func NewClient(id string, serverConfig *config.ServerConfig, logger *zap.Logger, logConfig *config.LogConfig, globalConfig *config.Config, storageManager *storage.Manager, upstreamManager *Manager) (*Client, error) {
148151
c := &Client{
149-
id: id,
150-
config: serverConfig,
151-
storageManager: storageManager,
152-
globalConfig: globalConfig,
152+
id: id,
153+
config: serverConfig,
154+
storageManager: storageManager,
155+
globalConfig: globalConfig,
156+
upstreamManager: upstreamManager,
153157
logger: logger.With(
154158
zap.String("upstream_id", id),
155159
zap.String("upstream_name", serverConfig.Name),
@@ -702,6 +706,10 @@ func (c *Client) createAutoOAuthConfig() *config.OAuthConfig {
702706

703707
// generateAuthenticationURL generates a direct authentication URL for local deployments
704708
func (c *Client) generateAuthenticationURL() string {
709+
return c.generateAuthenticationURLWithPKCE(nil)
710+
}
711+
712+
func (c *Client) generateAuthenticationURLWithPKCE(pkceParams *PKCEParams) string {
705713
if c.config.OAuth == nil {
706714
return ""
707715
}
@@ -796,25 +804,12 @@ func (c *Client) triggerOAuthFlowAsync(ctx context.Context) {
796804
}
797805

798806
c.logger.Info("🔐 OAUTH_TOKEN Automatic OAuth flow completed successfully")
807+
c.resetListToolsCircuitBreaker()
799808

800-
// Clear OAuth pending state before attempting to reconnect
801-
c.setOAuthPending(false, nil)
802-
803-
// Force reconnection after successful OAuth - bypass shouldAttemptConnection()
804-
// since we know OAuth just completed and we need to establish a fresh connection
809+
// After successful OAuth, force a reconnection to refresh the connection state
805810
c.logger.Info("🔐 OAUTH_TOKEN Forcing reconnection after successful OAuth")
806-
// Use a fresh context for reconnection to avoid any cancellation issues
807-
reconnectCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
808-
defer cancel()
809-
810-
if err := c.Connect(reconnectCtx); err != nil {
811-
c.logger.Error("🔐 OAUTH_TOKEN Failed to reconnect after automatic OAuth", zap.Error(err))
812-
c.mu.Lock()
813-
c.oauthError = err
814-
c.mu.Unlock()
815-
} else {
816-
c.logger.Info("🔐 OAUTH_TOKEN Successfully reconnected after OAuth flow")
817-
}
811+
c.upstreamManager.ForceReconnect(ctx, c.id) // Use the manager to force reconnect
812+
c.logger.Info("🔐 OAUTH_TOKEN Successfully reconnected after OAuth flow")
818813
}
819814

820815
// shouldUseLazyAuth returns true if lazy OAuth should be used
@@ -3632,3 +3627,25 @@ func (c *Client) getListToolsResult() ([]*config.ToolMetadata, error) {
36323627
defer c.mu.RUnlock()
36333628
return c.listToolsResult, c.listToolsError
36343629
}
3630+
3631+
func (c *Client) resetListToolsCircuitBreaker() {
3632+
c.mu.Lock()
3633+
defer c.mu.Unlock()
3634+
c.listToolsFailureCount = 0
3635+
c.listToolsCircuitOpen = false
3636+
c.listToolsLastFailure = time.Time{}
3637+
c.logger.Debug("ListTools circuit breaker has been reset")
3638+
}
3639+
3640+
func (c *Client) forceReconnect(ctx context.Context) {
3641+
c.logger.Info("Client force reconnect called")
3642+
_ = c.Disconnect()
3643+
go func() {
3644+
// Create a new context for the connection attempt
3645+
connectCtx, cancel := context.WithTimeout(context.Background(), c.getConnectionTimeout())
3646+
defer cancel()
3647+
if err := c.Connect(connectCtx); err != nil {
3648+
c.logger.Error("Failed to reconnect", zap.Error(err))
3649+
}
3650+
}()
3651+
}

internal/upstream/manager.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (m *Manager) AddServerConfig(id string, serverConfig *config.ServerConfig)
5858
}
5959

6060
// Create new client but don't connect yet
61-
client, err := NewClient(id, serverConfig, m.logger, m.logConfig, m.globalConfig, m.storageManager)
61+
client, err := NewClient(id, serverConfig, m.logger, m.logConfig, m.globalConfig, m.storageManager, m)
6262
if err != nil {
6363
return fmt.Errorf("failed to create client for server %s: %w", serverConfig.Name, err)
6464
}
@@ -94,6 +94,15 @@ func (m *Manager) AddServer(id string, serverConfig *config.ServerConfig) error
9494
return nil
9595
}
9696

97+
func (m *Manager) ForceReconnect(ctx context.Context, clientID string) {
98+
m.logger.Info("Manager forcing reconnect for client", zap.String("client_id", clientID))
99+
if client, exists := m.GetClient(clientID); exists {
100+
client.forceReconnect(ctx)
101+
} else {
102+
m.logger.Warn("Failed to force reconnect: client not found", zap.String("client_id", clientID))
103+
}
104+
}
105+
97106
// RemoveServer removes an upstream server
98107
func (m *Manager) RemoveServer(id string) {
99108
m.mu.Lock()

0 commit comments

Comments
 (0)