Skip to content

Commit 4f71ec6

Browse files
committed
Merge PR #39: refactor(bridge): replace goto Retry with for-loop in MaintainGRPCConnection
2 parents e7b9d6e + 19bf2b4 commit 4f71ec6

3 files changed

Lines changed: 320 additions & 95 deletions

File tree

nexus-bridge/bridge.go

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ type Bridge struct {
3737
messageSizeLimit int64
3838
writeTimeout time.Duration
3939
pingInterval time.Duration
40-
randSource *rand.Rand
4140
}
4241

4342
// New creates a new Bridge with optional configurations.
@@ -57,7 +56,6 @@ func New(oauthClient OAuthClient, opts ...Option) *Bridge {
5756
messageSizeLimit: 65536, // 64KB
5857
writeTimeout: 10 * time.Second,
5958
pingInterval: 30 * time.Second,
60-
randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
6159
}
6260

6361
// Apply all the functional options provided by the user
@@ -85,9 +83,7 @@ func NewStandard(oauthClient OAuthClient, agentLabels map[string]string, opts ..
8583
// MaintainWebSocket is the main entry point. It runs a loop that attempts
8684
// to establish and manage a connection, with a backoff policy for retries.
8785
func (b *Bridge) MaintainWebSocket(ctx context.Context, connectionID string, endpointURL string, handler Handler) error {
88-
attempt := 0
8986
for {
90-
start := time.Now()
9187
err := b.manageConnection(ctx, connectionID, endpointURL, handler)
9288
if err != nil {
9389
var permanentErr *PermanentError
@@ -99,100 +95,96 @@ func (b *Bridge) MaintainWebSocket(ctx context.Context, connectionID string, end
9995
b.logger.Error(err, "Connection manager exited with recoverable error", "connectionID", connectionID)
10096
}
10197

102-
// Reset attempt counter if the connection was stable for a while (e.g., 1 minute)
103-
if time.Since(start) > 1*time.Minute {
104-
attempt = 0
105-
}
106-
10798
select {
10899
case <-ctx.Done():
109100
b.logger.Info("Context cancelled; shutting down bridge", "connectionID", connectionID)
110101
b.metrics.SetConnectionStatus(0)
111102
return ctx.Err()
112103
default:
113104
// Connection dropped for a recoverable reason, wait and retry.
114-
backoff := b.calculateBackoff(attempt)
115-
attempt++
116-
b.logger.Info("Reconnecting", "connectionID", connectionID, "after", backoff, "attempt", attempt)
105+
backoff := b.calculateBackoff()
106+
b.logger.Info("Reconnecting", "connectionID", connectionID, "after", backoff)
117107
time.Sleep(backoff)
118108
}
119109
}
120110
}
121111

122-
// MaintainGRPCConnection manages a persistent gRPC connection.
123-
// It handles authentication, dialing, and reconnection.
124-
// The 'run' function is called with the established ClientConn.
112+
// MaintainGRPCConnection manages a persistent gRPC connection with exponential
113+
// backoff and context-aware retry. The run callback receives each established
114+
// connection; its return value determines whether to retry, stop, or exit cleanly.
115+
//
116+
// Terminal conditions (no retry):
117+
// - run returns nil (clean exit)
118+
// - run returns ErrInteractionRequired (user must re-authenticate)
119+
// - run returns a *PermanentError
120+
// - context is cancelled
125121
func (b *Bridge) MaintainGRPCConnection(
126122
ctx context.Context,
127123
connectionID string,
128124
target string,
129125
run func(ctx context.Context, conn *grpc.ClientConn) error,
130126
opts ...grpc.DialOption,
131127
) error {
128+
backoff := b.retryPolicy.MinBackoff
132129
attempt := 0
130+
133131
for {
134-
start := time.Now()
135-
// 1. Prepare Credentials
136-
// We use our custom PerRPCCredentials implementation
137-
creds := NewBridgeCredentials(b.oauthClient, connectionID, b.refreshBuffer, b.logger)
132+
if attempt > 0 {
133+
wait := b.applyJitter(backoff)
134+
b.logger.Info("Reconnecting gRPC", "target", target, "attempt", attempt, "after", wait)
135+
select {
136+
case <-ctx.Done():
137+
b.logger.Info("Context cancelled during backoff; stopping gRPC bridge", "connectionID", connectionID)
138+
return ctx.Err()
139+
case <-time.After(wait):
140+
}
141+
}
142+
attempt++
138143

139-
// 2. Dial Options
144+
creds := NewBridgeCredentials(b.oauthClient, connectionID, b.refreshBuffer, b.logger)
140145
dialOpts := append(opts, grpc.WithPerRPCCredentials(creds))
141146

142-
// 3. Dial
143-
b.logger.Info("Dialing gRPC target", "target", target)
147+
b.logger.Info("Dialing gRPC target", "target", target, "attempt", attempt)
144148
conn, err := grpc.NewClient(target, dialOpts...)
145149
if err != nil {
146-
b.logger.Error(err, "Failed to dial gRPC target", "target", target)
147-
goto Retry
150+
b.logger.Error(err, "Failed to dial gRPC target", "target", target, "attempt", attempt)
151+
backoff = b.growBackoff(backoff)
152+
continue
148153
}
149154

150155
b.metrics.IncConnections()
151156
b.metrics.SetConnectionStatus(1)
152157
b.logger.Info("gRPC connection established", "target", target)
153158

154-
// 4. Run User Logic
155159
err = run(ctx, conn)
156-
157-
// Cleanup
160+
158161
conn.Close()
159162
b.metrics.SetConnectionStatus(0)
160163
b.metrics.IncDisconnects()
161164

162-
// 5. Handle Error
163-
if err != nil {
164-
// Check if permanent
165-
var permanentErr *PermanentError
166-
if errors.As(err, &permanentErr) {
167-
b.logger.Error(err, "Permanent error in gRPC run loop; stopping", "connectionID", connectionID)
168-
return err
169-
}
170-
// Check if Context Done
171-
if errors.Is(err, ctx.Err()) {
172-
b.logger.Info("Context cancelled; shutting down gRPC bridge")
173-
return err
174-
}
175-
176-
b.logger.Error(err, "gRPC run loop exited with error", "connectionID", connectionID)
177-
} else {
165+
if err == nil {
178166
b.logger.Info("gRPC run loop exited cleanly", "connectionID", connectionID)
167+
return nil
179168
}
180169

181-
// Reset attempt counter if the connection was stable for a while
182-
if time.Since(start) > 1*time.Minute {
183-
attempt = 0
170+
if errors.Is(err, ErrInteractionRequired) {
171+
b.logger.Error(err, "Interaction required; stopping gRPC retry", "connectionID", connectionID)
172+
return err
184173
}
185174

186-
Retry:
187-
select {
188-
case <-ctx.Done():
175+
var permanentErr *PermanentError
176+
if errors.As(err, &permanentErr) {
177+
b.logger.Error(err, "Permanent error in gRPC run loop; stopping", "connectionID", connectionID)
178+
return err
179+
}
180+
181+
if ctx.Err() != nil {
182+
b.logger.Info("Context cancelled; shutting down gRPC bridge", "connectionID", connectionID)
189183
return ctx.Err()
190-
default:
191-
backoff := b.calculateBackoff(attempt)
192-
attempt++
193-
b.logger.Info("Reconnecting gRPC", "after", backoff, "attempt", attempt)
194-
time.Sleep(backoff)
195184
}
185+
186+
b.logger.Error(err, "gRPC run loop exited with error; will retry", "connectionID", connectionID, "attempt", attempt)
187+
backoff = b.growBackoff(backoff)
196188
}
197189
}
198190

@@ -381,19 +373,29 @@ func (b *Bridge) manageConnection(ctx context.Context, connectionID string, endp
381373
}
382374
}
383375

384-
// NEW: Helper function for calculating backoff with jitter.
385-
func (b *Bridge) calculateBackoff(attempt int) time.Duration {
386-
if attempt < 0 {
387-
attempt = 0
376+
// growBackoff doubles the current backoff, capping at MaxBackoff.
377+
func (b *Bridge) growBackoff(current time.Duration) time.Duration {
378+
next := current * 2
379+
if next > b.retryPolicy.MaxBackoff || next <= 0 {
380+
return b.retryPolicy.MaxBackoff
388381
}
389-
if attempt > 10 {
390-
attempt = 10
382+
return next
383+
}
384+
385+
// applyJitter adds random jitter to a duration to prevent thundering herd
386+
// when multiple agents reconnect simultaneously after a gateway restart.
387+
func (b *Bridge) applyJitter(d time.Duration) time.Duration {
388+
if b.retryPolicy.Jitter <= 0 {
389+
return d
391390
}
392-
factor := 1 << uint(attempt)
393-
base := float64(b.retryPolicy.MinBackoff) * float64(factor)
394-
if base > float64(b.retryPolicy.MaxBackoff) {
395-
base = float64(b.retryPolicy.MaxBackoff)
391+
return d + time.Duration(rand.Int63n(int64(b.retryPolicy.Jitter)))
392+
}
393+
394+
// calculateBackoff returns a flat backoff with jitter (used by MaintainWebSocket).
395+
func (b *Bridge) calculateBackoff() time.Duration {
396+
backoff := b.retryPolicy.MinBackoff + time.Duration(rand.Int63n(int64(b.retryPolicy.Jitter)))
397+
if backoff > b.retryPolicy.MaxBackoff {
398+
return b.retryPolicy.MaxBackoff
396399
}
397-
jitter := 0.2 + b.randSource.Float64()*0.6 // 0.2..0.8
398-
return time.Duration(base * jitter)
400+
return backoff
399401
}

0 commit comments

Comments
 (0)