Skip to content

Commit c2ac58d

Browse files
authored
XHTTP Client: Race Dialer
1 parent ca9a902 commit c2ac58d

2 files changed

Lines changed: 561 additions & 10 deletions

File tree

transport/internet/splithttp/dialer.go

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/http/httptrace"
1010
"net/url"
11+
"slices"
1112
"strconv"
1213
"sync"
1314
"sync/atomic"
@@ -93,6 +94,9 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
9394
return "1.1"
9495
}
9596
if len(tlsConfig.NextProtocol) != 1 {
97+
if slices.Contains(tlsConfig.NextProtocol, "h3") && slices.Contains(tlsConfig.NextProtocol, "h2") {
98+
return "3+2"
99+
}
96100
return "2"
97101
}
98102
if tlsConfig.NextProtocol[0] == "http/1.1" {
@@ -101,6 +105,7 @@ func decideHTTPVersion(tlsConfig *tls.Config, realityConfig *reality.Config) str
101105
if tlsConfig.NextProtocol[0] == "h3" {
102106
return "3"
103107
}
108+
104109
return "2"
105110
}
106111

@@ -109,14 +114,27 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
109114
realityConfig := reality.ConfigFromStreamSettings(streamSettings)
110115

111116
httpVersion := decideHTTPVersion(tlsConfig, realityConfig)
112-
if httpVersion == "3" {
113-
dest.Network = net.Network_UDP // better to keep this line
114-
}
115117

116118
var gotlsConfig *gotls.Config
119+
var h3gotlsConfig *gotls.Config
117120

118121
if tlsConfig != nil {
119122
gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
123+
h3gotlsConfig = gotlsConfig
124+
125+
if httpVersion == "3+2" {
126+
h3gotlsConfig = &gotls.Config{}
127+
*h3gotlsConfig = *gotlsConfig
128+
129+
// Make QUIC ALPN only contains h3, and remove h3 from TCP TLS ALPN
130+
h3gotlsConfig.NextProtos = []string{"h3"}
131+
h3idx := slices.Index(h3gotlsConfig.NextProtos, "h3")
132+
// Don't modify original tlsConfig.NextProtocol
133+
nextProtos := gotlsConfig.NextProtos
134+
gotlsConfig.NextProtos = make([]string, 0, len(nextProtos)-1)
135+
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[:h3idx]...)
136+
gotlsConfig.NextProtos = append(gotlsConfig.NextProtos, nextProtos[h3idx+1:]...)
137+
}
120138
}
121139

122140
transportConfig := streamSettings.ProtocolSettings.(*Config)
@@ -152,7 +170,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
152170

153171
var transport http.RoundTripper
154172

155-
if httpVersion == "3" {
173+
makeH3Transport := func() *http3.Transport {
156174
if keepAlivePeriod == 0 {
157175
keepAlivePeriod = quicgoH3KeepAlivePeriod
158176
}
@@ -168,9 +186,11 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
168186
MaxIncomingStreams: -1,
169187
KeepAlivePeriod: keepAlivePeriod,
170188
}
171-
transport = &http3.RoundTripper{
189+
dest := dest
190+
dest.Network = net.Network_UDP
191+
return &http3.Transport{
172192
QUICConfig: quicConfig,
173-
TLSClientConfig: gotlsConfig,
193+
TLSClientConfig: h3gotlsConfig,
174194
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
175195
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
176196
if err != nil {
@@ -208,26 +228,30 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
208228
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
209229
},
210230
}
211-
} else if httpVersion == "2" {
231+
}
232+
233+
makeH2Transport := func() *http2.Transport {
212234
if keepAlivePeriod == 0 {
213235
keepAlivePeriod = chromeH2KeepAlivePeriod
214236
}
215237
if keepAlivePeriod < 0 {
216238
keepAlivePeriod = 0
217239
}
218-
transport = &http2.Transport{
240+
return &http2.Transport{
219241
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
220242
return dialContext(ctxInner)
221243
},
222244
IdleConnTimeout: connIdleTimeout,
223245
ReadIdleTimeout: keepAlivePeriod,
224246
}
225-
} else {
247+
}
248+
249+
makeTransport := func() *http.Transport {
226250
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
227251
return dialContext(ctxInner)
228252
}
229253

230-
transport = &http.Transport{
254+
return &http.Transport{
231255
DialTLSContext: httpDialContext,
232256
DialContext: httpDialContext,
233257
IdleConnTimeout: connIdleTimeout,
@@ -237,6 +261,22 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea
237261
}
238262
}
239263

264+
switch httpVersion {
265+
case "3":
266+
transport = makeH3Transport()
267+
case "2":
268+
transport = makeH2Transport()
269+
case "3+2":
270+
raceTransport := &raceTransport{
271+
h3: makeH3Transport(),
272+
h2: makeH2Transport(),
273+
dest: dest,
274+
}
275+
transport = raceTransport.setup()
276+
default:
277+
transport = makeTransport()
278+
}
279+
240280
client := &DefaultDialerClient{
241281
transportConfig: transportConfig,
242282
client: &http.Client{

0 commit comments

Comments
 (0)