Skip to content

Commit 228eb2d

Browse files
committed
dns: Fix deadline
1 parent 31252a7 commit 228eb2d

3 files changed

Lines changed: 28 additions & 7 deletions

File tree

dns/transport/tcp.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"encoding/binary"
66
"io"
7+
"net"
8+
"time"
79

810
"github.com/sagernet/sing-box/adapter"
911
"github.com/sagernet/sing-box/common/dialer"
@@ -13,6 +15,7 @@ import (
1315
"github.com/sagernet/sing-box/option"
1416
"github.com/sagernet/sing/common"
1517
"github.com/sagernet/sing/common/buf"
18+
"github.com/sagernet/sing/common/bufio/deadline"
1619
E "github.com/sagernet/sing/common/exceptions"
1720
M "github.com/sagernet/sing/common/metadata"
1821
N "github.com/sagernet/sing/common/network"
@@ -71,6 +74,7 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
7174
return nil, E.Cause(err, "dial TCP connection")
7275
}
7376
defer conn.Close()
77+
defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))()
7478
err = WriteMessage(conn, 0, message)
7579
if err != nil {
7680
return nil, E.Cause(err, "write request")
@@ -82,6 +86,20 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
8286
return response, nil
8387
}
8488

89+
func setConnDeadline(ctx context.Context, conn net.Conn, needClose bool) func() {
90+
if needClose {
91+
stop := context.AfterFunc(ctx, func() {
92+
conn.Close()
93+
})
94+
return func() { stop() }
95+
}
96+
if d, ok := ctx.Deadline(); ok {
97+
conn.SetDeadline(d)
98+
return func() { conn.SetDeadline(time.Time{}) }
99+
}
100+
return func() {}
101+
}
102+
85103
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
86104
var responseLen uint16
87105
err := binary.Read(reader, binary.BigEndian, &responseLen)

dns/transport/tls.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package transport
22

33
import (
44
"context"
5-
"time"
65

76
"github.com/sagernet/sing-box/adapter"
87
"github.com/sagernet/sing-box/common/dialer"
@@ -12,6 +11,7 @@ import (
1211
"github.com/sagernet/sing-box/log"
1312
"github.com/sagernet/sing-box/option"
1413
"github.com/sagernet/sing/common"
14+
"github.com/sagernet/sing/common/bufio/deadline"
1515
E "github.com/sagernet/sing/common/exceptions"
1616
"github.com/sagernet/sing/common/logger"
1717
M "github.com/sagernet/sing/common/metadata"
@@ -38,7 +38,8 @@ type TLSTransport struct {
3838

3939
type tlsDNSConn struct {
4040
tls.Conn
41-
queryId uint16
41+
queryId uint16
42+
needDeadlineClose bool
4243
}
4344

4445
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
@@ -104,7 +105,10 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
104105
if err != nil {
105106
return nil, E.Cause(err, "dial TLS connection")
106107
}
107-
return &tlsDNSConn{Conn: tlsConn}, nil
108+
return &tlsDNSConn{
109+
Conn: tlsConn,
110+
needDeadlineClose: deadline.NeedAdditionalReadDeadline(tlsConn.NetConn()),
111+
}, nil
108112
})
109113
if err != nil {
110114
return nil, err
@@ -125,9 +129,7 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
125129
}
126130

127131
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
128-
if deadline, ok := ctx.Deadline(); ok {
129-
conn.SetDeadline(deadline)
130-
}
132+
defer setConnDeadline(ctx, conn, conn.needDeadlineClose)()
131133
conn.queryId++
132134
err := WriteMessage(conn, conn.queryId, message)
133135
if err != nil {
@@ -137,6 +139,5 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl
137139
if err != nil {
138140
return nil, E.Cause(err, "read response")
139141
}
140-
conn.SetDeadline(time.Time{})
141142
return response, nil
142143
}

dns/transport/udp.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/sagernet/sing-box/log"
1414
"github.com/sagernet/sing-box/option"
1515
"github.com/sagernet/sing/common/buf"
16+
"github.com/sagernet/sing/common/bufio/deadline"
1617
E "github.com/sagernet/sing/common/exceptions"
1718
"github.com/sagernet/sing/common/logger"
1819
M "github.com/sagernet/sing/common/metadata"
@@ -130,6 +131,7 @@ func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDN
130131
return nil, E.Cause(err, "dial TCP connection")
131132
}
132133
defer conn.Close()
134+
defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))()
133135
err = WriteMessage(conn, message.Id, message)
134136
if err != nil {
135137
return nil, E.Cause(err, "write request")

0 commit comments

Comments
 (0)