Skip to content

Commit 2d71bd2

Browse files
TheoyorTheodor Rauchjkralik
authored
FIX: response source address on [::]-bound servers are wrong (#629)
* fix: on binding the server to [::] the response happened to come from different addresses than sent to. * fix: add .idea to gitignore * feat: store source addr in conn * fix: did not swap * fix: remove vscode folder from gitignore * fix: ignore mcast adresses * fix: modify signatures in NewConn in tests * fix: data race condition * fix: don't use dummy adresses in test * fix: missed nil check * fix: don't upsert if nothing to do * fix: prevent against race conditions and dangling pointers * fix: comment * fix for golangci-lint * fix(udp): avoid multicast conn key leaks and stale control state * fix the two gocritic warnings --------- Co-authored-by: Theodor Rauch <theodor.rauch@ml-pa.com> Co-authored-by: Jozef Kralik <jozef.kralik@plgd.dev>
1 parent cb7cf24 commit 2d71bd2

7 files changed

Lines changed: 195 additions & 20 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ client
1111
!client/
1212
vendor/
1313
v3/
14-
14+
.idea/
1515
# Test binary, build with `go test -c`
1616
*.test
1717

.vscode/launch.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "0.2.0",
3+
"configurations": [
4+
{
5+
"name": "Run simple server",
6+
"type": "go",
7+
"request": "launch",
8+
"mode": "auto",
9+
"program": "${workspaceFolder}/examples/observe/server"
10+
}
11+
]
12+
}

udp/client/conn.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ type Conn struct {
206206
msgID atomic.Uint32
207207
blockwiseSZX blockwise.SZX
208208

209+
localAddr atomic.Pointer[net.IP]
210+
interfaceIndex atomic.Int64
211+
209212
/*
210213
An outstanding interaction is either a CON for which an ACK has not
211214
yet been received but is still expected (message layer) or a request
@@ -550,6 +553,8 @@ func (cc *Conn) writeMessageAsync(req *pool.Message) error {
550553
func (cc *Conn) writeMessage(req *pool.Message) error {
551554
req.UpsertType(message.Confirmable)
552555
req.UpsertMessageID(cc.GetMessageID())
556+
cc.upsertControlInformation(req)
557+
553558
if req.Type() != message.Confirmable {
554559
return cc.writeMessageAsync(req)
555560
}
@@ -755,6 +760,7 @@ func (cc *Conn) processResponse(reqType message.Type, reqMessageID int32, w *res
755760
w.Message().SetType(message.Acknowledgement)
756761
w.Message().SetMessageID(reqMessageID)
757762
w.Message().SetToken(nil)
763+
758764
err := cc.addResponseToCache(w.Message())
759765
if err != nil {
760766
return fmt.Errorf("cannot cache response: %w", err)
@@ -824,7 +830,8 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
824830
}()
825831
resp := cc.AcquireMessage(cc.Context())
826832
resp.SetToken(req.Token())
827-
ifIndex := req.ControlMessage().GetIfIndex()
833+
cc.setControlInformation(req.ControlMessage())
834+
828835
w := responsewriter.New(resp, cc, req.Options()...)
829836
defer func() {
830837
cc.ReleaseMessage(w.Message())
@@ -839,7 +846,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
839846
// nothing to send
840847
return
841848
}
842-
upsertInterfaceToMessage(w.Message(), ifIndex)
849+
cc.upsertControlInformation(w.Message())
843850
errW := cc.writeMessageAsync(w.Message())
844851
if errW != nil {
845852
cc.closeConnection()
@@ -851,13 +858,37 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess
851858
cc.sendPong(w, r)
852859
}
853860

854-
func upsertInterfaceToMessage(m *pool.Message, ifIndex int) {
855-
if ifIndex >= 1 {
856-
cm := coapNet.ControlMessage{
857-
IfIndex: ifIndex,
858-
}
859-
m.UpsertControlMessage(&cm)
861+
func (cc *Conn) setControlInformation(cm *coapNet.ControlMessage) {
862+
if cm == nil {
863+
cc.interfaceIndex.Store(0)
864+
cc.localAddr.Store(nil)
865+
return
866+
}
867+
868+
cc.interfaceIndex.Store(int64(cm.GetIfIndex()))
869+
if len(cm.Dst) == 0 || cm.Dst.IsMulticast() {
870+
cc.localAddr.Store(nil)
871+
return
860872
}
873+
874+
dst := make(net.IP, len(cm.Dst))
875+
copy(dst, cm.Dst)
876+
cc.localAddr.Store(&dst)
877+
}
878+
879+
func (cc *Conn) upsertControlInformation(msg *pool.Message) {
880+
ifIndex := int(cc.interfaceIndex.Load())
881+
localAddrPtr := cc.localAddr.Load()
882+
if ifIndex < 1 && localAddrPtr == nil {
883+
return
884+
}
885+
886+
var localAddr net.IP
887+
if localAddrPtr != nil {
888+
localAddr = *localAddrPtr
889+
}
890+
891+
msg.UpsertControlMessage(&coapNet.ControlMessage{IfIndex: ifIndex, Src: localAddr})
861892
}
862893

863894
func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
@@ -872,7 +903,6 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
872903
elem.ReleaseMessage(cc)
873904
resp := cc.AcquireMessage(cc.Context())
874905
resp.SetToken(r.Token())
875-
upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex())
876906
w := responsewriter.New(resp, cc, r.Options()...)
877907
defer func() {
878908
cc.ReleaseMessage(w.Message())

udp/client/controlmessage_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package client
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
coapNet "github.com/plgd-dev/go-coap/v3/net"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestSetControlInformationNil(t *testing.T) {
12+
var cc Conn
13+
addr := net.ParseIP("192.0.2.1").To4()
14+
cc.localAddr.Store(&addr)
15+
cc.interfaceIndex.Store(7)
16+
17+
cc.setControlInformation(nil)
18+
19+
require.Equal(t, int64(0), cc.interfaceIndex.Load())
20+
require.Nil(t, cc.localAddr.Load())
21+
}
22+
23+
func TestSetControlInformationUnicastCopiesAddress(t *testing.T) {
24+
var cc Conn
25+
dst := net.ParseIP("2001:db8::1")
26+
cm := &coapNet.ControlMessage{Dst: append(net.IP(nil), dst...), IfIndex: 9}
27+
28+
cc.setControlInformation(cm)
29+
30+
require.Equal(t, int64(9), cc.interfaceIndex.Load())
31+
stored := cc.localAddr.Load()
32+
require.NotNil(t, stored)
33+
expected := append(net.IP(nil), dst...)
34+
require.True(t, stored.Equal(expected))
35+
36+
cm.Dst[0] ^= 0xff
37+
require.True(t, stored.Equal(expected))
38+
}
39+
40+
func TestSetControlInformationMulticastClearsAddress(t *testing.T) {
41+
var cc Conn
42+
addr := net.ParseIP("192.0.2.1").To4()
43+
cc.localAddr.Store(&addr)
44+
45+
cc.setControlInformation(&coapNet.ControlMessage{Dst: net.ParseIP("ff02::1"), IfIndex: 11})
46+
47+
require.Equal(t, int64(11), cc.interfaceIndex.Load())
48+
require.Nil(t, cc.localAddr.Load())
49+
}

udp/server/server.go

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,19 @@ func (s *Server) Serve(l *coapNet.UDPConn) error {
159159
}
160160
}
161161
buf = buf[:n]
162-
cc, err := s.getConn(l, raddr, true)
162+
163+
// UDPConn.LocalAddr() only takes into account the address it is bound to.
164+
// In the case of a wildcard address, the actual destination address is in the control message.
165+
// On server-initiated exchanges, listener's LocalAddr can be used as the client has no assumptions of the source.
166+
laddr, err := s.getListenerLocalAddr(l)
167+
if err != nil {
168+
return err
169+
}
170+
if cm != nil && len(cm.Dst) > 0 && !cm.Dst.IsMulticast() {
171+
laddr.IP = cm.Dst
172+
}
173+
174+
cc, err := s.getConn(l, raddr, laddr, true)
163175
if err != nil {
164176
s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err))
165177
continue
@@ -178,6 +190,15 @@ func (s *Server) getListener() *coapNet.UDPConn {
178190
return s.listen
179191
}
180192

193+
func (s *Server) getListenerLocalAddr(l *coapNet.UDPConn) (*net.UDPAddr, error) {
194+
localAddr, ok := l.LocalAddr().(*net.UDPAddr)
195+
if !ok || localAddr == nil {
196+
return nil, fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr())
197+
}
198+
laddrVal := *localAddr
199+
return &laddrVal, nil
200+
}
201+
181202
// Stop stops server without wait of ends Serve function.
182203
func (s *Server) Stop() {
183204
s.cancel()
@@ -254,10 +275,21 @@ func getClose(cc *client.Conn) func() {
254275
return closeFn
255276
}
256277

257-
func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) {
278+
func getConnKey(raddr *net.UDPAddr, laddr *net.UDPAddr) string {
279+
normalizedLocalAddr := *laddr
280+
if len(normalizedLocalAddr.IP) > 0 && normalizedLocalAddr.IP.IsMulticast() {
281+
// Multicast destination address does not identify a unique server-side source address.
282+
// Normalize it to avoid creating one conn key per multicast group.
283+
normalizedLocalAddr.IP = nil
284+
normalizedLocalAddr.Zone = ""
285+
}
286+
return raddr.String() + "-" + normalizedLocalAddr.String()
287+
}
288+
289+
func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) {
258290
s.connsMutex.Lock()
259291
defer s.connsMutex.Unlock()
260-
key := raddr.String()
292+
key := getConnKey(raddr, laddr)
261293
cc = s.conns[key]
262294

263295
if cc != nil {
@@ -345,8 +377,19 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (
345377
return cc, true
346378
}
347379

348-
func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) {
349-
cc, created := s.getOrCreateConn(l, raddr)
380+
func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) {
381+
if raddr == nil {
382+
return nil, errors.New("invalid remote address")
383+
}
384+
if laddr == nil {
385+
var err error
386+
laddr, err = s.getListenerLocalAddr(l)
387+
if err != nil {
388+
return nil, err
389+
}
390+
}
391+
392+
cc, created := s.getOrCreateConn(l, raddr, laddr)
350393
if created {
351394
if s.cfg.OnNewConn != nil {
352395
s.cfg.OnNewConn(cc)
@@ -367,18 +410,30 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool)
367410
closeFn()
368411
}
369412
if firstTime {
370-
return s.getConn(l, raddr, false)
413+
return s.getConn(l, raddr, laddr, false)
371414
}
372415
return nil, errors.New("connection is closed")
373416
}
374417
return cc, nil
375418
}
376419

377-
func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) {
420+
// NewConn creates or gets a connection for the provided remote address.
421+
//
422+
// Optional laddr may be used to pin a concrete local address when the listener is bound to a wildcard address.
423+
// If laddr is omitted or nil, listener's local address is used.
424+
func (s *Server) NewConn(addr *net.UDPAddr, laddr ...*net.UDPAddr) (*client.Conn, error) {
425+
if len(laddr) > 1 {
426+
return nil, fmt.Errorf("invalid number of local addresses: %d", len(laddr))
427+
}
428+
var localAddr *net.UDPAddr
429+
if len(laddr) == 1 {
430+
localAddr = laddr[0]
431+
}
432+
378433
l := s.getListener()
379434
if l == nil {
380435
// server is not started/stopped
381436
return nil, errors.New("server is not running")
382437
}
383-
return s.getConn(l, addr, true)
438+
return s.getConn(l, addr, localAddr, true)
384439
}

udp/server/server_key_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package server
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestGetConnKeyIgnoresMulticastLocalAddress(t *testing.T) {
11+
raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830}
12+
mcastV6 := &net.UDPAddr{IP: net.ParseIP("ff02::fd"), Port: 5683}
13+
mcastV4 := &net.UDPAddr{IP: net.ParseIP("224.0.1.187"), Port: 5683}
14+
normalized := &net.UDPAddr{Port: 5683}
15+
16+
require.Equal(t, getConnKey(raddr, mcastV6), getConnKey(raddr, normalized))
17+
require.Equal(t, getConnKey(raddr, mcastV4), getConnKey(raddr, normalized))
18+
}
19+
20+
func TestGetConnKeyKeepsUnicastLocalAddressDistinct(t *testing.T) {
21+
raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830}
22+
laddrA := &net.UDPAddr{IP: net.ParseIP("2001:db8::10"), Port: 5683}
23+
laddrB := &net.UDPAddr{IP: net.ParseIP("2001:db8::11"), Port: 5683}
24+
25+
require.NotEqual(t, getConnKey(raddr, laddrA), getConnKey(raddr, laddrB))
26+
}

udp/server_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,9 @@ func TestServerNewClient(t *testing.T) {
467467

468468
time.Sleep(time.Second)
469469

470+
_, err = s1.NewConn(nil)
471+
require.ErrorContains(t, err, "invalid remote address")
472+
470473
cc, err := s1.NewConn(peer)
471474
require.NoError(t, err)
472475

@@ -478,7 +481,7 @@ func TestServerNewClient(t *testing.T) {
478481
require.NoError(t, err)
479482

480483
// repeat ping - new client should be created
481-
cc, err = s1.NewConn(peer)
484+
cc, err = s1.NewConn(peer, nil)
482485
require.NoError(t, err)
483486
err = cc.Ping(ctx)
484487
require.NoError(t, err)
@@ -626,7 +629,7 @@ func TestServerReconnectNewClient(t *testing.T) {
626629
}
627630

628631
// new client
629-
cc, err = s1.NewConn(peer)
632+
cc, err = s1.NewConn(peer, nil)
630633
require.NoError(t, err)
631634
ctx, cancel = context.WithTimeout(context.Background(), time.Second*1)
632635
defer cancel()

0 commit comments

Comments
 (0)