Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions githubapp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ package githubapp
import (
"os"
"strconv"
"time"
)

type Config struct {
WebURL string `yaml:"web_url" json:"webUrl"`
V3APIURL string `yaml:"v3_api_url" json:"v3ApiUrl"`
V4APIURL string `yaml:"v4_api_url" json:"v4ApiUrl"`

// RateLimitMaxWait is the maximum duration to wait before retrying a
// rate-limited (403) GitHub API request. Zero disables retry.
RateLimitMaxWait time.Duration `yaml:"rate_limit_max_wait" json:"rateLimitMaxWait"`

App struct {
IntegrationID int64 `yaml:"integration_id" json:"integrationId"`
WebhookSecret string `yaml:"webhook_secret" json:"webhookSecret"`
Expand All @@ -43,6 +48,7 @@ func (c *Config) SetValuesFromEnv(prefix string) {
setStringFromEnv("GITHUB_WEB_URL", prefix, &c.WebURL)
setStringFromEnv("GITHUB_V3_API_URL", prefix, &c.V3APIURL)
setStringFromEnv("GITHUB_V4_API_URL", prefix, &c.V4APIURL)
setDurationFromEnv("GITHUB_RATE_LIMIT_MAX_WAIT", prefix, &c.RateLimitMaxWait)

setIntFromEnv("GITHUB_APP_INTEGRATION_ID", prefix, &c.App.IntegrationID)
setStringFromEnv("GITHUB_APP_WEBHOOK_SECRET", prefix, &c.App.WebhookSecret)
Expand All @@ -65,3 +71,11 @@ func setIntFromEnv(key, prefix string, value *int64) {
}
}
}

func setDurationFromEnv(key, prefix string, value *time.Duration) {
if v, ok := os.LookupEnv(prefix + key); ok {
if d, err := time.ParseDuration(v); err == nil {
*value = d
}
}
}
86 changes: 86 additions & 0 deletions githubapp/middleware_retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package githubapp

import (
"net/http"
"strconv"
"time"
)

const (
headerRetryAfter = "Retry-After"
)

// RateLimitRetry returns a ClientMiddleware that transparently retries a
// request exactly once when GitHub responds with 403 and a rate-limit wait
// header (Retry-After or X-RateLimit-Reset). If the required wait exceeds
// maxWait the original 403 response is returned without retrying.
func RateLimitRetry(maxWait time.Duration) ClientMiddleware {
return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
res, err := next.RoundTrip(r)
if err != nil || res == nil || res.StatusCode != http.StatusForbidden {
return res, err
}

wait := rateLimitWait(res)
if wait <= 0 || wait > maxWait {
return res, err
}

// drain and close the 403 body before retrying
closeBody(res.Body)

timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-r.Context().Done():
return res, r.Context().Err()
case <-timer.C:
}

return next.RoundTrip(r)
})
}
}

// WithRateLimitRetry adds rate-limit retry middleware to all clients created
// by the ClientCreator. See RateLimitRetry for retry semantics.
func WithRateLimitRetry(maxWait time.Duration) ClientOption {
return WithClientMiddleware(RateLimitRetry(maxWait))
}

// rateLimitWait returns the duration to wait based on GitHub rate-limit
// response headers. It checks Retry-After first, then X-RateLimit-Reset.
// Returns 0 if neither header is present or parseable.
func rateLimitWait(res *http.Response) time.Duration {
if v := res.Header.Get(headerRetryAfter); v != "" {
if secs, err := strconv.ParseInt(v, 10, 64); err == nil && secs > 0 {
return time.Duration(secs) * time.Second
}
}

if v := res.Header.Get(httpHeaderRateReset); v != "" {
if epoch, err := strconv.ParseInt(v, 10, 64); err == nil {
wait := time.Until(time.Unix(epoch, 0))
if wait > 0 {
return wait
}
}
}

return 0
}
215 changes: 215 additions & 0 deletions githubapp/middleware_retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Copyright 2024 Palantir Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package githubapp

import (
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
)

// mockRoundTripper records how many times it was called and returns the
// configured responses in order. The final response is repeated for any
// additional calls.
type mockRoundTripper struct {
responses []*http.Response
calls int
}

func (m *mockRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
idx := m.calls
if idx >= len(m.responses) {
idx = len(m.responses) - 1
}
m.calls++
return m.responses[idx], nil
}

func makeResponse(status int, headers map[string]string) *http.Response {
h := make(http.Header)
for k, v := range headers {
h.Set(k, v)
}
return &http.Response{
StatusCode: status,
Header: h,
Body: io.NopCloser(strings.NewReader("")),
}
}

func TestRateLimitRetry_RetryAfterHeader(t *testing.T) {
mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "1"}),
makeResponse(http.StatusOK, nil),
},
}

mw := RateLimitRetry(5 * time.Second)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
start := time.Now()
res, err := rt.RoundTrip(req)
elapsed := time.Since(start)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusOK {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusOK)
}
if mock.calls != 2 {
t.Errorf("RoundTrip calls: got %d, want 2", mock.calls)
}
if elapsed < time.Second {
t.Errorf("expected at least 1s sleep, got %v", elapsed)
}
}

func TestRateLimitRetry_ResetHeader(t *testing.T) {
resetAt := time.Now().Add(1 * time.Second).Unix()

mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusForbidden, map[string]string{
httpHeaderRateReset: fmt.Sprintf("%d", resetAt),
}),
makeResponse(http.StatusOK, nil),
},
}

mw := RateLimitRetry(5 * time.Second)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
start := time.Now()
res, err := rt.RoundTrip(req)
elapsed := time.Since(start)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusOK {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusOK)
}
if mock.calls != 2 {
t.Errorf("RoundTrip calls: got %d, want 2", mock.calls)
}
if elapsed < 100*time.Millisecond {
t.Errorf("expected ~1s sleep, got %v", elapsed)
}
}

func TestRateLimitRetry_WaitExceedsMaxWait(t *testing.T) {
mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "60"}),
makeResponse(http.StatusOK, nil),
},
}

mw := RateLimitRetry(5 * time.Second)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
res, err := rt.RoundTrip(req)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusForbidden {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden)
}
if mock.calls != 1 {
t.Errorf("RoundTrip calls: got %d, want 1 (should not retry)", mock.calls)
}
}

func TestRateLimitRetry_NoRateLimitHeaders(t *testing.T) {
mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusForbidden, nil),
},
}

mw := RateLimitRetry(5 * time.Second)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
res, err := rt.RoundTrip(req)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusForbidden {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden)
}
if mock.calls != 1 {
t.Errorf("RoundTrip calls: got %d, want 1 (no rate limit header)", mock.calls)
}
}

func TestRateLimitRetry_Non403PassThrough(t *testing.T) {
mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusTooManyRequests, map[string]string{headerRetryAfter: "1"}),
},
}

mw := RateLimitRetry(5 * time.Second)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
res, err := rt.RoundTrip(req)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusTooManyRequests {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusTooManyRequests)
}
if mock.calls != 1 {
t.Errorf("RoundTrip calls: got %d, want 1 (non-403 not retried)", mock.calls)
}
}

func TestRateLimitRetry_ZeroMaxWaitDisablesRetry(t *testing.T) {
mock := &mockRoundTripper{
responses: []*http.Response{
makeResponse(http.StatusForbidden, map[string]string{headerRetryAfter: "1"}),
makeResponse(http.StatusOK, nil),
},
}

mw := RateLimitRetry(0)
rt := mw(mock)

req, _ := http.NewRequest(http.MethodGet, "https://api.github.com/repos", nil)
res, err := rt.RoundTrip(req)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode != http.StatusForbidden {
t.Errorf("status: got %d, want %d", res.StatusCode, http.StatusForbidden)
}
if mock.calls != 1 {
t.Errorf("RoundTrip calls: got %d, want 1 (maxWait=0 disables retry)", mock.calls)
}
}