Skip to content

Commit 9b70552

Browse files
committed
Fix low-cost cloudflared parity gaps
1 parent bf7e899 commit 9b70552

22 files changed

Lines changed: 565 additions & 51 deletions

option/cloudflared.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package option
33
import "github.com/sagernet/sing/common/json/badoption"
44

55
type CloudflaredInboundOptions struct {
6-
Token string `json:"token,omitempty"`
7-
HAConnections int `json:"ha_connections,omitempty"`
8-
Protocol string `json:"protocol,omitempty"`
9-
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
10-
TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"`
11-
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
12-
DatagramVersion string `json:"datagram_version,omitempty"`
13-
GracePeriod badoption.Duration `json:"grace_period,omitempty"`
14-
Region string `json:"region,omitempty"`
6+
Token string `json:"token,omitempty"`
7+
HAConnections int `json:"ha_connections,omitempty"`
8+
Protocol string `json:"protocol,omitempty"`
9+
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
10+
TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"`
11+
EdgeIPVersion int `json:"edge_ip_version,omitempty"`
12+
DatagramVersion string `json:"datagram_version,omitempty"`
13+
GracePeriod *badoption.Duration `json:"grace_period,omitempty"`
14+
Region string `json:"region,omitempty"`
1515
}

protocol/cloudflare/config_decode_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ package cloudflare
55
import (
66
"context"
77
"testing"
8+
"time"
89

910
"github.com/sagernet/sing-box/log"
1011
"github.com/sagernet/sing-box/option"
12+
"github.com/sagernet/sing/common/json"
1113
)
1214

1315
func TestNewInboundRequiresToken(t *testing.T) {
@@ -36,3 +38,35 @@ func TestNormalizeProtocolAutoUsesTokenStyleSentinel(t *testing.T) {
3638
t.Fatalf("expected auto protocol to normalize to token-style empty sentinel, got %q", protocol)
3739
}
3840
}
41+
42+
func TestResolveGracePeriodDefaultsToThirtySeconds(t *testing.T) {
43+
if got := resolveGracePeriod(nil); got != 30*time.Second {
44+
t.Fatalf("expected default grace period 30s, got %s", got)
45+
}
46+
}
47+
48+
func TestResolveGracePeriodPreservesExplicitZero(t *testing.T) {
49+
var options option.CloudflaredInboundOptions
50+
if err := json.Unmarshal([]byte(`{"grace_period":"0s"}`), &options); err != nil {
51+
t.Fatal(err)
52+
}
53+
if options.GracePeriod == nil {
54+
t.Fatal("expected explicit grace period to be set")
55+
}
56+
if got := resolveGracePeriod(options.GracePeriod); got != 0 {
57+
t.Fatalf("expected explicit zero grace period, got %s", got)
58+
}
59+
}
60+
61+
func TestResolveGracePeriodPreservesNonZeroValue(t *testing.T) {
62+
var options option.CloudflaredInboundOptions
63+
if err := json.Unmarshal([]byte(`{"grace_period":"45s"}`), &options); err != nil {
64+
t.Fatal(err)
65+
}
66+
if options.GracePeriod == nil {
67+
t.Fatal("expected explicit grace period to be set")
68+
}
69+
if got := resolveGracePeriod(options.GracePeriod); got != 45*time.Second {
70+
t.Fatalf("expected grace period 45s, got %s", got)
71+
}
72+
}

protocol/cloudflare/connection_drain_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,37 @@ func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) {
232232
t.Fatal("expected graceful shutdown to finish")
233233
}
234234
}
235+
236+
func TestQUICGracefulShutdownStopsWaitingWhenServeContextEnds(t *testing.T) {
237+
conn := newStubQUICConn()
238+
registrationClient := newMockRegistrationClient()
239+
serveCtx, cancelServe := context.WithCancel(context.Background())
240+
connection := &QUICConnection{
241+
conn: conn,
242+
gracePeriod: time.Second,
243+
registrationClient: registrationClient,
244+
registrationResult: &RegistrationResult{},
245+
serveCtx: serveCtx,
246+
serveCancel: func() {},
247+
}
248+
249+
done := make(chan struct{})
250+
go func() {
251+
connection.gracefulShutdown()
252+
close(done)
253+
}()
254+
255+
select {
256+
case <-registrationClient.unregisterCalled:
257+
case <-time.After(time.Second):
258+
t.Fatal("expected unregister call")
259+
}
260+
261+
cancelServe()
262+
263+
select {
264+
case <-done:
265+
case <-time.After(200 * time.Millisecond):
266+
t.Fatal("expected graceful shutdown to stop waiting once serve context ends")
267+
}
268+
}

protocol/cloudflare/connection_http2.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ func NewHTTP2Connection(
8383
return nil, E.Cause(err, "load Cloudflare root CAs")
8484
}
8585

86-
tlsConfig := &tls.Config{
87-
RootCAs: rootCAs,
88-
ServerName: h2EdgeSNI,
89-
}
86+
tlsConfig := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
9087

9188
tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
9289
if err != nil {
@@ -283,7 +280,8 @@ func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.Resp
283280
err := json.NewDecoder(r.Body).Decode(&body)
284281
if err != nil {
285282
c.logger.Error("decode configuration update: ", err)
286-
w.WriteHeader(http.StatusBadRequest)
283+
w.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
284+
w.WriteHeader(http.StatusBadGateway)
287285
return
288286
}
289287
result := c.inbound.ApplyConfig(body.Version, body.Config)

protocol/cloudflare/connection_http2_behavior_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package cloudflare
44

55
import (
6+
"bytes"
67
"io"
78
"net/http"
89
"testing"
@@ -165,3 +166,26 @@ func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) {
165166
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
166167
}
167168
}
169+
170+
func TestHandleConfigurationUpdateDecodeFailureReturnsBadGateway(t *testing.T) {
171+
writer := &captureHTTP2Writer{}
172+
connection := &HTTP2Connection{
173+
logger: log.NewNOPFactory().NewLogger("test"),
174+
}
175+
request, err := http.NewRequest(http.MethodPost, "https://example.com", bytes.NewBufferString("{"))
176+
if err != nil {
177+
t.Fatal(err)
178+
}
179+
180+
connection.handleConfigurationUpdate(request, writer)
181+
182+
if writer.statusCode != http.StatusBadGateway {
183+
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, writer.statusCode)
184+
}
185+
if meta := writer.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflared {
186+
t.Fatalf("unexpected response meta: %q", meta)
187+
}
188+
if len(writer.body) != 0 {
189+
t.Fatalf("expected empty response body, got %q", string(writer.body))
190+
}
191+
}

protocol/cloudflare/connection_quic.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package cloudflare
44

55
import (
66
"context"
7-
"crypto/tls"
87
"fmt"
98
"io"
109
"net"
@@ -54,6 +53,7 @@ type QUICConnection struct {
5453
registrationResult *RegistrationResult
5554
onConnected func()
5655

56+
serveCtx context.Context
5757
serveCancel context.CancelFunc
5858
registrationClose sync.Once
5959
shutdownOnce sync.Once
@@ -109,11 +109,7 @@ func NewQUICConnection(
109109
return nil, E.Cause(err, "load Cloudflare root CAs")
110110
}
111111

112-
tlsConfig := &tls.Config{
113-
RootCAs: rootCAs,
114-
ServerName: quicEdgeSNI,
115-
NextProtos: []string{quicEdgeALPN},
116-
}
112+
tlsConfig := newEdgeTLSConfig(rootCAs, quicEdgeSNI, []string{quicEdgeALPN})
117113

118114
quicConfig := &quic.Config{
119115
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
@@ -190,6 +186,7 @@ func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error
190186
" (connection ", q.registrationResult.ConnectionID, ")")
191187

192188
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
189+
q.serveCtx = serveCtx
193190
q.serveCancel = serveCancel
194191

195192
errChan := make(chan error, 2)
@@ -321,9 +318,16 @@ func (q *QUICConnection) gracefulShutdown() {
321318
}
322319
q.closeRegistrationClient()
323320
if q.gracePeriod > 0 {
321+
waitCtx := q.serveCtx
322+
if waitCtx == nil {
323+
waitCtx = context.Background()
324+
}
324325
timer := time.NewTimer(q.gracePeriod)
325-
<-timer.C
326-
timer.Stop()
326+
defer timer.Stop()
327+
select {
328+
case <-timer.C:
329+
case <-waitCtx.Done():
330+
}
327331
}
328332
q.closeNow("graceful shutdown")
329333
})

protocol/cloudflare/control.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package cloudflare
44

55
import (
66
"context"
7+
"errors"
78
"io"
89
"net"
910
"runtime"
@@ -43,6 +44,29 @@ type registrationRPCClient interface {
4344
Close() error
4445
}
4546

47+
type permanentRegistrationError struct {
48+
Err error
49+
}
50+
51+
func (e *permanentRegistrationError) Error() string {
52+
if e == nil || e.Err == nil {
53+
return "permanent registration error"
54+
}
55+
return e.Err.Error()
56+
}
57+
58+
func (e *permanentRegistrationError) Unwrap() error {
59+
if e == nil {
60+
return nil
61+
}
62+
return e.Err
63+
}
64+
65+
func isPermanentRegistrationError(err error) bool {
66+
var permanentErr *permanentRegistrationError
67+
return errors.As(err, &permanentErr)
68+
}
69+
4670
// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream.
4771
// The stream should be the first QUIC stream (control stream).
4872
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient {
@@ -118,7 +142,7 @@ func (c *RegistrationClient) RegisterConnection(
118142
Delay: time.Duration(resultError.RetryAfter()),
119143
}
120144
}
121-
return nil, registrationError
145+
return nil, &permanentRegistrationError{Err: registrationError}
122146

123147
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
124148
connDetails, err := result.ConnectionDetails()

protocol/cloudflare/datagram_rpc_test.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ import (
1717
)
1818

1919
func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
20+
return newRegisterUDPSessionCallWithDstIP(t, []byte{127, 0, 0, 1}, traceContext)
21+
}
22+
23+
func newRegisterUDPSessionCallWithDstIP(t *testing.T, dstIP []byte, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
2024
t.Helper()
2125

2226
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
@@ -31,7 +35,7 @@ func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.Ses
3135
if err := params.SetSessionId(sessionID[:]); err != nil {
3236
t.Fatal(err)
3337
}
34-
if err := params.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
38+
if err := params.SetDstIp(dstIP); err != nil {
3539
t.Fatal(err)
3640
}
3741
params.SetDstPort(53)
@@ -197,3 +201,31 @@ func TestV2RPCUnregisterUDPSessionPropagatesMessage(t *testing.T) {
197201
t.Fatalf("expected close reason propagated from edge, got %q", reason)
198202
}
199203
}
204+
205+
func TestV2RPCRegisterUDPSessionRejectsMissingDestinationIP(t *testing.T) {
206+
inboundInstance := newLimitedInbound(t, 0)
207+
inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()}
208+
server := &cloudflaredServer{
209+
inbound: inboundInstance,
210+
muxer: NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger),
211+
ctx: context.Background(),
212+
logger: inboundInstance.logger,
213+
}
214+
call, readResult := newRegisterUDPSessionCallWithDstIP(t, nil, "")
215+
216+
if err := server.RegisterUdpSession(call); err != nil {
217+
t.Fatal(err)
218+
}
219+
220+
result, err := readResult()
221+
if err != nil {
222+
t.Fatal(err)
223+
}
224+
resultErr, err := result.Err()
225+
if err != nil {
226+
t.Fatal(err)
227+
}
228+
if resultErr != "missing destination IP" {
229+
t.Fatalf("unexpected result error %q", resultErr)
230+
}
231+
}

protocol/cloudflare/datagram_rpc_v3.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/sagernet/sing-box/log"
1111
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
1212
E "github.com/sagernet/sing/common/exceptions"
13+
"zombiezen.com/go/capnproto2/server"
1314
)
1415

1516
var (
@@ -38,6 +39,7 @@ func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager
3839
}
3940

4041
func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
42+
server.Ack(call.Options)
4143
version := call.Params.Version()
4244
configData, _ := call.Params.Config()
4345
updateResult := s.inbound.ApplyConfig(version, configData)

protocol/cloudflare/datagram_v2.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,16 @@ func (m *DatagramV2Muxer) RegisterSession(
164164
destinationPort uint16,
165165
closeAfterIdle time.Duration,
166166
) error {
167+
if destinationIP == nil {
168+
return E.New("missing destination IP")
169+
}
167170
var destinationAddr netip.Addr
168171
if ip4 := destinationIP.To4(); ip4 != nil {
169172
destinationAddr = netip.AddrFrom4([4]byte(ip4))
173+
} else if ip16 := destinationIP.To16(); ip16 != nil {
174+
destinationAddr = netip.AddrFrom16([16]byte(ip16))
170175
} else {
171-
destinationAddr = netip.AddrFrom16([16]byte(destinationIP.To16()))
176+
return E.New("invalid destination IP")
172177
}
173178
destination := netip.AddrPortFrom(destinationAddr, destinationPort)
174179

@@ -482,7 +487,11 @@ func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_reg
482487
return traceErr
483488
}
484489

485-
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
490+
if len(destinationIP) == 0 {
491+
err = E.New("missing destination IP")
492+
} else {
493+
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
494+
}
486495

487496
result, allocErr := call.Results.NewResult()
488497
if allocErr != nil {

0 commit comments

Comments
 (0)