@@ -30,22 +30,62 @@ import (
3030 "errors"
3131 "net"
3232 "net/netip"
33+ "sync"
3334 "time"
3435
3536 x "github.com/celzero/firestack/intra/backend"
36- "github.com/celzero/firestack/intra/dnsx"
37- "github.com/celzero/firestack/intra/log"
38- "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
39-
4037 "github.com/celzero/firestack/intra/core"
38+ "github.com/celzero/firestack/intra/dnsx"
4139 "github.com/celzero/firestack/intra/ipn"
40+ "github.com/celzero/firestack/intra/log"
4241 "github.com/celzero/firestack/intra/netstack"
4342 "github.com/celzero/firestack/intra/protect"
4443 "github.com/celzero/firestack/intra/settings"
44+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
4545)
4646
4747type tcpHandler struct {
4848 * baseHandler
49+ nat * tcpNat
50+ }
51+
52+ type tcpNat struct {
53+ sync.Mutex
54+ m map [string ]map [netip.AddrPort ]netip.AddrPort // proxyID => src => ext
55+ }
56+
57+ func newTCPNat () * tcpNat {
58+ return & tcpNat {m : make (map [string ]map [netip.AddrPort ]netip.AddrPort )}
59+ }
60+
61+ func (t * tcpNat ) assoc (pid string , src , ext netip.AddrPort ) {
62+ if t == nil || len (pid ) == 0 || ! sameFamily (src .Addr (), ext .Addr ()) {
63+ return
64+ }
65+ t .Lock ()
66+ defer t .Unlock ()
67+
68+ m := t .m [pid ]
69+ if m == nil {
70+ m = make (map [netip.AddrPort ]netip.AddrPort )
71+ t .m [pid ] = m
72+ }
73+ m [src ] = ext
74+ }
75+
76+ func (t * tcpNat ) lookup (pid string , src netip.AddrPort ) (zz netip.AddrPort , ok bool ) {
77+ if t == nil || len (pid ) == 0 || ! src .IsValid () {
78+ return
79+ }
80+ t .Lock ()
81+ defer t .Unlock ()
82+
83+ if m := t .m [pid ]; m != nil {
84+ if ext , ok := m [src ]; ok {
85+ return ext , true
86+ }
87+ }
88+ return
4989}
5090
5191type ioinfo struct {
@@ -77,6 +117,7 @@ func NewTCPHandler(pctx context.Context, resolver dnsx.Resolver, prox ipn.ProxyP
77117
78118 h := & tcpHandler {
79119 baseHandler : newBaseHandler (pctx , dnsx .NetTypeTCP , resolver , prox , listener ),
120+ nat : newTCPNat (),
80121 }
81122
82123 go h .processSummaries ()
@@ -164,6 +205,38 @@ func (h *tcpHandler) handshakeIfNeededOrClose(gconn *netstack.GTCPConn, smm *Soc
164205 return allow , nil
165206}
166207
208+ func (h * tcpHandler ) natAssoc (pid string , src netip.AddrPort , addr net.Addr ) {
209+ if ext := netAddrPort (addr ); ext .IsValid () {
210+ h .nat .assoc (pid , src , ext )
211+ }
212+ }
213+
214+ func (h * tcpHandler ) natLookup (pid string , src , target netip.AddrPort ) (zz netip.AddrPort ) {
215+ if ext , ok := h .nat .lookup (pid , src ); ok && sameFamily (ext .Addr (), target .Addr ()) {
216+ return ext
217+ }
218+ return
219+ }
220+
221+ func netAddrPort (addr net.Addr ) (zz netip.AddrPort ) {
222+ if addr == nil {
223+ return
224+ }
225+ if v , ok := addr .(* net.TCPAddr ); ok {
226+ return v .AddrPort ()
227+ } else if ap , err := netip .ParseAddrPort (addr .String ()); err == nil {
228+ return ap
229+ }
230+ return
231+ }
232+
233+ func sameFamily (a , b netip.Addr ) bool {
234+ if ! a .IsValid () || ! b .IsValid () {
235+ return false
236+ }
237+ return a .Is4 () == b .Is4 ()
238+ }
239+
167240// Proxy implements netstack.GTCPConnHandler
168241// It must be called from a goroutine.
169242func (h * tcpHandler ) Proxy (gconn * netstack.GTCPConn , src , target netip.AddrPort ) (open bool ) {
@@ -250,7 +323,6 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
250323 }
251324
252325 cont := true
253- boundSrc := makeAnyAddrPort (src )
254326 // pick all realips to connect to
255327 for i , dstipp := range actualTargets {
256328 // dstipp may be v4 or v6 regardless of target addr
@@ -270,7 +342,7 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
270342 continue
271343 }
272344
273- if cont , err = h .handle (px , gconn , boundSrc , dstipp , delayForHappyEyeballs , smm ); err == nil {
345+ if cont , err = h .handle (px , gconn , src , dstipp , delayForHappyEyeballs , smm ); err == nil {
274346 return allow // smm instead queued by handle() => forward()
275347 } else {
276348 end := time .Since (smm .start )
@@ -292,15 +364,26 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
292364}
293365
294366// handle connects to the target via the proxy, and pipes data between the src, target; thread-safe.
295- func (h * tcpHandler ) handle (px ipn.Proxy , gconn * netstack.GTCPConn , boundSrc , target netip.AddrPort , errOnNoRoute bool , smm * SocketSummary ) (cont bool , err error ) {
367+ func (h * tcpHandler ) handle (px ipn.Proxy , gconn * netstack.GTCPConn , src , target netip.AddrPort , errOnNoRoute bool , smm * SocketSummary ) (cont bool , err error ) {
296368 cont = true
297369 stop := ! cont
298370 targetstr := target .String ()
299371
300372 if errOnNoRoute {
301373 if canroute := px .Router ().Contains (x .StrOf (targetstr )); ! canroute {
302374 // make sure to not delay in HappyEyeballs scenario?
303- return cont , log .WE ("proxy(%s) has no route to %s" , pidstr (px ), targetstr )
375+ return cont , log .WE ("proxy(%s) has no route to %s (<= %s)" , pidstr (px ), targetstr , src )
376+ }
377+ }
378+
379+ bindAddr := makeAnyAddrPort (src )
380+ eim := settings .EndpointIndependentMapping .Load ()
381+ portfwd := settings .PortForward .Load ()
382+ maybeDialBind := eim || portfwd
383+
384+ if eim {
385+ if nataddr := h .natLookup (pidstr (px ), src , target ); nataddr .IsValid () {
386+ bindAddr = nataddr
304387 }
305388 }
306389
@@ -310,16 +393,20 @@ func (h *tcpHandler) handle(px ipn.Proxy, gconn *netstack.GTCPConn, boundSrc, ta
310393 start := time .Now ()
311394
312395 if settings .Debug {
313- log .VV ("tcp: %s dial %s: attempt: %s [%s] => %s for %s" ,
314- smm .ID , pidstr (px ), gconn .LocalAddr (), boundSrc , targetstr , smm .UID )
396+ log .VV ("tcp: %s dial %s: attempt: %s [%s [%s] ] => %s for %s" ,
397+ smm .ID , pidstr (px ), src , gconn .LocalAddr (), bindAddr , targetstr , smm .UID )
315398 }
316399
317400 // github.com/google/gvisor/blob/5ba35f516b5c2/test/benchmarks/tcp/tcp_proxy.go#L359
318401 // ref: stackoverflow.com/questions/63656117
319402 // ref: stackoverflow.com/questions/40328025
320- if settings .PortForward .Load () {
321- pc , err = px .Dialer ().DialBind ("tcp" , boundSrc .String (), targetstr )
322- } else {
403+ if maybeDialBind {
404+ pc , err = px .Dialer ().DialBind ("tcp" , bindAddr .String (), targetstr )
405+ maybeDialBind = err == nil
406+ logwif (! maybeDialBind )("tcp: %s dialbind ok? %t (%s [%s] => %s via %s); err? %v" ,
407+ smm .ID , maybeDialBind , src , bindAddr , targetstr , pidstr (px ), err )
408+ }
409+ if ! maybeDialBind {
323410 pc , err = px .Dialer ().Dial ("tcp" , targetstr )
324411 }
325412 if err == nil {
@@ -349,8 +436,8 @@ func (h *tcpHandler) handle(px ipn.Proxy, gconn *netstack.GTCPConn, boundSrc, ta
349436
350437 if err != nil {
351438 clos (pc )
352- log .W ("tcp: err dialing %s proxy(%s) to dst(%v ) for %s: %v" ,
353- smm .ID , smm .PID , smm .Target , smm .UID , err )
439+ log .W ("tcp: err dialing %s proxy(%s) %v [%v] => %v (bind? %t ) for %s: %v" ,
440+ smm .ID , smm .PID , src , bindAddr , smm .Target , maybeDialBind , smm .UID , err )
354441 return cont , err
355442 }
356443
@@ -362,6 +449,10 @@ func (h *tcpHandler) handle(px ipn.Proxy, gconn *netstack.GTCPConn, boundSrc, ta
362449 core .Go ("tcp.forward." + smm .ID , func () {
363450 h .listener .PostFlow (smm .postMark ())
364451 h .forward (gconn , rwext {dst , tcptimeout }, smm ) // src always *gonet.TCPConn
452+ // TODO assoc if forward was successful
453+ if eim {
454+ h .natAssoc (smm .PID , src , dst .LocalAddr ())
455+ }
365456 })
366457 return cont , nil // handled; takes ownership of src
367458}
0 commit comments