Skip to content

Commit aa458ed

Browse files
committed
Implement read waiter for UDP
1 parent 6be79e9 commit aa458ed

5 files changed

Lines changed: 127 additions & 42 deletions

File tree

client.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M.
8686
if err != nil {
8787
return nil, err
8888
}
89-
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
89+
extendedConn := bufio.NewExtendedConn(stream)
90+
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
9091
default:
9192
return nil, E.Extend(N.ErrUnknownNetwork, network)
9293
}
@@ -97,7 +98,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
9798
if err != nil {
9899
return nil, err
99100
}
100-
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
101+
extendedConn := bufio.NewExtendedConn(stream)
102+
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
101103
}
102104

103105
func (c *Client) openStream(ctx context.Context) (net.Conn, error) {

client_conn.go

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,24 @@ func (c *clientConn) Upstream() any {
9393
return c.Conn
9494
}
9595

96+
var _ N.NetPacketConn = (*clientPacketConn)(nil)
97+
9698
type clientPacketConn struct {
97-
N.ExtendedConn
98-
access sync.Mutex
99-
destination M.Socksaddr
100-
requestWritten bool
101-
responseRead bool
99+
N.AbstractConn
100+
conn N.ExtendedConn
101+
access sync.Mutex
102+
destination M.Socksaddr
103+
requestWritten bool
104+
responseRead bool
105+
readWaitOptions N.ReadWaitOptions
102106
}
103107

104108
func (c *clientPacketConn) NeedHandshake() bool {
105109
return !c.requestWritten
106110
}
107111

108112
func (c *clientPacketConn) readResponse() error {
109-
response, err := ReadStreamResponse(c.ExtendedConn)
113+
response, err := ReadStreamResponse(c.conn)
110114
if err != nil {
111115
return err
112116
}
@@ -125,14 +129,14 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) {
125129
c.responseRead = true
126130
}
127131
var length uint16
128-
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
132+
err = binary.Read(c.conn, binary.BigEndian, &length)
129133
if err != nil {
130134
return
131135
}
132136
if cap(b) < int(length) {
133137
return 0, io.ErrShortBuffer
134138
}
135-
return io.ReadFull(c.ExtendedConn, b[:length])
139+
return io.ReadFull(c.conn, b[:length])
136140
}
137141

138142
func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
@@ -156,7 +160,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
156160
common.Error(buffer.Write(payload)),
157161
)
158162
}
159-
_, err = c.ExtendedConn.Write(buffer.Bytes())
163+
_, err = c.conn.Write(buffer.Bytes())
160164
if err != nil {
161165
return
162166
}
@@ -174,11 +178,11 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) {
174178
return c.writeRequest(b)
175179
}
176180
}
177-
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
181+
err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
178182
if err != nil {
179183
return
180184
}
181-
return c.ExtendedConn.Write(b)
185+
return c.conn.Write(b)
182186
}
183187

184188
func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
@@ -190,11 +194,11 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
190194
c.responseRead = true
191195
}
192196
var length uint16
193-
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
197+
err = binary.Read(c.conn, binary.BigEndian, &length)
194198
if err != nil {
195199
return
196200
}
197-
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
201+
_, err = buffer.ReadFullFrom(c.conn, int(length))
198202
return
199203
}
200204

@@ -211,7 +215,7 @@ func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
211215
}
212216
bLen := buffer.Len()
213217
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
214-
return c.ExtendedConn.WriteBuffer(buffer)
218+
return c.conn.WriteBuffer(buffer)
215219
}
216220

217221
func (c *clientPacketConn) FrontHeadroom() int {
@@ -227,14 +231,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
227231
c.responseRead = true
228232
}
229233
var length uint16
230-
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
234+
err = binary.Read(c.conn, binary.BigEndian, &length)
231235
if err != nil {
232236
return
233237
}
234238
if cap(p) < int(length) {
235239
return 0, nil, io.ErrShortBuffer
236240
}
237-
n, err = io.ReadFull(c.ExtendedConn, p[:length])
241+
n, err = io.ReadFull(c.conn, p[:length])
238242
return
239243
}
240244

@@ -248,11 +252,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
248252
return c.writeRequest(p)
249253
}
250254
}
251-
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
255+
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
252256
if err != nil {
253257
return
254258
}
255-
return c.ExtendedConn.Write(p)
259+
return c.conn.Write(p)
256260
}
257261

258262
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
@@ -265,7 +269,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
265269
}
266270

267271
func (c *clientPacketConn) LocalAddr() net.Addr {
268-
return c.ExtendedConn.LocalAddr()
272+
return c.conn.LocalAddr()
269273
}
270274

271275
func (c *clientPacketConn) RemoteAddr() net.Addr {
@@ -277,25 +281,27 @@ func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
277281
}
278282

279283
func (c *clientPacketConn) Upstream() any {
280-
return c.ExtendedConn
284+
return c.conn
281285
}
282286

283287
var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
284288

285289
type clientPacketAddrConn struct {
286-
N.ExtendedConn
287-
access sync.Mutex
288-
destination M.Socksaddr
289-
requestWritten bool
290-
responseRead bool
290+
N.AbstractConn
291+
conn N.ExtendedConn
292+
access sync.Mutex
293+
destination M.Socksaddr
294+
requestWritten bool
295+
responseRead bool
296+
readWaitOptions N.ReadWaitOptions
291297
}
292298

293299
func (c *clientPacketAddrConn) NeedHandshake() bool {
294300
return !c.requestWritten
295301
}
296302

297303
func (c *clientPacketAddrConn) readResponse() error {
298-
response, err := ReadStreamResponse(c.ExtendedConn)
304+
response, err := ReadStreamResponse(c.conn)
299305
if err != nil {
300306
return err
301307
}
@@ -313,7 +319,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
313319
}
314320
c.responseRead = true
315321
}
316-
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
322+
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
317323
if err != nil {
318324
return
319325
}
@@ -323,14 +329,14 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
323329
addr = destination.UDPAddr()
324330
}
325331
var length uint16
326-
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
332+
err = binary.Read(c.conn, binary.BigEndian, &length)
327333
if err != nil {
328334
return
329335
}
330336
if cap(p) < int(length) {
331337
return 0, nil, io.ErrShortBuffer
332338
}
333-
n, err = io.ReadFull(c.ExtendedConn, p[:length])
339+
n, err = io.ReadFull(c.conn, p[:length])
334340
return
335341
}
336342

@@ -360,7 +366,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
360366
common.Error(buffer.Write(payload)),
361367
)
362368
}
363-
_, err = c.ExtendedConn.Write(buffer.Bytes())
369+
_, err = c.conn.Write(buffer.Bytes())
364370
if err != nil {
365371
return
366372
}
@@ -378,15 +384,15 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
378384
return c.writeRequest(p, M.SocksaddrFromNet(addr))
379385
}
380386
}
381-
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
387+
err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
382388
if err != nil {
383389
return
384390
}
385-
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
391+
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
386392
if err != nil {
387393
return
388394
}
389-
return c.ExtendedConn.Write(p)
395+
return c.conn.Write(p)
390396
}
391397

392398
func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
@@ -397,16 +403,16 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
397403
}
398404
c.responseRead = true
399405
}
400-
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
406+
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
401407
if err != nil {
402408
return
403409
}
404410
var length uint16
405-
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
411+
err = binary.Read(c.conn, binary.BigEndian, &length)
406412
if err != nil {
407413
return
408414
}
409-
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
415+
_, err = buffer.ReadFullFrom(c.conn, int(length))
410416
return
411417
}
412418

@@ -428,11 +434,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
428434
return err
429435
}
430436
common.Must(binary.Write(header, binary.BigEndian, uint16(bLen)))
431-
return c.ExtendedConn.WriteBuffer(buffer)
437+
return c.conn.WriteBuffer(buffer)
432438
}
433439

434440
func (c *clientPacketAddrConn) LocalAddr() net.Addr {
435-
return c.ExtendedConn.LocalAddr()
441+
return c.conn.LocalAddr()
436442
}
437443

438444
func (c *clientPacketAddrConn) FrontHeadroom() int {
@@ -444,5 +450,5 @@ func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
444450
}
445451

446452
func (c *clientPacketAddrConn) Upstream() any {
447-
return c.ExtendedConn
453+
return c.conn
448454
}

client_conn_wait.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package mux
2+
3+
import (
4+
"encoding/binary"
5+
6+
"github.com/sagernet/sing/common/buf"
7+
M "github.com/sagernet/sing/common/metadata"
8+
N "github.com/sagernet/sing/common/network"
9+
)
10+
11+
var _ N.PacketReadWaiter = (*clientPacketConn)(nil)
12+
13+
func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
14+
c.readWaitOptions = options
15+
return false
16+
}
17+
18+
func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
19+
if !c.responseRead {
20+
err = c.readResponse()
21+
if err != nil {
22+
return
23+
}
24+
c.responseRead = true
25+
}
26+
var length uint16
27+
err = binary.Read(c.conn, binary.BigEndian, &length)
28+
if err != nil {
29+
return
30+
}
31+
buffer = c.readWaitOptions.NewPacketBuffer()
32+
_, err = buffer.ReadFullFrom(c.conn, int(length))
33+
if err != nil {
34+
buffer.Release()
35+
return nil, M.Socksaddr{}, err
36+
}
37+
c.readWaitOptions.PostReturn(buffer)
38+
return
39+
}
40+
41+
var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)
42+
43+
func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
44+
c.readWaitOptions = options
45+
return false
46+
}
47+
48+
func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
49+
if !c.responseRead {
50+
err = c.readResponse()
51+
if err != nil {
52+
return
53+
}
54+
c.responseRead = true
55+
}
56+
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
57+
if err != nil {
58+
return
59+
}
60+
var length uint16
61+
err = binary.Read(c.conn, binary.BigEndian, &length)
62+
if err != nil {
63+
return
64+
}
65+
buffer = c.readWaitOptions.NewPacketBuffer()
66+
_, err = buffer.ReadFullFrom(c.conn, int(length))
67+
if err != nil {
68+
buffer.Release()
69+
return nil, M.Socksaddr{}, err
70+
}
71+
c.readWaitOptions.PostReturn(buffer)
72+
return
73+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.18
44

55
require (
66
github.com/hashicorp/yamux v0.1.1
7-
github.com/sagernet/sing v0.2.20
7+
github.com/sagernet/sing v0.3.0-rc.2
88
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
99
golang.org/x/net v0.19.0
1010
golang.org/x/sys v0.15.0

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbg
33
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
44
github.com/sagernet/sing v0.2.20 h1:ckcCB/5xu8G8wElNeH74IF6Soac5xWN+eQUXRuonjPQ=
55
github.com/sagernet/sing v0.2.20/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
6+
github.com/sagernet/sing v0.3.0-rc.1 h1:XcdCC9CcLNfMSlObIQPjxyzenGQT2R1sGLHvdwDmQFU=
7+
github.com/sagernet/sing v0.3.0-rc.1/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
8+
github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug=
9+
github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
610
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
711
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
812
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=

0 commit comments

Comments
 (0)