Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ linters:
- godox
path: .*.go
text: replace with standard maps package
# Package names that intentionally shadow stdlib for API clarity
- linters:
- revive
path: ^pkg/errors/.*\.go
text: var-naming
- linters:
- revive
path: ^pkg/math/.*\.go
text: var-naming
- linters:
- revive
path: ^test/net/.*\.go
text: var-naming
- linters:
- revive
path: ^net/.*\.go
text: var-naming
# Which file paths to exclude: they will be analyzed, but issues from them won't be reported.
# "/" will be replaced by the current OS file path separator to properly work on Windows.
# Default: []
Expand Down
11 changes: 7 additions & 4 deletions message/codes/codes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package codes
import (
"encoding/json"
"strconv"
"strings"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -16,14 +17,16 @@ func TestJSONUnmarshal(t *testing.T) {
require.NoError(t, err)
require.Equal(t, want, got)

inNumeric := "["
var sb strings.Builder
sb.WriteString("[")
for i, c := range want {
if i > 0 {
inNumeric += ","
sb.WriteString(",")
}
inNumeric += strconv.FormatUint(uint64(c), 10)
sb.WriteString(strconv.FormatUint(uint64(c), 10))
}
inNumeric += "]"
sb.WriteString("]")
inNumeric := sb.String()
err = json.Unmarshal([]byte(inNumeric), &got)
require.NoError(t, err)
require.Equal(t, want, got)
Expand Down
1 change: 1 addition & 0 deletions udp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func Client(conn *net.UDPConn, opts ...Option) *client.Conn {
context.Background(),
l,
addr,
nil,
cfg.MaxMessageSize,
cfg.MTU,
cfg.CloseSocket,
Expand Down
20 changes: 14 additions & 6 deletions udp/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ func (s *Server) Serve(l *coapNet.UDPConn) error {
}
}
buf = buf[:n]
cc, err := s.getConn(l, raddr, true)
// Extract the original destination IP from the control message
// This is the IP that the client sent to (public IP on Fly.io)
var originalDstIP net.IP
if cm != nil && cm.Dst != nil {
originalDstIP = cm.Dst
}
cc, err := s.getConn(l, raddr, originalDstIP, true)
if err != nil {
s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err))
continue
Expand Down Expand Up @@ -254,7 +260,7 @@ func getClose(cc *client.Conn) func() {
return closeFn
}

func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) {
func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, originalDstIP net.IP) (cc *client.Conn, created bool) {
s.connsMutex.Lock()
defer s.connsMutex.Unlock()
key := raddr.String()
Expand Down Expand Up @@ -295,6 +301,7 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (
s.doneCtx,
udpConn,
raddr,
originalDstIP,
s.cfg.MaxMessageSize,
s.cfg.MTU,
false,
Expand Down Expand Up @@ -345,8 +352,8 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (
return cc, true
}

func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) {
cc, created := s.getOrCreateConn(l, raddr)
func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, originalDstIP net.IP, firstTime bool) (*client.Conn, error) {
cc, created := s.getOrCreateConn(l, raddr, originalDstIP)
if created {
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
Expand All @@ -367,7 +374,7 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool)
closeFn()
}
if firstTime {
return s.getConn(l, raddr, false)
return s.getConn(l, raddr, originalDstIP, false)
}
return nil, errors.New("connection is closed")
}
Expand All @@ -380,5 +387,6 @@ func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) {
// server is not started/stopped
return nil, errors.New("server is not running")
}
return s.getConn(l, addr, true)
// NewConn is used for outbound connections, so we don't have an original destination IP
return s.getConn(l, addr, nil, true)
}
23 changes: 20 additions & 3 deletions udp/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ type Session struct {
connection *coapNet.UDPConn
doneCancel context.CancelFunc

cancel context.CancelFunc
raddr *net.UDPAddr
cancel context.CancelFunc
raddr *net.UDPAddr
originalDstIP net.IP // Stores the original destination IP from received packets (public IP on Fly.io)

mutex sync.Mutex
maxMessageSize uint32
Expand All @@ -39,6 +40,7 @@ func NewSession(
doneCtx context.Context,
connection *coapNet.UDPConn,
raddr *net.UDPAddr,
originalDstIP net.IP,
maxMessageSize uint32,
mtu uint16,
closeSocket bool,
Expand All @@ -50,6 +52,7 @@ func NewSession(
cancel: cancel,
connection: connection,
raddr: raddr,
originalDstIP: originalDstIP,
maxMessageSize: maxMessageSize,
mtu: mtu,
closeSocket: closeSocket,
Expand Down Expand Up @@ -109,7 +112,21 @@ func (s *Session) WriteMessage(req *pool.Message) error {
if err != nil {
return fmt.Errorf("cannot marshal: %w", err)
}
return s.connection.WriteWithOptions(data, coapNet.WithContext(req.Context()), coapNet.WithRemoteAddr(s.raddr), coapNet.WithControlMessage(req.ControlMessage()))

// Get or create the control message with the correct source address
cm := req.ControlMessage()
if cm == nil {
cm = &coapNet.ControlMessage{}
}
// Set the source address to the original destination IP (public IP that client sent to)
// This ensures responses are sent from the same IP the client sent to, which is critical
// for environments like Fly.io where packets arrive at a public IP but the socket is
// bound to an internal private address.
if s.originalDstIP != nil && cm.Src == nil {
cm.Src = s.originalDstIP
}

return s.connection.WriteWithOptions(data, coapNet.WithContext(req.Context()), coapNet.WithRemoteAddr(s.raddr), coapNet.WithControlMessage(cm))
}

// WriteMulticastMessage sends multicast to the remote multicast address.
Expand Down
Loading