Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M.
if err != nil {
return nil, err
}
return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
Expand All @@ -97,7 +98,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil {
return nil, err
}
return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
}

func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
Expand Down
84 changes: 45 additions & 39 deletions client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,24 @@ func (c *clientConn) Upstream() any {
return c.Conn
}

var _ N.NetPacketConn = (*clientPacketConn)(nil)

type clientPacketConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}

func (c *clientPacketConn) NeedHandshake() bool {
return !c.requestWritten
}

func (c *clientPacketConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
Expand All @@ -125,14 +129,14 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) {
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(b) < int(length) {
return 0, io.ErrShortBuffer
}
return io.ReadFull(c.ExtendedConn, b[:length])
return io.ReadFull(c.conn, b[:length])
}

func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
Expand All @@ -156,7 +160,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
Expand All @@ -174,11 +178,11 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) {
return c.writeRequest(b)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
if err != nil {
return
}
return c.ExtendedConn.Write(b)
return c.conn.Write(b)
}

func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
Expand All @@ -190,11 +194,11 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
_, err = buffer.ReadFullFrom(c.conn, int(length))
return
}

Expand All @@ -211,7 +215,7 @@ func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
}
bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
return c.ExtendedConn.WriteBuffer(buffer)
return c.conn.WriteBuffer(buffer)
}

func (c *clientPacketConn) FrontHeadroom() int {
Expand All @@ -227,14 +231,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
c.responseRead = true
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadFull(c.ExtendedConn, p[:length])
n, err = io.ReadFull(c.conn, p[:length])
return
}

Expand All @@ -248,11 +252,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return c.writeRequest(p)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.ExtendedConn.Write(p)
return c.conn.Write(p)
}

func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
Expand All @@ -265,7 +269,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
}

func (c *clientPacketConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
return c.conn.LocalAddr()
}

func (c *clientPacketConn) RemoteAddr() net.Addr {
Expand All @@ -277,25 +281,27 @@ func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
}

func (c *clientPacketConn) Upstream() any {
return c.ExtendedConn
return c.conn
}

var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)

type clientPacketAddrConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
N.AbstractConn
conn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
readWaitOptions N.ReadWaitOptions
}

func (c *clientPacketAddrConn) NeedHandshake() bool {
return !c.requestWritten
}

func (c *clientPacketAddrConn) readResponse() error {
response, err := ReadStreamResponse(c.ExtendedConn)
response, err := ReadStreamResponse(c.conn)
if err != nil {
return err
}
Expand All @@ -313,7 +319,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
}
c.responseRead = true
}
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
Expand All @@ -323,14 +329,14 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
addr = destination.UDPAddr()
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
if cap(p) < int(length) {
return 0, nil, io.ErrShortBuffer
}
n, err = io.ReadFull(c.ExtendedConn, p[:length])
n, err = io.ReadFull(c.conn, p[:length])
return
}

Expand Down Expand Up @@ -360,7 +366,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
common.Error(buffer.Write(payload)),
)
}
_, err = c.ExtendedConn.Write(buffer.Bytes())
_, err = c.conn.Write(buffer.Bytes())
if err != nil {
return
}
Expand All @@ -378,15 +384,15 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro
return c.writeRequest(p, M.SocksaddrFromNet(addr))
}
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
if err != nil {
return
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
if err != nil {
return
}
return c.ExtendedConn.Write(p)
return c.conn.Write(p)
}

func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
Expand All @@ -397,16 +403,16 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
_, err = buffer.ReadFullFrom(c.conn, int(length))
return
}

Expand All @@ -428,11 +434,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
return err
}
common.Must(binary.Write(header, binary.BigEndian, uint16(bLen)))
return c.ExtendedConn.WriteBuffer(buffer)
return c.conn.WriteBuffer(buffer)
}

func (c *clientPacketAddrConn) LocalAddr() net.Addr {
return c.ExtendedConn.LocalAddr()
return c.conn.LocalAddr()
}

func (c *clientPacketAddrConn) FrontHeadroom() int {
Expand All @@ -444,5 +450,5 @@ func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
}

func (c *clientPacketAddrConn) Upstream() any {
return c.ExtendedConn
return c.conn
}
73 changes: 73 additions & 0 deletions client_conn_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package mux

import (
"encoding/binary"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ N.PacketReadWaiter = (*clientPacketConn)(nil)

func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}

var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil)

func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if !c.responseRead {
err = c.readResponse()
if err != nil {
return
}
c.responseRead = true
}
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
if err != nil {
return
}
var length uint16
err = binary.Read(c.conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
c.readWaitOptions.PostReturn(buffer)
return
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ go 1.18

require (
github.com/hashicorp/yamux v0.1.1
github.com/sagernet/sing v0.2.18
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7
golang.org/x/net v0.19.0
golang.org/x/sys v0.15.0
)
Expand Down
10 changes: 5 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.18 h1:2Ce4dl0pkWft+4914NGXPb8OiQpgA8UHQ9xFOmgvKuY=
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0=
github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759 h1:BZfmPnZ2n0zD0YZb7UnAAaZ0Ib5riPgKvl5Jasz3LA4=
github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ=
github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
Expand Down
3 changes: 1 addition & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
}
var group task.Group
group.Append0(func(_ context.Context) error {
var stream net.Conn
for {
stream, err = session.Accept()
stream, err := session.Accept()
if err != nil {
return err
}
Expand Down