Skip to content

Commit bf7e899

Browse files
committed
Fix cloudflared parity gaps
1 parent 7dceb2e commit bf7e899

23 files changed

Lines changed: 1378 additions & 141 deletions

protocol/cloudflare/connection_drain_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ import (
1010
"testing"
1111
"time"
1212

13-
"github.com/google/uuid"
1413
"github.com/sagernet/quic-go"
14+
15+
"github.com/google/uuid"
1516
)
1617

1718
type stubNetConn struct {
@@ -43,6 +44,7 @@ func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.N
4344
func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) {
4445
return nil, errors.New("unused")
4546
}
47+
4648
func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) {
4749
return nil, errors.New("unused")
4850
}

protocol/cloudflare/connection_http2.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"math"
1010
"net"
1111
"net/http"
12+
"runtime/debug"
1213
"strconv"
1314
"strings"
1415
"sync"
@@ -27,8 +28,17 @@ const (
2728
h2EdgeSNI = "h2.cftunnel.com"
2829
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
2930
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
31+
contentTypeHeader = "content-type"
32+
contentLengthHeader = "content-length"
33+
transferEncodingHeader = "transfer-encoding"
34+
chunkTransferEncoding = "chunked"
35+
sseContentType = "text/event-stream"
36+
grpcContentType = "application/grpc"
37+
ndjsonContentType = "application/x-ndjson"
3038
)
3139

40+
var flushableContentTypes = []string{sseContentType, grpcContentType, ndjsonContentType}
41+
3242
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
3343
// Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server.
3444
type HTTP2Connection struct {
@@ -191,7 +201,7 @@ func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Reque
191201
return
192202
}
193203
c.registrationResult = result
194-
c.inbound.notifyConnected(c.connIndex)
204+
c.inbound.notifyConnected(c.connIndex, "http2")
195205

196206
c.logger.Info("connected to ", result.Location,
197207
" (connection ", result.ConnectionID, ")")
@@ -246,14 +256,18 @@ func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Reques
246256
}
247257
}
248258

259+
flushState := &http2FlushState{shouldFlush: connectionType != ConnectionTypeHTTP}
249260
stream := &http2DataStream{
250261
reader: r.Body,
251262
writer: w,
252263
flusher: flusher,
264+
state: flushState,
265+
logger: c.logger,
253266
}
254267
respWriter := &http2ResponseWriter{
255-
writer: w,
256-
flusher: flusher,
268+
writer: w,
269+
flusher: flusher,
270+
flushState: flushState,
257271
}
258272

259273
c.inbound.dispatchRequest(ctx, stream, respWriter, request)
@@ -386,15 +400,26 @@ type http2DataStream struct {
386400
reader io.ReadCloser
387401
writer http.ResponseWriter
388402
flusher http.Flusher
403+
state *http2FlushState
404+
logger log.ContextLogger
389405
}
390406

391407
func (s *http2DataStream) Read(p []byte) (int, error) {
392408
return s.reader.Read(p)
393409
}
394410

395-
func (s *http2DataStream) Write(p []byte) (int, error) {
396-
n, err := s.writer.Write(p)
397-
if err == nil {
411+
func (s *http2DataStream) Write(p []byte) (n int, err error) {
412+
defer func() {
413+
if recovered := recover(); recovered != nil {
414+
if s.logger != nil {
415+
s.logger.Debug("recovered from HTTP/2 data stream panic: ", recovered, "\n", string(debug.Stack()))
416+
}
417+
n = 0
418+
err = io.ErrClosedPipe
419+
}
420+
}()
421+
n, err = s.writer.Write(p)
422+
if err == nil && s.state != nil && s.state.shouldFlush {
398423
s.flusher.Flush()
399424
}
400425
return n, err
@@ -409,6 +434,7 @@ type http2ResponseWriter struct {
409434
writer http.ResponseWriter
410435
flusher http.Flusher
411436
headersSent bool
437+
flushState *http2FlushState
412438
}
413439

414440
func (w *http2ResponseWriter) AddTrailer(name, value string) {
@@ -462,12 +488,37 @@ func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Meta
462488

463489
w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders))
464490
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin)
491+
if w.flushState != nil && shouldFlushHTTPHeaders(userHeaders) {
492+
w.flushState.shouldFlush = true
493+
}
465494

466495
if statusCode == http.StatusSwitchingProtocols {
467496
statusCode = http.StatusOK
468497
}
469498

470499
w.writer.WriteHeader(statusCode)
471-
w.flusher.Flush()
500+
if w.flushState != nil && w.flushState.shouldFlush {
501+
w.flusher.Flush()
502+
}
472503
return nil
473504
}
505+
506+
type http2FlushState struct {
507+
shouldFlush bool
508+
}
509+
510+
func shouldFlushHTTPHeaders(headers http.Header) bool {
511+
if headers.Get(contentLengthHeader) == "" {
512+
return true
513+
}
514+
if transferEncoding := strings.ToLower(headers.Get(transferEncodingHeader)); transferEncoding != "" && strings.Contains(transferEncoding, chunkTransferEncoding) {
515+
return true
516+
}
517+
contentType := strings.ToLower(headers.Get(contentTypeHeader))
518+
for _, flushable := range flushableContentTypes {
519+
if strings.HasPrefix(contentType, flushable) {
520+
return true
521+
}
522+
}
523+
return false
524+
}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
//go:build with_cloudflared
2+
3+
package cloudflare
4+
5+
import (
6+
"io"
7+
"net/http"
8+
"testing"
9+
10+
"github.com/sagernet/sing-box/log"
11+
)
12+
13+
type captureHTTP2Writer struct {
14+
header http.Header
15+
flushCount int
16+
statusCode int
17+
body []byte
18+
panicWrite bool
19+
}
20+
21+
func (w *captureHTTP2Writer) Header() http.Header {
22+
if w.header == nil {
23+
w.header = make(http.Header)
24+
}
25+
return w.header
26+
}
27+
28+
func (w *captureHTTP2Writer) WriteHeader(statusCode int) {
29+
w.statusCode = statusCode
30+
}
31+
32+
func (w *captureHTTP2Writer) Write(p []byte) (int, error) {
33+
if w.panicWrite {
34+
panic("write after close")
35+
}
36+
w.body = append(w.body, p...)
37+
return len(p), nil
38+
}
39+
40+
func (w *captureHTTP2Writer) Flush() {
41+
w.flushCount++
42+
}
43+
44+
func TestHTTP2NonStreamingResponseDoesNotFlush(t *testing.T) {
45+
writer := &captureHTTP2Writer{}
46+
flushState := &http2FlushState{}
47+
respWriter := &http2ResponseWriter{
48+
writer: writer,
49+
flusher: writer,
50+
flushState: flushState,
51+
}
52+
53+
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, http.Header{
54+
"Content-Type": []string{"application/json"},
55+
"Content-Length": []string{"2"},
56+
}))
57+
if err != nil {
58+
t.Fatal(err)
59+
}
60+
if writer.flushCount != 0 {
61+
t.Fatalf("expected no header flush for non-streaming response, got %d", writer.flushCount)
62+
}
63+
64+
stream := &http2DataStream{
65+
writer: writer,
66+
flusher: writer,
67+
state: flushState,
68+
logger: log.NewNOPFactory().NewLogger("test"),
69+
}
70+
if _, err := stream.Write([]byte("ok")); err != nil {
71+
t.Fatal(err)
72+
}
73+
if writer.flushCount != 0 {
74+
t.Fatalf("expected no body flush for non-streaming response, got %d", writer.flushCount)
75+
}
76+
}
77+
78+
func TestHTTP2StreamingResponsesFlush(t *testing.T) {
79+
testCases := []struct {
80+
name string
81+
header http.Header
82+
}{
83+
{
84+
name: "sse",
85+
header: http.Header{
86+
"Content-Type": []string{"text/event-stream"},
87+
"Content-Length": []string{"1"},
88+
},
89+
},
90+
{
91+
name: "grpc",
92+
header: http.Header{
93+
"Content-Type": []string{"application/grpc"},
94+
"Content-Length": []string{"1"},
95+
},
96+
},
97+
{
98+
name: "ndjson",
99+
header: http.Header{
100+
"Content-Type": []string{"application/x-ndjson"},
101+
"Content-Length": []string{"1"},
102+
},
103+
},
104+
{
105+
name: "chunked",
106+
header: http.Header{
107+
"Content-Type": []string{"application/json"},
108+
"Content-Length": []string{"-1"},
109+
"Transfer-Encoding": []string{"chunked"},
110+
},
111+
},
112+
{
113+
name: "no-content-length",
114+
header: http.Header{
115+
"Content-Type": []string{"application/json"},
116+
},
117+
},
118+
}
119+
120+
for _, testCase := range testCases {
121+
t.Run(testCase.name, func(t *testing.T) {
122+
writer := &captureHTTP2Writer{}
123+
flushState := &http2FlushState{}
124+
respWriter := &http2ResponseWriter{
125+
writer: writer,
126+
flusher: writer,
127+
flushState: flushState,
128+
}
129+
130+
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, testCase.header))
131+
if err != nil {
132+
t.Fatal(err)
133+
}
134+
if writer.flushCount == 0 {
135+
t.Fatal("expected header flush for streaming response")
136+
}
137+
138+
stream := &http2DataStream{
139+
writer: writer,
140+
flusher: writer,
141+
state: flushState,
142+
logger: log.NewNOPFactory().NewLogger("test"),
143+
}
144+
if _, err := stream.Write([]byte("chunk")); err != nil {
145+
t.Fatal(err)
146+
}
147+
if writer.flushCount < 2 {
148+
t.Fatalf("expected body flush for streaming response, got %d flushes", writer.flushCount)
149+
}
150+
})
151+
}
152+
}
153+
154+
func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) {
155+
writer := &captureHTTP2Writer{panicWrite: true}
156+
stream := &http2DataStream{
157+
writer: writer,
158+
flusher: writer,
159+
state: &http2FlushState{shouldFlush: true},
160+
logger: log.NewNOPFactory().NewLogger("test"),
161+
}
162+
163+
_, err := stream.Write([]byte("panic"))
164+
if err != io.ErrClosedPipe {
165+
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
166+
}
167+
}

protocol/cloudflare/connection_quic.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ type QUICConnection struct {
6060
closeOnce sync.Once
6161
}
6262

63+
type quicStreamHandle interface {
64+
io.Reader
65+
io.Writer
66+
io.Closer
67+
CancelRead(code quic.StreamErrorCode)
68+
CancelWrite(code quic.StreamErrorCode)
69+
SetWriteDeadline(t time.Time) error
70+
}
71+
6372
type quicConnection interface {
6473
OpenStream() (*quic.Stream, error)
6574
AcceptStream(ctx context.Context) (*quic.Stream, error)
@@ -80,7 +89,6 @@ func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reaso
8089
return err
8190
}
8291

83-
8492
// NewQUICConnection dials the edge and establishes a QUIC connection.
8593
func NewQUICConnection(
8694
ctx context.Context,
@@ -240,13 +248,14 @@ func (q *QUICConnection) acceptStreams(ctx context.Context, handler StreamHandle
240248
}
241249
}
242250

243-
func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream, handler StreamHandler) {
251+
func (q *QUICConnection) handleStream(ctx context.Context, stream quicStreamHandle, handler StreamHandler) {
244252
rwc := newStreamReadWriteCloser(stream)
245253
defer rwc.Close()
246254

247255
streamType, err := ReadStreamSignature(rwc)
248256
if err != nil {
249257
q.logger.Debug("failed to read stream signature: ", err)
258+
stream.CancelWrite(0)
250259
return
251260
}
252261

@@ -255,6 +264,7 @@ func (q *QUICConnection) handleStream(ctx context.Context, stream *quic.Stream,
255264
request, err := ReadConnectRequest(rwc)
256265
if err != nil {
257266
q.logger.Debug("failed to read connect request: ", err)
267+
stream.CancelWrite(0)
258268
return
259269
}
260270
handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex)
@@ -365,11 +375,11 @@ type DatagramSender interface {
365375
// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser
366376
// with mutex-protected writes and safe close semantics.
367377
type streamReadWriteCloser struct {
368-
stream *quic.Stream
378+
stream quicStreamHandle
369379
writeAccess sync.Mutex
370380
}
371381

372-
func newStreamReadWriteCloser(stream *quic.Stream) *streamReadWriteCloser {
382+
func newStreamReadWriteCloser(stream quicStreamHandle) *streamReadWriteCloser {
373383
return &streamReadWriteCloser{stream: stream}
374384
}
375385

0 commit comments

Comments
 (0)