Skip to content

Commit 594a9cd

Browse files
committed
Scope HTTP/2 fallback and HTTP/3 broken state per authority
1 parent 031a147 commit 594a9cd

6 files changed

Lines changed: 295 additions & 40 deletions

File tree

common/httpclient/helpers.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
E "github.com/sagernet/sing/common/exceptions"
1313
M "github.com/sagernet/sing/common/metadata"
1414
N "github.com/sagernet/sing/common/network"
15+
16+
"golang.org/x/net/idna"
1517
)
1618

1719
func dialTLS(ctx context.Context, rawDialer N.Dialer, baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string, expectProto string) (net.Conn, error) {
@@ -73,6 +75,34 @@ func mustGetBody(request *http.Request) io.ReadCloser {
7375
return body
7476
}
7577

78+
func requestAuthority(request *http.Request) string {
79+
if request == nil || request.URL == nil || request.URL.Host == "" {
80+
return ""
81+
}
82+
host, port, err := net.SplitHostPort(request.URL.Host)
83+
if err != nil {
84+
host = request.URL.Host
85+
port = ""
86+
}
87+
if port == "" {
88+
if request.URL.Scheme == "http" {
89+
port = "80"
90+
} else {
91+
port = "443"
92+
}
93+
}
94+
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
95+
return host + ":" + port
96+
}
97+
ascii, idnaErr := idna.Lookup.ToASCII(host)
98+
if idnaErr == nil {
99+
host = ascii
100+
} else {
101+
host = strings.ToLower(host)
102+
}
103+
return net.JoinHostPort(host, port)
104+
}
105+
76106
func buildSTDTLSConfig(baseTLSConfig tls.Config, destination M.Socksaddr, nextProtos []string) (*stdTLS.Config, error) {
77107
if baseTLSConfig == nil {
78108
return nil, nil

common/httpclient/helpers_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package httpclient
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"testing"
7+
)
8+
9+
func TestRequestAuthority(t *testing.T) {
10+
testCases := []struct {
11+
name string
12+
url string
13+
expect string
14+
}{
15+
{name: "https default port", url: "https://example.com/foo", expect: "example.com:443"},
16+
{name: "http default port", url: "http://example.com/foo", expect: "example.com:80"},
17+
{name: "https explicit port", url: "https://example.com:8443/foo", expect: "example.com:8443"},
18+
{name: "https uppercase host", url: "https://EXAMPLE.COM/foo", expect: "example.com:443"},
19+
{name: "https ipv6 default port", url: "https://[2001:db8::1]/foo", expect: "[2001:db8::1]:443"},
20+
{name: "https ipv6 explicit port", url: "https://[2001:db8::1]:8443/foo", expect: "[2001:db8::1]:8443"},
21+
{name: "https ipv4", url: "https://192.0.2.1/foo", expect: "192.0.2.1:443"},
22+
}
23+
for _, testCase := range testCases {
24+
t.Run(testCase.name, func(t *testing.T) {
25+
parsed, err := url.Parse(testCase.url)
26+
if err != nil {
27+
t.Fatalf("parse url: %v", err)
28+
}
29+
got := requestAuthority(&http.Request{URL: parsed})
30+
if got != testCase.expect {
31+
t.Fatalf("got %q, want %q", got, testCase.expect)
32+
}
33+
})
34+
}
35+
36+
t.Run("nil request", func(t *testing.T) {
37+
if got := requestAuthority(nil); got != "" {
38+
t.Fatalf("got %q, want empty", got)
39+
}
40+
})
41+
t.Run("nil URL", func(t *testing.T) {
42+
if got := requestAuthority(&http.Request{}); got != "" {
43+
t.Fatalf("got %q, want empty", got)
44+
}
45+
})
46+
t.Run("empty host", func(t *testing.T) {
47+
if got := requestAuthority(&http.Request{URL: &url.URL{Scheme: "https"}}); got != "" {
48+
t.Fatalf("got %q, want empty", got)
49+
}
50+
})
51+
}

common/httpclient/http2_fallback_transport.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"errors"
77
"net"
88
"net/http"
9-
"sync/atomic"
9+
"sync"
1010

1111
"github.com/sagernet/sing-box/common/tls"
1212
"github.com/sagernet/sing-box/option"
@@ -20,35 +20,47 @@ import (
2020
var errHTTP2Fallback = E.New("fallback to HTTP/1.1")
2121

2222
type http2FallbackTransport struct {
23-
h2Transport *http2.Transport
24-
h1Transport *http1Transport
25-
h2Fallback *atomic.Bool
23+
h2Transport *http2.Transport
24+
h1Transport *http1Transport
25+
fallbackAccess sync.RWMutex
26+
fallbackAuthority map[string]struct{}
2627
}
2728

2829
func newHTTP2FallbackTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2FallbackTransport, error) {
2930
h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
30-
var fallback atomic.Bool
3131
h2Transport, err := ConfigureHTTP2Transport(options)
3232
if err != nil {
3333
return nil, err
3434
}
3535
h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
36-
conn, dialErr := dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
37-
if dialErr != nil {
38-
if errors.Is(dialErr, errHTTP2Fallback) {
39-
fallback.Store(true)
40-
}
41-
return nil, dialErr
42-
}
43-
return conn, nil
36+
return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
4437
}
4538
return &http2FallbackTransport{
46-
h2Transport: h2Transport,
47-
h1Transport: h1,
48-
h2Fallback: &fallback,
39+
h2Transport: h2Transport,
40+
h1Transport: h1,
41+
fallbackAuthority: make(map[string]struct{}),
4942
}, nil
5043
}
5144

45+
func (t *http2FallbackTransport) isH2Fallback(authority string) bool {
46+
if authority == "" {
47+
return false
48+
}
49+
t.fallbackAccess.RLock()
50+
_, found := t.fallbackAuthority[authority]
51+
t.fallbackAccess.RUnlock()
52+
return found
53+
}
54+
55+
func (t *http2FallbackTransport) markH2Fallback(authority string) {
56+
if authority == "" {
57+
return
58+
}
59+
t.fallbackAccess.Lock()
60+
t.fallbackAuthority[authority] = struct{}{}
61+
t.fallbackAccess.Unlock()
62+
}
63+
5264
func (t *http2FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
5365
return t.roundTrip(request, true)
5466
}
@@ -57,7 +69,8 @@ func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fall
5769
if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
5870
return t.h1Transport.RoundTrip(request)
5971
}
60-
if t.h2Fallback.Load() {
72+
authority := requestAuthority(request)
73+
if t.isH2Fallback(authority) {
6174
if !allowHTTP1Fallback {
6275
return nil, errHTTP2Fallback
6376
}
@@ -70,6 +83,7 @@ func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fall
7083
if !errors.Is(err, errHTTP2Fallback) || !allowHTTP1Fallback {
7184
return nil, err
7285
}
86+
t.markH2Fallback(authority)
7387
return t.h1Transport.RoundTrip(cloneRequestForRetry(request))
7488
}
7589

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package httpclient
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestHTTP2FallbackAuthorityIsolation(t *testing.T) {
8+
transport := &http2FallbackTransport{fallbackAuthority: make(map[string]struct{})}
9+
10+
transport.markH2Fallback("a.example:443")
11+
if !transport.isH2Fallback("a.example:443") {
12+
t.Fatal("a.example:443 should be marked")
13+
}
14+
if transport.isH2Fallback("b.example:443") {
15+
t.Fatal("b.example:443 must remain unmarked after marking a.example")
16+
}
17+
18+
transport.markH2Fallback("b.example:443")
19+
if !transport.isH2Fallback("b.example:443") {
20+
t.Fatal("b.example:443 should be marked after explicit mark")
21+
}
22+
if !transport.isH2Fallback("a.example:443") {
23+
t.Fatal("a.example:443 mark must survive marking another authority")
24+
}
25+
}
26+
27+
func TestHTTP2FallbackEmptyAuthorityNoOp(t *testing.T) {
28+
transport := &http2FallbackTransport{fallbackAuthority: make(map[string]struct{})}
29+
30+
transport.markH2Fallback("")
31+
if len(transport.fallbackAuthority) != 0 {
32+
t.Fatalf("empty authority must not be stored, got %d entries", len(transport.fallbackAuthority))
33+
}
34+
if transport.isH2Fallback("") {
35+
t.Fatal("isH2Fallback must be false for empty authority")
36+
}
37+
}

common/httpclient/http3_transport.go

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ type http3Transport struct {
2424
h3Transport *http3.Transport
2525
}
2626

27+
type http3BrokenEntry struct {
28+
until time.Time
29+
backoff time.Duration
30+
}
31+
2732
type http3FallbackTransport struct {
2833
h3Transport *http3.Transport
2934
h2Fallback innerTransport
3035
fallbackDelay time.Duration
3136
brokenAccess sync.Mutex
32-
brokenUntil time.Time
33-
brokenBackoff time.Duration
37+
broken map[string]http3BrokenEntry
3438
}
3539

3640
func newHTTP3RoundTripper(
@@ -114,6 +118,7 @@ func newHTTP3FallbackTransport(
114118
h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
115119
h2Fallback: h2Fallback,
116120
fallbackDelay: fallbackDelay,
121+
broken: make(map[string]http3BrokenEntry),
117122
}, nil
118123
}
119124

@@ -138,31 +143,32 @@ func (t *http3FallbackTransport) RoundTrip(request *http.Request) (*http.Respons
138143
}
139144

140145
func (t *http3FallbackTransport) roundTripHTTP3(request *http.Request) (*http.Response, error) {
141-
if t.h3Broken() {
146+
authority := requestAuthority(request)
147+
if t.h3Broken(authority) {
142148
return t.h2FallbackRoundTrip(request)
143149
}
144150
response, err := t.h3Transport.RoundTripOpt(request, http3.RoundTripOpt{OnlyCachedConn: true})
145151
if err == nil {
146-
t.clearH3Broken()
152+
t.clearH3Broken(authority)
147153
return response, nil
148154
}
149155
if !errors.Is(err, http3.ErrNoCachedConn) {
150-
t.markH3Broken()
156+
t.markH3Broken(authority)
151157
return t.h2FallbackRoundTrip(cloneRequestForRetry(request))
152158
}
153159
if !requestReplayable(request) {
154160
response, err = t.h3Transport.RoundTrip(request)
155161
if err == nil {
156-
t.clearH3Broken()
162+
t.clearH3Broken(authority)
157163
return response, nil
158164
}
159-
t.markH3Broken()
165+
t.markH3Broken(authority)
160166
return nil, err
161167
}
162-
return t.roundTripHTTP3Race(request)
168+
return t.roundTripHTTP3Race(request, authority)
163169
}
164170

165-
func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request) (*http.Response, error) {
171+
func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request, authority string) (*http.Response, error) {
166172
ctx, cancel := context.WithCancel(request.Context())
167173
defer cancel()
168174
type result struct {
@@ -215,13 +221,13 @@ func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request) (*htt
215221
received++
216222
if raceResult.err == nil {
217223
if raceResult.h3 {
218-
t.clearH3Broken()
224+
t.clearH3Broken(authority)
219225
}
220226
drainRemaining()
221227
return raceResult.response, nil
222228
}
223229
if raceResult.h3 {
224-
t.markH3Broken()
230+
t.markH3Broken(authority)
225231
h3Err = raceResult.err
226232
if goroutines == 1 {
227233
goroutines++
@@ -269,29 +275,47 @@ func (t *http3FallbackTransport) Close() error {
269275
return t.h3Transport.Close()
270276
}
271277

272-
func (t *http3FallbackTransport) h3Broken() bool {
278+
func (t *http3FallbackTransport) h3Broken(authority string) bool {
279+
if authority == "" {
280+
return false
281+
}
273282
t.brokenAccess.Lock()
274283
defer t.brokenAccess.Unlock()
275-
return !t.brokenUntil.IsZero() && time.Now().Before(t.brokenUntil)
284+
entry, found := t.broken[authority]
285+
if !found {
286+
return false
287+
}
288+
if entry.until.IsZero() || !time.Now().Before(entry.until) {
289+
delete(t.broken, authority)
290+
return false
291+
}
292+
return true
276293
}
277294

278-
func (t *http3FallbackTransport) clearH3Broken() {
295+
func (t *http3FallbackTransport) clearH3Broken(authority string) {
296+
if authority == "" {
297+
return
298+
}
279299
t.brokenAccess.Lock()
280-
t.brokenUntil = time.Time{}
281-
t.brokenBackoff = 0
300+
delete(t.broken, authority)
282301
t.brokenAccess.Unlock()
283302
}
284303

285-
func (t *http3FallbackTransport) markH3Broken() {
304+
func (t *http3FallbackTransport) markH3Broken(authority string) {
305+
if authority == "" {
306+
return
307+
}
286308
t.brokenAccess.Lock()
287309
defer t.brokenAccess.Unlock()
288-
if t.brokenBackoff == 0 {
289-
t.brokenBackoff = 5 * time.Minute
310+
entry := t.broken[authority]
311+
if entry.backoff == 0 {
312+
entry.backoff = 5 * time.Minute
290313
} else {
291-
t.brokenBackoff *= 2
292-
if t.brokenBackoff > 48*time.Hour {
293-
t.brokenBackoff = 48 * time.Hour
314+
entry.backoff *= 2
315+
if entry.backoff > 48*time.Hour {
316+
entry.backoff = 48 * time.Hour
294317
}
295318
}
296-
t.brokenUntil = time.Now().Add(t.brokenBackoff)
319+
entry.until = time.Now().Add(entry.backoff)
320+
t.broken[authority] = entry
297321
}

0 commit comments

Comments
 (0)