Skip to content

Commit 18f1056

Browse files
committed
Skip kickWriteHandshake for server first protocols
1 parent e0c137e commit 18f1056

1 file changed

Lines changed: 33 additions & 25 deletions

File tree

route/conn.go

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
"github.com/sagernet/sing-box/adapter"
1515
"github.com/sagernet/sing-box/common/dialer"
16+
"github.com/sagernet/sing-box/common/sniff"
1617
"github.com/sagernet/sing-box/common/tlsfragment"
1718
"github.com/sagernet/sing-box/common/tlsspoof"
1819
C "github.com/sagernet/sing-box/constant"
@@ -140,11 +141,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
140141
}
141142
remoteConn = spoofConn
142143
}
144+
serverFirst := sniff.Skip(&metadata)
143145
var done atomic.Bool
144-
if m.kickWriteHandshake(ctx, conn, remoteConn, false, &done, onClose) {
146+
if m.kickWriteHandshake(ctx, conn, remoteConn, serverFirst, false, &done, onClose) {
145147
return
146148
}
147-
if m.kickWriteHandshake(ctx, remoteConn, conn, true, &done, onClose) {
149+
if m.kickWriteHandshake(ctx, remoteConn, conn, serverFirst, true, &done, onClose) {
148150
return
149151
}
150152
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
@@ -305,37 +307,43 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn,
305307
}
306308
}
307309

308-
func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool {
310+
func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, serverFirst bool, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool {
309311
if !N.NeedHandshakeForWrite(destination) {
310312
return false
311313
}
312314
var (
313-
cachedBuffer *buf.Buffer
315+
err error
314316
wrotePayload bool
315317
)
316-
sourceReader, readCounters := N.UnwrapCountReader(source, nil)
317-
destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil)
318-
if cachedReader, ok := sourceReader.(N.CachedReader); ok {
319-
cachedBuffer = cachedReader.ReadCached()
320-
}
321-
var err error
322-
if cachedBuffer != nil {
323-
wrotePayload = true
324-
dataLen := cachedBuffer.Len()
325-
_, err = destinationWriter.Write(cachedBuffer.Bytes())
326-
cachedBuffer.Release()
327-
if err == nil {
328-
for _, counter := range readCounters {
329-
counter(int64(dataLen))
330-
}
331-
for _, counter := range writeCounters {
332-
counter(int64(dataLen))
333-
}
334-
}
335-
} else {
318+
if serverFirst {
336319
_ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout))
337-
_, err = destinationWriter.Write(nil)
320+
_, err = destination.Write(nil)
338321
_ = destination.SetWriteDeadline(time.Time{})
322+
} else {
323+
var cachedBuffer *buf.Buffer
324+
sourceReader, readCounters := N.UnwrapCountReader(source, nil)
325+
destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil)
326+
if cachedReader, ok := sourceReader.(N.CachedReader); ok {
327+
cachedBuffer = cachedReader.ReadCached()
328+
}
329+
if cachedBuffer != nil {
330+
wrotePayload = true
331+
dataLen := cachedBuffer.Len()
332+
_, err = destinationWriter.Write(cachedBuffer.Bytes())
333+
cachedBuffer.Release()
334+
if err == nil {
335+
for _, counter := range readCounters {
336+
counter(int64(dataLen))
337+
}
338+
for _, counter := range writeCounters {
339+
counter(int64(dataLen))
340+
}
341+
}
342+
} else {
343+
_ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout))
344+
_, err = destinationWriter.Write(nil)
345+
_ = destination.SetWriteDeadline(time.Time{})
346+
}
339347
}
340348
if err == nil {
341349
return false

0 commit comments

Comments
 (0)