2424package dialers
2525
2626import (
27+ "context"
2728 "io"
2829 "net"
30+ "net/netip"
2931 "sync"
3032 "sync/atomic"
3133 "syscall"
@@ -44,15 +46,22 @@ func (zeroNetAddr) String() string { return "none" }
4446
4547const maxRetryCount = 3
4648
49+ // ippPins maintains a limited-time mapping between ip:port addresses and dialer IDs.
50+ // TODO: invalidate cache on network changes.
51+ // TODO: with context.TODO, expmap's reaper goroutine will leak.
52+ var ippPins = core .NewSieve [netip.AddrPort , string ](context .TODO (), desync_cache_ttl )
53+
4754// retrier implements the DuplexConn interface and must
4855// be typecastable to *net.TCPConn (see: xdial.DialTCP)
4956// inheritance: go.dev/play/p/mMiQgXsPM7Y
5057type retrier struct {
51- dialers []protect.RDialer
52- dialerOpts settings.DialerOpts
53- multidial bool
54- raddr net.Addr
55- laddr net.Addr // laddr may be nil; TCPAddr.IP may be nil.
58+ dialers []protect.RDialer
59+ dialerOpts settings.DialerOpts
60+ nextDialerIdx int
61+ multidial bool
62+
63+ raddr net.Addr
64+ laddr net.Addr // laddr may be nil; TCPAddr.IP may be nil.
5665
5766 // Flags indicating whether the caller has called CloseRead and CloseWrite.
5867 readDone atomic.Bool
@@ -80,9 +89,8 @@ type retrier struct {
8089 // and is cleared when the first byte is received.
8190 tee []byte
8291 // retryErr is set to the error from the last retry, if any.
83- retryErr error
84- retryCount uint8
85- dialerCount int
92+ retryErr error
93+ retryCount uint8
8694 // Flag indicating when retry is finished or unnecessary.
8795 retryDoneCh chan struct {} // always unbuffered
8896}
@@ -116,7 +124,7 @@ func (r *retrier) retryCompleted() bool {
116124
117125func (r * retrier ) canRetryLocked () bool {
118126 if r .multidial {
119- return r .dialerCount < len (r .dialers )
127+ return r .nextDialerIdx < len (r .dialers )
120128 } else {
121129 return r .retryCount < maxRetryCount
122130 }
@@ -162,9 +170,27 @@ func dialerOptsForRace() settings.DialerOpts {
162170 }
163171}
164172
173+ func reprioritize (ds []protect.RDialer , ipp netip.AddrPort ) []protect.RDialer {
174+ // reprioritize the dialers based on the IP:port pair
175+ if ! ipp .IsValid () {
176+ return ds
177+ }
178+ id , ok := ippPins .Get (ipp )
179+ if ! ok || len (id ) <= 0 {
180+ return ds
181+ }
182+ for i , d := range ds {
183+ if d .ID () == id {
184+ ds [i ], ds [0 ] = ds [0 ], ds [i ]
185+ break
186+ }
187+ }
188+ return ds
189+ }
190+
165191func DialAny (ds []protect.RDialer , laddr , raddr net.Addr ) (* retrier , error ) {
166192 r := & retrier {
167- dialers : ds ,
193+ dialers : reprioritize ( ds , asAddrPort ( raddr )) ,
168194 dialerOpts : dialerOptsForRace (),
169195 multidial : true ,
170196 laddr : laddr , // may be nil
@@ -274,7 +300,6 @@ func (r *retrier) dialLocked() (c core.DuplexConn, err error) {
274300 begin := time .Now ()
275301 c , err = r .doDialLocked (strat )
276302 rtt := time .Since (begin )
277- r .dialerCount ++
278303
279304 r .conn = c // c may be nil
280305 r .timeout = calcTimeout (rtt )
@@ -290,7 +315,13 @@ func (r *retrier) dialLocked() (c core.DuplexConn, err error) {
290315func (r * retrier ) doDialLocked (dialStrat int32 ) (_ core.DuplexConn , err error ) {
291316 var conn * net.TCPConn
292317
293- di := r .dialerCount % len (r .dialers )
318+ di := r .nextDialerIdx
319+ if r .multidial {
320+ if di >= len (r .dialers ) {
321+ return nil , errNoDialer
322+ }
323+ }
324+ r .nextDialerIdx = di + 1
294325
295326 // r.raddr may be nil or laddr.IP may be nil.
296327 switch dialStrat {
@@ -395,7 +426,7 @@ func (r *retrier) Read(buf []byte) (n int, err error) {
395426 err = core .UniqErr (err , retryerr )
396427 }
397428 logeor (retryerr , log .I )("retrier: read# %d + (mult? %t / c: %d): [%s<=%s] %d; err? %v" ,
398- r .retryCount , r .multidial , r .dialerCount , laddr (c ), r .raddr , n , retryerr )
429+ r .retryCount , r .multidial , r .nextDialerIdx , laddr (c ), r .raddr , n , retryerr )
399430 }
400431 if c != nil && core .IsNotNil (c ) {
401432 _ = c .SetReadDeadline (r .readDeadline )
@@ -404,7 +435,8 @@ func (r *retrier) Read(buf []byte) (n int, err error) {
404435 r .tee = nil // discard teed data
405436 return
406437 }
407- logeor (err , note )("retrier: read: already retried! [%s<=%s] %d; err? %v" , laddr (c ), r .raddr , n , err )
438+ logeor (err , note )("retrier: read: already retried! [%s<=%s] %s; err? %v" ,
439+ laddr (c ), r .raddr , n , err )
408440 } // else: just one read is enough; no retry needed
409441 return
410442}
@@ -514,11 +546,24 @@ func (r *retrier) ReadFrom(reader io.Reader) (bytes int64, err error) {
514546 return bytes , io .ErrUnexpectedEOF
515547 }
516548
549+ pinned := false
550+ pinnedID := ""
551+ if r .multidial {
552+ if ipp := asAddrPort (r .raddr ); ipp .IsValid () {
553+ // cache the dialer ID for the IP:port pair
554+ di := max (0 , r .nextDialerIdx - 1 ) % len (r .dialers )
555+ pinnedID = r .dialers [di ].ID ()
556+ ippPins .Put (ipp , pinnedID )
557+ pinned = true
558+ }
559+ }
560+
517561 var b int64
518562 b , err = c .ReadFrom (reader )
519563 bytes += b
520564
521- logeif (err )("retrier: readfrom: done; sz: %d; err: %v" , bytes , err )
565+ logeif (err )("retrier: readfrom: done (id: %s, pinned? %t); sz: %d; err: %v" ,
566+ pinnedID , pinned , bytes , err )
522567 return
523568}
524569
0 commit comments