@@ -13,6 +13,7 @@ import (
1313
1414 "github.com/sagernet/sing-box/adapter"
1515 "github.com/sagernet/sing-box/common/dialer"
16+ "github.com/sagernet/sing-box/common/sniff"
1617 "github.com/sagernet/sing-box/common/tlsfragment"
1718 "github.com/sagernet/sing-box/common/tlsspoof"
1819 C "github.com/sagernet/sing-box/constant"
@@ -140,11 +141,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
140141 }
141142 remoteConn = spoofConn
142143 }
144+ serverFirst := sniff .Skip (& metadata )
143145 var done atomic.Bool
144- if m .kickWriteHandshake (ctx , conn , remoteConn , false , & done , onClose ) {
146+ if m .kickWriteHandshake (ctx , conn , remoteConn , serverFirst , false , & done , onClose ) {
145147 return
146148 }
147- if m .kickWriteHandshake (ctx , remoteConn , conn , true , & done , onClose ) {
149+ if m .kickWriteHandshake (ctx , remoteConn , conn , serverFirst , true , & done , onClose ) {
148150 return
149151 }
150152 go m .connectionCopy (ctx , conn , remoteConn , false , & done , onClose )
@@ -305,37 +307,43 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn,
305307 }
306308}
307309
308- func (m * ConnectionManager ) kickWriteHandshake (ctx context.Context , source net.Conn , destination net.Conn , direction bool , done * atomic.Bool , onClose N.CloseHandlerFunc ) bool {
310+ func (m * ConnectionManager ) kickWriteHandshake (ctx context.Context , source net.Conn , destination net.Conn , serverFirst bool , direction bool , done * atomic.Bool , onClose N.CloseHandlerFunc ) bool {
309311 if ! N .NeedHandshakeForWrite (destination ) {
310312 return false
311313 }
312314 var (
313- cachedBuffer * buf. Buffer
315+ err error
314316 wrotePayload bool
315317 )
316- sourceReader , readCounters := N .UnwrapCountReader (source , nil )
317- destinationWriter , writeCounters := N .UnwrapCountWriter (destination , nil )
318- if cachedReader , ok := sourceReader .(N.CachedReader ); ok {
319- cachedBuffer = cachedReader .ReadCached ()
320- }
321- var err error
322- if cachedBuffer != nil {
323- wrotePayload = true
324- dataLen := cachedBuffer .Len ()
325- _ , err = destinationWriter .Write (cachedBuffer .Bytes ())
326- cachedBuffer .Release ()
327- if err == nil {
328- for _ , counter := range readCounters {
329- counter (int64 (dataLen ))
330- }
331- for _ , counter := range writeCounters {
332- counter (int64 (dataLen ))
333- }
334- }
335- } else {
318+ if serverFirst {
336319 _ = destination .SetWriteDeadline (time .Now ().Add (C .ReadPayloadTimeout ))
337- _ , err = destinationWriter .Write (nil )
320+ _ , err = destination .Write (nil )
338321 _ = destination .SetWriteDeadline (time.Time {})
322+ } else {
323+ var cachedBuffer * buf.Buffer
324+ sourceReader , readCounters := N .UnwrapCountReader (source , nil )
325+ destinationWriter , writeCounters := N .UnwrapCountWriter (destination , nil )
326+ if cachedReader , ok := sourceReader .(N.CachedReader ); ok {
327+ cachedBuffer = cachedReader .ReadCached ()
328+ }
329+ if cachedBuffer != nil {
330+ wrotePayload = true
331+ dataLen := cachedBuffer .Len ()
332+ _ , err = destinationWriter .Write (cachedBuffer .Bytes ())
333+ cachedBuffer .Release ()
334+ if err == nil {
335+ for _ , counter := range readCounters {
336+ counter (int64 (dataLen ))
337+ }
338+ for _ , counter := range writeCounters {
339+ counter (int64 (dataLen ))
340+ }
341+ }
342+ } else {
343+ _ = destination .SetWriteDeadline (time .Now ().Add (C .ReadPayloadTimeout ))
344+ _ , err = destinationWriter .Write (nil )
345+ _ = destination .SetWriteDeadline (time.Time {})
346+ }
339347 }
340348 if err == nil {
341349 return false
0 commit comments