@@ -37,7 +37,6 @@ type Bridge struct {
3737 messageSizeLimit int64
3838 writeTimeout time.Duration
3939 pingInterval time.Duration
40- randSource * rand.Rand
4140}
4241
4342// New creates a new Bridge with optional configurations.
@@ -57,7 +56,6 @@ func New(oauthClient OAuthClient, opts ...Option) *Bridge {
5756 messageSizeLimit : 65536 , // 64KB
5857 writeTimeout : 10 * time .Second ,
5958 pingInterval : 30 * time .Second ,
60- randSource : rand .New (rand .NewSource (time .Now ().UnixNano ())),
6159 }
6260
6361 // Apply all the functional options provided by the user
@@ -85,9 +83,7 @@ func NewStandard(oauthClient OAuthClient, agentLabels map[string]string, opts ..
8583// MaintainWebSocket is the main entry point. It runs a loop that attempts
8684// to establish and manage a connection, with a backoff policy for retries.
8785func (b * Bridge ) MaintainWebSocket (ctx context.Context , connectionID string , endpointURL string , handler Handler ) error {
88- attempt := 0
8986 for {
90- start := time .Now ()
9187 err := b .manageConnection (ctx , connectionID , endpointURL , handler )
9288 if err != nil {
9389 var permanentErr * PermanentError
@@ -99,100 +95,96 @@ func (b *Bridge) MaintainWebSocket(ctx context.Context, connectionID string, end
9995 b .logger .Error (err , "Connection manager exited with recoverable error" , "connectionID" , connectionID )
10096 }
10197
102- // Reset attempt counter if the connection was stable for a while (e.g., 1 minute)
103- if time .Since (start ) > 1 * time .Minute {
104- attempt = 0
105- }
106-
10798 select {
10899 case <- ctx .Done ():
109100 b .logger .Info ("Context cancelled; shutting down bridge" , "connectionID" , connectionID )
110101 b .metrics .SetConnectionStatus (0 )
111102 return ctx .Err ()
112103 default :
113104 // Connection dropped for a recoverable reason, wait and retry.
114- backoff := b .calculateBackoff (attempt )
115- attempt ++
116- b .logger .Info ("Reconnecting" , "connectionID" , connectionID , "after" , backoff , "attempt" , attempt )
105+ backoff := b .calculateBackoff ()
106+ b .logger .Info ("Reconnecting" , "connectionID" , connectionID , "after" , backoff )
117107 time .Sleep (backoff )
118108 }
119109 }
120110}
121111
122- // MaintainGRPCConnection manages a persistent gRPC connection.
123- // It handles authentication, dialing, and reconnection.
124- // The 'run' function is called with the established ClientConn.
112+ // MaintainGRPCConnection manages a persistent gRPC connection with exponential
113+ // backoff and context-aware retry. The run callback receives each established
114+ // connection; its return value determines whether to retry, stop, or exit cleanly.
115+ //
116+ // Terminal conditions (no retry):
117+ // - run returns nil (clean exit)
118+ // - run returns ErrInteractionRequired (user must re-authenticate)
119+ // - run returns a *PermanentError
120+ // - context is cancelled
125121func (b * Bridge ) MaintainGRPCConnection (
126122 ctx context.Context ,
127123 connectionID string ,
128124 target string ,
129125 run func (ctx context.Context , conn * grpc.ClientConn ) error ,
130126 opts ... grpc.DialOption ,
131127) error {
128+ backoff := b .retryPolicy .MinBackoff
132129 attempt := 0
130+
133131 for {
134- start := time .Now ()
135- // 1. Prepare Credentials
136- // We use our custom PerRPCCredentials implementation
137- creds := NewBridgeCredentials (b .oauthClient , connectionID , b .refreshBuffer , b .logger )
132+ if attempt > 0 {
133+ wait := b .applyJitter (backoff )
134+ b .logger .Info ("Reconnecting gRPC" , "target" , target , "attempt" , attempt , "after" , wait )
135+ select {
136+ case <- ctx .Done ():
137+ b .logger .Info ("Context cancelled during backoff; stopping gRPC bridge" , "connectionID" , connectionID )
138+ return ctx .Err ()
139+ case <- time .After (wait ):
140+ }
141+ }
142+ attempt ++
138143
139- // 2. Dial Options
144+ creds := NewBridgeCredentials ( b . oauthClient , connectionID , b . refreshBuffer , b . logger )
140145 dialOpts := append (opts , grpc .WithPerRPCCredentials (creds ))
141146
142- // 3. Dial
143- b .logger .Info ("Dialing gRPC target" , "target" , target )
147+ b .logger .Info ("Dialing gRPC target" , "target" , target , "attempt" , attempt )
144148 conn , err := grpc .NewClient (target , dialOpts ... )
145149 if err != nil {
146- b .logger .Error (err , "Failed to dial gRPC target" , "target" , target )
147- goto Retry
150+ b .logger .Error (err , "Failed to dial gRPC target" , "target" , target , "attempt" , attempt )
151+ backoff = b .growBackoff (backoff )
152+ continue
148153 }
149154
150155 b .metrics .IncConnections ()
151156 b .metrics .SetConnectionStatus (1 )
152157 b .logger .Info ("gRPC connection established" , "target" , target )
153158
154- // 4. Run User Logic
155159 err = run (ctx , conn )
156-
157- // Cleanup
160+
158161 conn .Close ()
159162 b .metrics .SetConnectionStatus (0 )
160163 b .metrics .IncDisconnects ()
161164
162- // 5. Handle Error
163- if err != nil {
164- // Check if permanent
165- var permanentErr * PermanentError
166- if errors .As (err , & permanentErr ) {
167- b .logger .Error (err , "Permanent error in gRPC run loop; stopping" , "connectionID" , connectionID )
168- return err
169- }
170- // Check if Context Done
171- if errors .Is (err , ctx .Err ()) {
172- b .logger .Info ("Context cancelled; shutting down gRPC bridge" )
173- return err
174- }
175-
176- b .logger .Error (err , "gRPC run loop exited with error" , "connectionID" , connectionID )
177- } else {
165+ if err == nil {
178166 b .logger .Info ("gRPC run loop exited cleanly" , "connectionID" , connectionID )
167+ return nil
179168 }
180169
181- // Reset attempt counter if the connection was stable for a while
182- if time . Since ( start ) > 1 * time . Minute {
183- attempt = 0
170+ if errors . Is ( err , ErrInteractionRequired ) {
171+ b . logger . Error ( err , "Interaction required; stopping gRPC retry" , "connectionID" , connectionID )
172+ return err
184173 }
185174
186- Retry:
187- select {
188- case <- ctx .Done ():
175+ var permanentErr * PermanentError
176+ if errors .As (err , & permanentErr ) {
177+ b .logger .Error (err , "Permanent error in gRPC run loop; stopping" , "connectionID" , connectionID )
178+ return err
179+ }
180+
181+ if ctx .Err () != nil {
182+ b .logger .Info ("Context cancelled; shutting down gRPC bridge" , "connectionID" , connectionID )
189183 return ctx .Err ()
190- default :
191- backoff := b .calculateBackoff (attempt )
192- attempt ++
193- b .logger .Info ("Reconnecting gRPC" , "after" , backoff , "attempt" , attempt )
194- time .Sleep (backoff )
195184 }
185+
186+ b .logger .Error (err , "gRPC run loop exited with error; will retry" , "connectionID" , connectionID , "attempt" , attempt )
187+ backoff = b .growBackoff (backoff )
196188 }
197189}
198190
@@ -381,19 +373,29 @@ func (b *Bridge) manageConnection(ctx context.Context, connectionID string, endp
381373 }
382374}
383375
384- // NEW: Helper function for calculating backoff with jitter.
385- func (b * Bridge ) calculateBackoff (attempt int ) time.Duration {
386- if attempt < 0 {
387- attempt = 0
376+ // growBackoff doubles the current backoff, capping at MaxBackoff.
377+ func (b * Bridge ) growBackoff (current time.Duration ) time.Duration {
378+ next := current * 2
379+ if next > b .retryPolicy .MaxBackoff || next <= 0 {
380+ return b .retryPolicy .MaxBackoff
388381 }
389- if attempt > 10 {
390- attempt = 10
382+ return next
383+ }
384+
385+ // applyJitter adds random jitter to a duration to prevent thundering herd
386+ // when multiple agents reconnect simultaneously after a gateway restart.
387+ func (b * Bridge ) applyJitter (d time.Duration ) time.Duration {
388+ if b .retryPolicy .Jitter <= 0 {
389+ return d
391390 }
392- factor := 1 << uint (attempt )
393- base := float64 (b .retryPolicy .MinBackoff ) * float64 (factor )
394- if base > float64 (b .retryPolicy .MaxBackoff ) {
395- base = float64 (b .retryPolicy .MaxBackoff )
391+ return d + time .Duration (rand .Int63n (int64 (b .retryPolicy .Jitter )))
392+ }
393+
394+ // calculateBackoff returns a flat backoff with jitter (used by MaintainWebSocket).
395+ func (b * Bridge ) calculateBackoff () time.Duration {
396+ backoff := b .retryPolicy .MinBackoff + time .Duration (rand .Int63n (int64 (b .retryPolicy .Jitter )))
397+ if backoff > b .retryPolicy .MaxBackoff {
398+ return b .retryPolicy .MaxBackoff
396399 }
397- jitter := 0.2 + b .randSource .Float64 ()* 0.6 // 0.2..0.8
398- return time .Duration (base * jitter )
400+ return backoff
399401}
0 commit comments