Skip to content

Commit 1a1890b

Browse files
security: Reject cross-host redirects to prevent Authorization leak (#4171)
1 parent 971b607 commit 1a1890b

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

github/github.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,12 @@ func (c *Client) bareDoUntilFound(ctx context.Context, req *http.Request, maxRed
10901090
return nil, nil, errInvalidLocation
10911091
}
10921092
newURL := c.BaseURL.ResolveReference(rerr.Location)
1093+
// Refuse to follow a permanent redirect to a different host:
1094+
// req.Clone preserves Authorization headers added by the auth
1095+
// transport, so a cross-host target would leak credentials.
1096+
if newURL.Host != c.BaseURL.Host {
1097+
return nil, response, fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, newURL.Host)
1098+
}
10931099
newRequest := req.Clone(ctx)
10941100
newRequest.URL = newURL
10951101
return c.bareDoUntilFound(ctx, newRequest, maxRedirects-1)
@@ -1846,11 +1852,35 @@ func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u stri
18461852
if maxRedirects > 0 && resp.StatusCode == http.StatusMovedPermanently {
18471853
_ = resp.Body.Close()
18481854
u = resp.Header.Get("Location")
1855+
if err := c.checkRedirectHost(u); err != nil {
1856+
return nil, err
1857+
}
18491858
resp, err = c.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects-1, opts...)
18501859
}
18511860
return resp, err
18521861
}
18531862

1863+
// checkRedirectHost returns an error if the redirect target is on a different
1864+
// host than the client's configured BaseURL. This prevents credentials attached
1865+
// by the auth transport from being sent to an attacker-controlled host when a
1866+
// compromised or malicious API response returns a cross-origin Location header.
1867+
// An empty Location is also rejected.
1868+
func (c *Client) checkRedirectHost(location string) error {
1869+
if location == "" {
1870+
return errInvalidLocation
1871+
}
1872+
target, err := url.Parse(location)
1873+
if err != nil {
1874+
return fmt.Errorf("invalid redirect location %q: %w", location, err)
1875+
}
1876+
// Resolve relative locations against BaseURL so relative paths are allowed.
1877+
target = c.BaseURL.ResolveReference(target)
1878+
if target.Host != c.BaseURL.Host {
1879+
return fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, target.Host)
1880+
}
1881+
return nil
1882+
}
1883+
18541884
// Ptr is a helper routine that allocates a new T value
18551885
// to store v and returns a pointer to it.
18561886
func Ptr[T any](v T) *T {

github/github_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,6 +2256,80 @@ func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) {
22562256
}
22572257
}
22582258

2259+
// TestBareDoUntilFound_RejectsCrossHostRedirect verifies that bareDoUntilFound
2260+
// refuses to follow a 301 redirect whose Location points to a different host,
2261+
// which would otherwise leak the Authorization header (added by the auth
2262+
// transport) to an attacker-controlled server.
2263+
func TestBareDoUntilFound_RejectsCrossHostRedirect(t *testing.T) {
2264+
t.Parallel()
2265+
client, mux, _ := setup(t)
2266+
2267+
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
2268+
w.Header().Set("Location", "https://evil.example.com/steal")
2269+
w.WriteHeader(http.StatusMovedPermanently)
2270+
})
2271+
2272+
req, _ := client.NewRequest("GET", ".", nil)
2273+
_, _, err := client.bareDoUntilFound(t.Context(), req, 1)
2274+
if err == nil {
2275+
t.Fatal("Expected cross-host redirect to be rejected, got nil error.")
2276+
}
2277+
if !strings.Contains(err.Error(), "cross-host redirect") {
2278+
t.Errorf("Expected cross-host redirect error, got: %v", err)
2279+
}
2280+
}
2281+
2282+
// TestRoundTripWithOptionalFollowRedirect_RejectsCrossHostRedirect verifies
2283+
// that roundTripWithOptionalFollowRedirect refuses to follow a 301 redirect to
2284+
// a different host, preventing Authorization-header leakage to attacker-
2285+
// controlled servers via a malicious or compromised API response.
2286+
func TestRoundTripWithOptionalFollowRedirect_RejectsCrossHostRedirect(t *testing.T) {
2287+
t.Parallel()
2288+
client, mux, _ := setup(t)
2289+
2290+
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
2291+
w.Header().Set("Location", "https://evil.example.com/steal")
2292+
w.WriteHeader(http.StatusMovedPermanently)
2293+
})
2294+
2295+
_, err := client.roundTripWithOptionalFollowRedirect(t.Context(), ".", 1)
2296+
if err == nil {
2297+
t.Fatal("Expected cross-host redirect to be rejected, got nil error.")
2298+
}
2299+
if !strings.Contains(err.Error(), "cross-host redirect") {
2300+
t.Errorf("Expected cross-host redirect error, got: %v", err)
2301+
}
2302+
}
2303+
2304+
// TestRoundTripWithOptionalFollowRedirect_AllowsSameHostRedirect ensures the
2305+
// cross-host check does not break legitimate same-host 301 follow behavior
2306+
// (the path that rate-limit redirection relies on).
2307+
func TestRoundTripWithOptionalFollowRedirect_AllowsSameHostRedirect(t *testing.T) {
2308+
t.Parallel()
2309+
client, mux, _ := setup(t)
2310+
2311+
var followed atomic.Bool
2312+
mux.HandleFunc("/archive", func(w http.ResponseWriter, _ *http.Request) {
2313+
w.Header().Set("Location", baseURLPath+"/archive-target")
2314+
w.WriteHeader(http.StatusMovedPermanently)
2315+
})
2316+
mux.HandleFunc("/archive-target", func(w http.ResponseWriter, _ *http.Request) {
2317+
followed.Store(true)
2318+
w.WriteHeader(http.StatusOK)
2319+
})
2320+
2321+
resp, err := client.roundTripWithOptionalFollowRedirect(t.Context(), "archive", 2)
2322+
if err != nil {
2323+
t.Fatalf("Unexpected error on same-host redirect: %v", err)
2324+
}
2325+
if resp != nil && resp.Body != nil {
2326+
resp.Body.Close()
2327+
}
2328+
if !followed.Load() {
2329+
t.Error("Expected same-host redirect to be followed.")
2330+
}
2331+
}
2332+
22592333
func TestSanitizeURL(t *testing.T) {
22602334
t.Parallel()
22612335
tests := []struct {

0 commit comments

Comments
 (0)