diff --git a/.golangci.yml b/.golangci.yml index 5182e3e9..fe8290af 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -143,6 +143,23 @@ linters: - godox path: .*.go text: replace with standard maps package + # Package names that intentionally shadow stdlib for API clarity + - linters: + - revive + path: ^pkg/errors/.*\.go + text: var-naming + - linters: + - revive + path: ^pkg/math/.*\.go + text: var-naming + - linters: + - revive + path: ^test/net/.*\.go + text: var-naming + - linters: + - revive + path: ^net/.*\.go + text: var-naming # Which file paths to exclude: they will be analyzed, but issues from them won't be reported. # "/" will be replaced by the current OS file path separator to properly work on Windows. # Default: [] diff --git a/message/codes/codes_test.go b/message/codes/codes_test.go index ce469fde..eec1f43f 100644 --- a/message/codes/codes_test.go +++ b/message/codes/codes_test.go @@ -3,6 +3,7 @@ package codes import ( "encoding/json" "strconv" + "strings" "testing" "github.com/stretchr/testify/require" @@ -16,14 +17,16 @@ func TestJSONUnmarshal(t *testing.T) { require.NoError(t, err) require.Equal(t, want, got) - inNumeric := "[" + var sb strings.Builder + sb.WriteString("[") for i, c := range want { if i > 0 { - inNumeric += "," + sb.WriteString(",") } - inNumeric += strconv.FormatUint(uint64(c), 10) + sb.WriteString(strconv.FormatUint(uint64(c), 10)) } - inNumeric += "]" + sb.WriteString("]") + inNumeric := sb.String() err = json.Unmarshal([]byte(inNumeric), &got) require.NoError(t, err) require.Equal(t, want, got) diff --git a/udp/client.go b/udp/client.go index 4a68b358..82cb1373 100644 --- a/udp/client.go +++ b/udp/client.go @@ -91,6 +91,7 @@ func Client(conn *net.UDPConn, opts ...Option) *client.Conn { context.Background(), l, addr, + nil, cfg.MaxMessageSize, cfg.MTU, cfg.CloseSocket, diff --git a/udp/server/server.go b/udp/server/server.go index 1ec8100a..d344e96c 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -159,7 +159,13 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { } } buf = buf[:n] - cc, err := s.getConn(l, raddr, true) + // Extract the original destination IP from the control message + // This is the IP that the client sent to (public IP on Fly.io) + var originalDstIP net.IP + if cm != nil && cm.Dst != nil { + originalDstIP = cm.Dst + } + cc, err := s.getConn(l, raddr, originalDstIP, true) if err != nil { s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) continue @@ -254,7 +260,7 @@ func getClose(cc *client.Conn) func() { return closeFn } -func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) { +func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, originalDstIP net.IP) (cc *client.Conn, created bool) { s.connsMutex.Lock() defer s.connsMutex.Unlock() key := raddr.String() @@ -295,6 +301,7 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) ( s.doneCtx, udpConn, raddr, + originalDstIP, s.cfg.MaxMessageSize, s.cfg.MTU, false, @@ -345,8 +352,8 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) ( return cc, true } -func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { - cc, created := s.getOrCreateConn(l, raddr) +func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, originalDstIP net.IP, firstTime bool) (*client.Conn, error) { + cc, created := s.getOrCreateConn(l, raddr, originalDstIP) if created { if s.cfg.OnNewConn != nil { s.cfg.OnNewConn(cc) @@ -367,7 +374,7 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) closeFn() } if firstTime { - return s.getConn(l, raddr, false) + return s.getConn(l, raddr, originalDstIP, false) } return nil, errors.New("connection is closed") } @@ -380,5 +387,6 @@ func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) { // server is not started/stopped return nil, errors.New("server is not running") } - return s.getConn(l, addr, true) + // NewConn is used for outbound connections, so we don't have an original destination IP + return s.getConn(l, addr, nil, true) } diff --git a/udp/server/session.go b/udp/server/session.go index 99d3bf95..a3cf5e9c 100644 --- a/udp/server/session.go +++ b/udp/server/session.go @@ -24,8 +24,9 @@ type Session struct { connection *coapNet.UDPConn doneCancel context.CancelFunc - cancel context.CancelFunc - raddr *net.UDPAddr + cancel context.CancelFunc + raddr *net.UDPAddr + originalDstIP net.IP // Stores the original destination IP from received packets (public IP on Fly.io) mutex sync.Mutex maxMessageSize uint32 @@ -39,6 +40,7 @@ func NewSession( doneCtx context.Context, connection *coapNet.UDPConn, raddr *net.UDPAddr, + originalDstIP net.IP, maxMessageSize uint32, mtu uint16, closeSocket bool, @@ -50,6 +52,7 @@ func NewSession( cancel: cancel, connection: connection, raddr: raddr, + originalDstIP: originalDstIP, maxMessageSize: maxMessageSize, mtu: mtu, closeSocket: closeSocket, @@ -109,7 +112,21 @@ func (s *Session) WriteMessage(req *pool.Message) error { if err != nil { return fmt.Errorf("cannot marshal: %w", err) } - return s.connection.WriteWithOptions(data, coapNet.WithContext(req.Context()), coapNet.WithRemoteAddr(s.raddr), coapNet.WithControlMessage(req.ControlMessage())) + + // Get or create the control message with the correct source address + cm := req.ControlMessage() + if cm == nil { + cm = &coapNet.ControlMessage{} + } + // Set the source address to the original destination IP (public IP that client sent to) + // This ensures responses are sent from the same IP the client sent to, which is critical + // for environments like Fly.io where packets arrive at a public IP but the socket is + // bound to an internal private address. + if s.originalDstIP != nil && cm.Src == nil { + cm.Src = s.originalDstIP + } + + return s.connection.WriteWithOptions(data, coapNet.WithContext(req.Context()), coapNet.WithRemoteAddr(s.raddr), coapNet.WithControlMessage(cm)) } // WriteMulticastMessage sends multicast to the remote multicast address.