diff --git a/githubapp/config.go b/githubapp/config.go index 5e5e5eaa..c41d7400 100644 --- a/githubapp/config.go +++ b/githubapp/config.go @@ -17,6 +17,7 @@ package githubapp import ( "os" "strconv" + "time" ) type Config struct { @@ -24,6 +25,10 @@ type Config struct { V3APIURL string `yaml:"v3_api_url" json:"v3ApiUrl"` V4APIURL string `yaml:"v4_api_url" json:"v4ApiUrl"` + // RateLimitMaxWait is the maximum duration to wait before retrying a + // rate-limited (403) GitHub API request. Zero disables retry. + RateLimitMaxWait time.Duration `yaml:"rate_limit_max_wait" json:"rateLimitMaxWait"` + App struct { IntegrationID int64 `yaml:"integration_id" json:"integrationId"` WebhookSecret string `yaml:"webhook_secret" json:"webhookSecret"` @@ -43,6 +48,7 @@ func (c *Config) SetValuesFromEnv(prefix string) { setStringFromEnv("GITHUB_WEB_URL", prefix, &c.WebURL) setStringFromEnv("GITHUB_V3_API_URL", prefix, &c.V3APIURL) setStringFromEnv("GITHUB_V4_API_URL", prefix, &c.V4APIURL) + setDurationFromEnv("GITHUB_RATE_LIMIT_MAX_WAIT", prefix, &c.RateLimitMaxWait) setIntFromEnv("GITHUB_APP_INTEGRATION_ID", prefix, &c.App.IntegrationID) setStringFromEnv("GITHUB_APP_WEBHOOK_SECRET", prefix, &c.App.WebhookSecret) @@ -65,3 +71,11 @@ func setIntFromEnv(key, prefix string, value *int64) { } } } + +func setDurationFromEnv(key, prefix string, value *time.Duration) { + if v, ok := os.LookupEnv(prefix + key); ok { + if d, err := time.ParseDuration(v); err == nil { + *value = d + } + } +} diff --git a/githubapp/middleware_retry.go b/githubapp/middleware_retry.go new file mode 100644 index 00000000..acd67053 --- /dev/null +++ b/githubapp/middleware_retry.go @@ -0,0 +1,86 @@ +// Copyright 2024 Palantir Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package githubapp + +import ( + "net/http" + "strconv" + "time" +) + +const ( + headerRetryAfter = "Retry-After" +) + +// RateLimitRetry returns a ClientMiddleware that transparently retries a +// request exactly once when GitHub responds with 403 and a rate-limit wait +// header (Retry-After or X-RateLimit-Reset). If the required wait exceeds +// maxWait the original 403 response is returned without retrying. +func RateLimitRetry(maxWait time.Duration) ClientMiddleware { + return func(next http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(r *http.Request) (*http.Response, error) { + res, err := next.RoundTrip(r) + if err != nil || res == nil || res.StatusCode != http.StatusForbidden { + return res, err + } + + wait := rateLimitWait(res) + if wait <= 0 || wait > maxWait { + return res, err + } + + // drain and close the 403 body before retrying + closeBody(res.Body) + + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-r.Context().Done(): + return res, r.Context().Err() + case <-timer.C: + } + + return next.RoundTrip(r) + }) + } +} + +// WithRateLimitRetry adds rate-limit retry middleware to all clients created +// by the ClientCreator. See RateLimitRetry for retry semantics. +func WithRateLimitRetry(maxWait time.Duration) ClientOption { + return WithClientMiddleware(RateLimitRetry(maxWait)) +} + +// rateLimitWait returns the duration to wait based on GitHub rate-limit +// response headers. It checks Retry-After first, then X-RateLimit-Reset. +// Returns 0 if neither header is present or parseable. +func rateLimitWait(res *http.Response) time.Duration { + if v := res.Header.Get(headerRetryAfter); v != "" { + if secs, err := strconv.ParseInt(v, 10, 64); err == nil && secs > 0 { + return time.Duration(secs) * time.Second + } + } + + if v := res.Header.Get(httpHeaderRateReset); v != "" { + if epoch, err := strconv.ParseInt(v, 10, 64); err == nil { + wait := time.Until(time.Unix(epoch, 0)) + if wait > 0 { + return wait + } + } + } + + return 0 +} diff --git a/githubapp/middleware_retry_test.go b/githubapp/middleware_retry_test.go new file mode 100644 index 00000000..45f82f49 --- /dev/null +++ b/githubapp/middleware_retry_test.go @@ -0,0 +1,215 @@ +// Copyright 2024 Palantir Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package githubapp + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// mockRoundTripper records how many times it was called and returns the +// configured responses in order. The final response is repeated for any +// additional calls. +type mockRoundTripper struct { + responses []*http.Response + calls int +} + +func (m *mockRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + idx := m.calls + if idx >= len(m.responses) { + idx = len(m.responses) - 1 + } + m.calls++ + return m.responses[idx], nil +} + +func makeResponse(status int, headers map[string]string) *http.Response { + h := make(http.Header) + for k, v := range headers { + h.Set(k, v) + } + return &http.Response{ + StatusCode: status, + Header: h, + Body: io.NopCloser(strings.NewReader("")), + } +} + +func TestRateLimitRetry_RetryAfterHeader(t *testing.T) { + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "1"}), + makeResponse(http.StatusOK, nil), + }, + } + + mw := RateLimitRetry(5 * time.Second) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + start := time.Now() + res, err := rt.RoundTrip(req) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusOK) + } + if mock.calls != 2 { + t.Errorf("RoundTrip calls: got %d, want 2", mock.calls) + } + if elapsed < time.Second { + t.Errorf("expected at least 1s sleep, got %v", elapsed) + } +} + +func TestRateLimitRetry_ResetHeader(t *testing.T) { + resetAt := time.Now().Add(1 * time.Second).Unix() + + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusForbidden, map[string]string{ + httpHeaderRateReset: fmt.Sprintf("%d", resetAt), + }), + makeResponse(http.StatusOK, nil), + }, + } + + mw := RateLimitRetry(5 * time.Second) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + start := time.Now() + res, err := rt.RoundTrip(req) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusOK { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusOK) + } + if mock.calls != 2 { + t.Errorf("RoundTrip calls: got %d, want 2", mock.calls) + } + if elapsed < 100*time.Millisecond { + t.Errorf("expected ~1s sleep, got %v", elapsed) + } +} + +func TestRateLimitRetry_WaitExceedsMaxWait(t *testing.T) { + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "60"}), + makeResponse(http.StatusOK, nil), + }, + } + + mw := RateLimitRetry(5 * time.Second) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + res, err := rt.RoundTrip(req) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusForbidden { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden) + } + if mock.calls != 1 { + t.Errorf("RoundTrip calls: got %d, want 1 (should not retry)", mock.calls) + } +} + +func TestRateLimitRetry_NoRateLimitHeaders(t *testing.T) { + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusForbidden, nil), + }, + } + + mw := RateLimitRetry(5 * time.Second) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + res, err := rt.RoundTrip(req) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusForbidden { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden) + } + if mock.calls != 1 { + t.Errorf("RoundTrip calls: got %d, want 1 (no rate limit header)", mock.calls) + } +} + +func TestRateLimitRetry_Non403PassThrough(t *testing.T) { + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusTooManyRequests, map[string]string{headerRetryAfter: "1"}), + }, + } + + mw := RateLimitRetry(5 * time.Second) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + res, err := rt.RoundTrip(req) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusTooManyRequests { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusTooManyRequests) + } + if mock.calls != 1 { + t.Errorf("RoundTrip calls: got %d, want 1 (non-403 not retried)", mock.calls) + } +} + +func TestRateLimitRetry_ZeroMaxWaitDisablesRetry(t *testing.T) { + mock := &mockRoundTripper{ + responses: []*http.Response{ + makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "1"}), + makeResponse(http.StatusOK, nil), + }, + } + + mw := RateLimitRetry(0) + rt := mw(mock) + + req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil) + res, err := rt.RoundTrip(req) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.StatusCode != http.StatusForbidden { + t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden) + } + if mock.calls != 1 { + t.Errorf("RoundTrip calls: got %d, want 1 (maxWait=0 disables retry)", mock.calls) + } +}