@@ -15,6 +15,8 @@ import (
1515 "net"
1616 "net/netip"
1717 "os"
18+ "regexp"
19+ "strconv"
1820 "strings"
1921 "syscall"
2022 "time"
@@ -246,6 +248,9 @@ var (
246248 errServerTemporarilyMisbehaving = errors .New ("server misbehaving" )
247249 errCanceled = errors .New ("operation was canceled" )
248250 errTimeout = errors .New ("i/o timeout" )
251+ errNumericPort = errors .New ("port must be numeric" )
252+ errNoSuitableAddress = errors .New ("no suitable address found" )
253+ errMissingAddress = errors .New ("missing address" )
249254)
250255
251256func (net * Net ) LookupHost (host string ) (addrs []string , err error ) {
@@ -688,3 +693,116 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
688693 }
689694 return saddrs , nil
690695}
696+
697+ func partialDeadline (now , deadline time.Time , addrsRemaining int ) (time.Time , error ) {
698+ if deadline .IsZero () {
699+ return deadline , nil
700+ }
701+ timeRemaining := deadline .Sub (now )
702+ if timeRemaining <= 0 {
703+ return time.Time {}, errTimeout
704+ }
705+ timeout := timeRemaining / time .Duration (addrsRemaining )
706+ const saneMinimum = 2 * time .Second
707+ if timeout < saneMinimum {
708+ if timeRemaining < saneMinimum {
709+ timeout = timeRemaining
710+ } else {
711+ timeout = saneMinimum
712+ }
713+ }
714+ return now .Add (timeout ), nil
715+ }
716+
717+ var protoSplitter = regexp .MustCompile (`^(tcp)(4|6)?$` )
718+
719+ func (tnet * Net ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
720+ if ctx == nil {
721+ panic ("nil context" )
722+ }
723+ var acceptV4 , acceptV6 bool
724+ matches := protoSplitter .FindStringSubmatch (network )
725+ if matches == nil {
726+ return nil , & net.OpError {Op : "dial" , Err : net .UnknownNetworkError (network )}
727+ } else if len (matches [2 ]) == 0 {
728+ acceptV4 = true
729+ acceptV6 = true
730+ } else {
731+ acceptV4 = matches [2 ][0 ] == '4'
732+ acceptV6 = ! acceptV4
733+ }
734+ var host string
735+ var port int
736+ var sport string
737+ var err error
738+ host , sport , err = net .SplitHostPort (address )
739+ if err != nil {
740+ return nil , & net.OpError {Op : "dial" , Err : err }
741+ }
742+ port , err = strconv .Atoi (sport )
743+ if err != nil || port < 0 || port > 65535 {
744+ return nil , & net.OpError {Op : "dial" , Err : errNumericPort }
745+ }
746+ allAddr , err := tnet .LookupContextHost (ctx , host )
747+ if err != nil {
748+ return nil , & net.OpError {Op : "dial" , Err : err }
749+ }
750+ var addrs []netip.AddrPort
751+ for _ , addr := range allAddr {
752+ ip , err := netip .ParseAddr (addr )
753+ if err == nil && ((ip .Is4 () && acceptV4 ) || (ip .Is6 () && acceptV6 )) {
754+ addrs = append (addrs , netip .AddrPortFrom (ip , uint16 (port )))
755+ }
756+ }
757+ if len (addrs ) == 0 && len (allAddr ) != 0 {
758+ return nil , & net.OpError {Op : "dial" , Err : errNoSuitableAddress }
759+ }
760+
761+ var firstErr error
762+ for i , addr := range addrs {
763+ select {
764+ case <- ctx .Done ():
765+ err := ctx .Err ()
766+ if err == context .Canceled {
767+ err = errCanceled
768+ } else if err == context .DeadlineExceeded {
769+ err = errTimeout
770+ }
771+ return nil , & net.OpError {Op : "dial" , Err : err }
772+ default :
773+ }
774+
775+ dialCtx := ctx
776+ if deadline , hasDeadline := ctx .Deadline (); hasDeadline {
777+ partialDeadline , err := partialDeadline (time .Now (), deadline , len (addrs )- i )
778+ if err != nil {
779+ if firstErr == nil {
780+ firstErr = & net.OpError {Op : "dial" , Err : err }
781+ }
782+ break
783+ }
784+ if partialDeadline .Before (deadline ) {
785+ var cancel context.CancelFunc
786+ dialCtx , cancel = context .WithDeadline (ctx , partialDeadline )
787+ defer cancel ()
788+ }
789+ }
790+
791+ var c net.Conn
792+ c , err = tnet .DialContextTCPAddrPort (dialCtx , addr )
793+ if err == nil {
794+ return c , nil
795+ }
796+ if firstErr == nil {
797+ firstErr = err
798+ }
799+ }
800+ if firstErr == nil {
801+ firstErr = & net.OpError {Op : "dial" , Err : errMissingAddress }
802+ }
803+ return nil , firstErr
804+ }
805+
806+ func (tnet * Net ) Dial (network , address string ) (net.Conn , error ) {
807+ return tnet .DialContext (context .Background (), network , address )
808+ }
0 commit comments