|
| 1 | +/* |
| 2 | + * Copyright 2024 The CNAI Authors |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package retrypolicy |
| 18 | + |
| 19 | +import ( |
| 20 | + "context" |
| 21 | + "errors" |
| 22 | + "fmt" |
| 23 | + "math" |
| 24 | + "regexp" |
| 25 | + "strings" |
| 26 | + "time" |
| 27 | + |
| 28 | + retry "github.com/avast/retry-go/v4" |
| 29 | + log "github.com/sirupsen/logrus" |
| 30 | +) |
| 31 | + |
| 32 | +const ( |
| 33 | + oneGB = 1 << 30 // 1 GiB in bytes |
| 34 | + tenGB = 10 << 30 |
| 35 | + nineGB = tenGB - oneGB |
| 36 | + |
| 37 | + minMaxRetryTime = 10 * time.Minute |
| 38 | + maxMaxRetryTime = 60 * time.Minute |
| 39 | + |
| 40 | + minMaxBackoff = 1 * time.Minute |
| 41 | + maxMaxBackoff = 10 * time.Minute |
| 42 | + |
| 43 | + initialDelay = 5 * time.Second |
| 44 | + maxJitter = 5 * time.Second |
| 45 | + |
| 46 | + // maxBackoff cap when derived from user-specified MaxRetryTime |
| 47 | + absoluteMaxBackoff = 10 * time.Minute |
| 48 | +) |
| 49 | + |
| 50 | +// Config holds user-configurable retry parameters from CLI flags. |
| 51 | +type Config struct { |
| 52 | + MaxRetryTime time.Duration // 0 = dynamic based on file size |
| 53 | + NoRetry bool // disable retry entirely |
| 54 | +} |
| 55 | + |
| 56 | +// DoOpts configures a single Do call. |
| 57 | +type DoOpts struct { |
| 58 | + FileSize int64 // for dynamic parameter calculation |
| 59 | + FileName string // for logging |
| 60 | + Config *Config |
| 61 | + OnRetry func(attempt uint, reason string, backoff time.Duration) |
| 62 | +} |
| 63 | + |
| 64 | +// Do executes fn with retry. It computes dynamic retry parameters from fileSize, |
| 65 | +// creates an internal deadline context (and defers its cancel to prevent leak), |
| 66 | +// sets up retry logging, and calls retry.Do. |
| 67 | +// The parent ctx is only used for user-initiated cancellation. |
| 68 | +func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) error { |
| 69 | + cfg := opts.Config |
| 70 | + if cfg == nil { |
| 71 | + cfg = &Config{} |
| 72 | + } |
| 73 | + |
| 74 | + // NoRetry: call fn once with the parent context, return the result. |
| 75 | + if cfg.NoRetry { |
| 76 | + return fn(ctx) |
| 77 | + } |
| 78 | + |
| 79 | + maxRetryTime, maxBackoff := computeDynamicParams(opts.FileSize) |
| 80 | + |
| 81 | + // Override with user-specified MaxRetryTime if set. |
| 82 | + if cfg.MaxRetryTime > 0 { |
| 83 | + maxRetryTime = cfg.MaxRetryTime |
| 84 | + maxBackoff = cfg.MaxRetryTime / 6 |
| 85 | + if maxBackoff > absoluteMaxBackoff { |
| 86 | + maxBackoff = absoluteMaxBackoff |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + startTime := time.Now() |
| 91 | + deadlineCtx, deadlineCancel := context.WithDeadline(ctx, startTime.Add(maxRetryTime)) |
| 92 | + defer deadlineCancel() |
| 93 | + |
| 94 | + sizeStr := humanizeBytes(opts.FileSize) |
| 95 | + |
| 96 | + return retry.Do( |
| 97 | + func() error { |
| 98 | + return fn(deadlineCtx) |
| 99 | + }, |
| 100 | + retry.Attempts(0), |
| 101 | + retry.Context(deadlineCtx), |
| 102 | + retry.DelayType(retry.BackOffDelay), |
| 103 | + retry.Delay(initialDelay), |
| 104 | + retry.MaxDelay(maxBackoff), |
| 105 | + retry.MaxJitter(maxJitter), |
| 106 | + retry.LastErrorOnly(true), |
| 107 | + retry.WrapContextErrorWithLastError(true), |
| 108 | + retry.RetryIf(func(err error) bool { |
| 109 | + retryable := IsRetryable(err) |
| 110 | + if !retryable { |
| 111 | + log.WithFields(log.Fields{ |
| 112 | + "file": opts.FileName, |
| 113 | + "size": sizeStr, |
| 114 | + "error": err.Error(), |
| 115 | + }).Error("[RETRY] non-retryable error, not retrying") |
| 116 | + } |
| 117 | + return retryable |
| 118 | + }), |
| 119 | + retry.OnRetry(func(n uint, err error) { |
| 120 | + backoff := computeBackoff(n+1, initialDelay, maxBackoff) |
| 121 | + elapsed := time.Since(startTime) |
| 122 | + |
| 123 | + log.WithFields(log.Fields{ |
| 124 | + "file": opts.FileName, |
| 125 | + "size": sizeStr, |
| 126 | + "error": err.Error(), |
| 127 | + "max_retry_time": maxRetryTime.String(), |
| 128 | + "max_backoff": maxBackoff.String(), |
| 129 | + "next_retry_in": backoff.Truncate(time.Second).String(), |
| 130 | + "elapsed": fmt.Sprintf("%s / %s", elapsed.Truncate(time.Second), maxRetryTime), |
| 131 | + }).Warnf("[RETRY] attempt %d for %q (%s)", n+1, opts.FileName, sizeStr) |
| 132 | + |
| 133 | + if opts.OnRetry != nil { |
| 134 | + reason := ShortReason(err) |
| 135 | + opts.OnRetry(n+1, reason, backoff) |
| 136 | + } |
| 137 | + }), |
| 138 | + ) |
| 139 | +} |
| 140 | + |
| 141 | +// computeDynamicParams calculates maxRetryTime and maxBackoff based on file size. |
| 142 | +// |
| 143 | +// For files <= 1 GB: maxRetryTime=10min, maxBackoff=1min |
| 144 | +// For files >= 10 GB: maxRetryTime=60min, maxBackoff=10min |
| 145 | +// Linear interpolation between. |
| 146 | +func computeDynamicParams(fileSize int64) (time.Duration, time.Duration) { |
| 147 | + ratio := float64(fileSize-oneGB) / float64(nineGB) |
| 148 | + if ratio < 0 { |
| 149 | + ratio = 0 |
| 150 | + } |
| 151 | + if ratio > 1 { |
| 152 | + ratio = 1 |
| 153 | + } |
| 154 | + |
| 155 | + maxRetryTime := minMaxRetryTime + time.Duration(ratio*float64(maxMaxRetryTime-minMaxRetryTime)) |
| 156 | + maxBackoff := minMaxBackoff + time.Duration(ratio*float64(maxMaxBackoff-minMaxBackoff)) |
| 157 | + |
| 158 | + return maxRetryTime, maxBackoff |
| 159 | +} |
| 160 | + |
| 161 | +// computeBackoff estimates the backoff duration for display purposes. |
| 162 | +// It mirrors the exponential backoff calculation without jitter. |
| 163 | +func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration { |
| 164 | + if attempt == 0 { |
| 165 | + return initial |
| 166 | + } |
| 167 | + backoff := time.Duration(float64(initial) * math.Pow(2, float64(attempt-1))) |
| 168 | + if backoff > maxDelay { |
| 169 | + backoff = maxDelay |
| 170 | + } |
| 171 | + return backoff |
| 172 | +} |
| 173 | + |
| 174 | +// httpStatusPattern matches ORAS-style error messages that embed HTTP status codes. |
| 175 | +var httpStatusPattern = regexp.MustCompile(`response status code (\d{3})`) |
| 176 | + |
| 177 | +// IsRetryable returns true for transient errors that warrant a retry. |
| 178 | +func IsRetryable(err error) bool { |
| 179 | + if err == nil { |
| 180 | + return false |
| 181 | + } |
| 182 | + |
| 183 | + // context.Canceled is never retryable — it means user/system cancellation. |
| 184 | + if errors.Is(err, context.Canceled) { |
| 185 | + return false |
| 186 | + } |
| 187 | + |
| 188 | + errMsg := err.Error() |
| 189 | + |
| 190 | + // Check for HTTP status codes embedded in error messages (ORAS style). |
| 191 | + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { |
| 192 | + code := matches[1] |
| 193 | + // 5xx server errors are retryable. |
| 194 | + if code[0] == '5' { |
| 195 | + return true |
| 196 | + } |
| 197 | + // 408 (Request Timeout) and 429 (Too Many Requests) are retryable. |
| 198 | + if code == "408" || code == "429" { |
| 199 | + return true |
| 200 | + } |
| 201 | + // Other 4xx are not retryable (401, 403, 404, etc.) |
| 202 | + return false |
| 203 | + } |
| 204 | + |
| 205 | + // Network-level transient errors. |
| 206 | + if strings.Contains(errMsg, "i/o timeout") { |
| 207 | + return true |
| 208 | + } |
| 209 | + if strings.Contains(errMsg, "connection reset by peer") { |
| 210 | + return true |
| 211 | + } |
| 212 | + if strings.Contains(errMsg, "connection refused") { |
| 213 | + return true |
| 214 | + } |
| 215 | + if strings.Contains(errMsg, "broken pipe") { |
| 216 | + return true |
| 217 | + } |
| 218 | + if strings.Contains(errMsg, "EOF") { |
| 219 | + return true |
| 220 | + } |
| 221 | + |
| 222 | + // Unknown errors default to retryable. |
| 223 | + return true |
| 224 | +} |
| 225 | + |
| 226 | +// ShortReason extracts a brief human-readable label from an error for progress bar display. |
| 227 | +func ShortReason(err error) string { |
| 228 | + if err == nil { |
| 229 | + return "" |
| 230 | + } |
| 231 | + |
| 232 | + errMsg := err.Error() |
| 233 | + |
| 234 | + // Check for HTTP status codes. |
| 235 | + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { |
| 236 | + return "HTTP " + matches[1] |
| 237 | + } |
| 238 | + |
| 239 | + if strings.Contains(errMsg, "i/o timeout") { |
| 240 | + return "i/o timeout" |
| 241 | + } |
| 242 | + if strings.Contains(errMsg, "connection reset by peer") { |
| 243 | + return "conn reset" |
| 244 | + } |
| 245 | + if strings.Contains(errMsg, "connection refused") { |
| 246 | + return "conn refused" |
| 247 | + } |
| 248 | + if strings.Contains(errMsg, "broken pipe") { |
| 249 | + return "broken pipe" |
| 250 | + } |
| 251 | + if strings.Contains(errMsg, "EOF") { |
| 252 | + return "EOF" |
| 253 | + } |
| 254 | + |
| 255 | + return "unknown error" |
| 256 | +} |
| 257 | + |
| 258 | +// humanizeBytes converts a byte count to a human-readable string. |
| 259 | +func humanizeBytes(b int64) string { |
| 260 | + const ( |
| 261 | + kb = 1024 |
| 262 | + mb = 1024 * kb |
| 263 | + gb = 1024 * mb |
| 264 | + ) |
| 265 | + |
| 266 | + switch { |
| 267 | + case b >= gb: |
| 268 | + return fmt.Sprintf("%.1f GB", float64(b)/float64(gb)) |
| 269 | + case b >= mb: |
| 270 | + return fmt.Sprintf("%.1f MB", float64(b)/float64(mb)) |
| 271 | + case b >= kb: |
| 272 | + return fmt.Sprintf("%.1f KB", float64(b)/float64(kb)) |
| 273 | + default: |
| 274 | + return fmt.Sprintf("%d B", b) |
| 275 | + } |
| 276 | +} |
0 commit comments