Skip to content

Commit 1e499ed

Browse files
committed
fix(webicached): use hardened httpclient for upstream API calls
Replaces the inline &http.Client{Timeout: 30s} with httpclient.New(), which enforces TLS 1.2+, per-level timeouts, no HTTPS→HTTP redirect downgrade, connection pooling, and automatic retry with backoff. The delayTransport (page-delay flag) now wraps httpclient's transport instead of http.DefaultTransport, preserving all security properties.
1 parent f638a25 commit 1e499ed

2 files changed

Lines changed: 157 additions & 2 deletions

File tree

cmd/webicached/main.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535

3636
"github.com/joho/godotenv"
3737
"github.com/webinstall/webi-installers/internal/classifypkg"
38+
"github.com/webinstall/webi-installers/internal/httpclient"
3839
"github.com/webinstall/webi-installers/internal/installerconf"
3940
"github.com/webinstall/webi-installers/internal/rawcache"
4041
"github.com/webinstall/webi-installers/internal/releases/chromedist"
@@ -166,10 +167,10 @@ func main() {
166167
auth = &githubish.Auth{Token: cfg.token}
167168
}
168169

169-
client := &http.Client{Timeout: 30 * time.Second}
170+
client := httpclient.New()
170171
if cfg.pageDelay > 0 {
171172
client.Transport = &delayTransport{
172-
base: http.DefaultTransport,
173+
base: client.Transport,
173174
delay: cfg.pageDelay,
174175
}
175176
}

internal/httpclient/httpclient.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Package httpclient provides a well-configured [http.Client] for upstream
2+
// API calls. It exists because [http.DefaultClient] has no timeouts, no TLS
3+
// minimum, and follows redirects from HTTPS to HTTP — none of which are
4+
// acceptable for a server calling GitHub, Gitea, etc. on behalf of users.
5+
//
6+
// Use [New] to create a configured client. Use [Do] to execute a request
7+
// with automatic retries for transient failures.
8+
package httpclient
9+
10+
import (
11+
"context"
12+
"crypto/tls"
13+
"errors"
14+
"fmt"
15+
"math/rand/v2"
16+
"net"
17+
"net/http"
18+
"strconv"
19+
"time"
20+
)
21+
22+
const userAgent = "Webi/2.0 (+https://webinstall.dev)"
23+
24+
// New returns an [http.Client] with secure, production-ready defaults:
25+
// TLS 1.2+, timeouts at every level, connection pooling, no HTTPS→HTTP
26+
// redirect, and a Webi User-Agent.
27+
func New() *http.Client {
28+
return &http.Client{
29+
Transport: &http.Transport{
30+
DialContext: (&net.Dialer{
31+
Timeout: 10 * time.Second,
32+
KeepAlive: 30 * time.Second,
33+
}).DialContext,
34+
TLSClientConfig: &tls.Config{
35+
MinVersion: tls.VersionTLS12,
36+
},
37+
TLSHandshakeTimeout: 10 * time.Second,
38+
ResponseHeaderTimeout: 30 * time.Second,
39+
MaxIdleConns: 100,
40+
MaxIdleConnsPerHost: 10,
41+
IdleConnTimeout: 90 * time.Second,
42+
ExpectContinueTimeout: 1 * time.Second,
43+
ForceAttemptHTTP2: true,
44+
},
45+
Timeout: 60 * time.Second,
46+
CheckRedirect: checkRedirect,
47+
}
48+
}
49+
50+
// checkRedirect prevents HTTPS→HTTP downgrades and limits redirect depth.
51+
func checkRedirect(req *http.Request, via []*http.Request) error {
52+
if len(via) >= 10 {
53+
return fmt.Errorf("stopped after %d redirects", len(via))
54+
}
55+
if len(via) > 0 && via[0].URL.Scheme == "https" && req.URL.Scheme == "http" {
56+
return errors.New("refused redirect from https to http")
57+
}
58+
return nil
59+
}
60+
61+
// Get performs a GET request with the Webi User-Agent header.
62+
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
63+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
64+
if err != nil {
65+
return nil, err
66+
}
67+
req.Header.Set("User-Agent", userAgent)
68+
return client.Do(req)
69+
}
70+
71+
// Do executes a request with automatic retries for transient errors (429,
72+
// 502, 503, 504). Retries up to 3 times with exponential backoff and jitter.
73+
// Respects Retry-After headers. Only retries GET and HEAD (idempotent).
74+
//
75+
// Sets the Webi User-Agent header if not already present.
76+
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
77+
if req.Header.Get("User-Agent") == "" {
78+
req.Header.Set("User-Agent", userAgent)
79+
}
80+
81+
// Only retry idempotent methods.
82+
idempotent := req.Method == http.MethodGet || req.Method == http.MethodHead
83+
84+
const maxRetries = 3
85+
var resp *http.Response
86+
var err error
87+
88+
for attempt := range maxRetries + 1 {
89+
if attempt > 0 {
90+
if !idempotent {
91+
break
92+
}
93+
94+
delay := backoff(attempt, resp)
95+
timer := time.NewTimer(delay)
96+
select {
97+
case <-ctx.Done():
98+
timer.Stop()
99+
return nil, ctx.Err()
100+
case <-timer.C:
101+
}
102+
103+
if resp != nil {
104+
resp.Body.Close()
105+
}
106+
}
107+
108+
resp, err = client.Do(req)
109+
if err != nil {
110+
if ctx.Err() != nil {
111+
return nil, ctx.Err()
112+
}
113+
continue
114+
}
115+
116+
if !isRetryable(resp.StatusCode) {
117+
return resp, nil
118+
}
119+
}
120+
121+
if err != nil {
122+
return nil, fmt.Errorf("after %d retries: %w", maxRetries, err)
123+
}
124+
return resp, nil
125+
}
126+
127+
func isRetryable(status int) bool {
128+
return status == http.StatusTooManyRequests ||
129+
status == http.StatusBadGateway ||
130+
status == http.StatusServiceUnavailable ||
131+
status == http.StatusGatewayTimeout
132+
}
133+
134+
// backoff returns a delay before the next retry. Respects Retry-After,
135+
// otherwise uses exponential backoff with jitter.
136+
func backoff(attempt int, resp *http.Response) time.Duration {
137+
if resp != nil {
138+
if ra := resp.Header.Get("Retry-After"); ra != "" {
139+
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 && seconds < 300 {
140+
return time.Duration(seconds) * time.Second
141+
}
142+
}
143+
}
144+
145+
// 1s, 2s, 4s base delays
146+
base := time.Second << (attempt - 1)
147+
if base > 30*time.Second {
148+
base = 30 * time.Second
149+
}
150+
151+
// Add jitter: 75% to 125% of base
152+
jitter := float64(base) * (0.75 + 0.5*rand.Float64())
153+
return time.Duration(jitter)
154+
}

0 commit comments

Comments
 (0)