@@ -159,7 +159,19 @@ func (s *Server) Serve(l *coapNet.UDPConn) error {
159159 }
160160 }
161161 buf = buf [:n ]
162- cc , err := s .getConn (l , raddr , true )
162+
163+ // UDPConn.LocalAddr() only takes into account the address it is bound to.
164+ // In the case of a wildcard address, the actual destination address is in the control message.
165+ // On server-initiated exchanges, listener's LocalAddr can be used as the client has no assumptions of the source.
166+ laddr , err := s .getListenerLocalAddr (l )
167+ if err != nil {
168+ return err
169+ }
170+ if cm != nil && len (cm .Dst ) > 0 && ! cm .Dst .IsMulticast () {
171+ laddr .IP = cm .Dst
172+ }
173+
174+ cc , err := s .getConn (l , raddr , laddr , true )
163175 if err != nil {
164176 s .cfg .Errors (fmt .Errorf ("%v: cannot get client connection: %w" , raddr , err ))
165177 continue
@@ -178,6 +190,15 @@ func (s *Server) getListener() *coapNet.UDPConn {
178190 return s .listen
179191}
180192
193+ func (s * Server ) getListenerLocalAddr (l * coapNet.UDPConn ) (* net.UDPAddr , error ) {
194+ localAddr , ok := l .LocalAddr ().(* net.UDPAddr )
195+ if ! ok || localAddr == nil {
196+ return nil , fmt .Errorf ("unexpected listener local addr type: %T" , l .LocalAddr ())
197+ }
198+ laddrVal := * localAddr
199+ return & laddrVal , nil
200+ }
201+
181202// Stop stops server without wait of ends Serve function.
182203func (s * Server ) Stop () {
183204 s .cancel ()
@@ -254,10 +275,21 @@ func getClose(cc *client.Conn) func() {
254275 return closeFn
255276}
256277
257- func (s * Server ) getOrCreateConn (udpConn * coapNet.UDPConn , raddr * net.UDPAddr ) (cc * client.Conn , created bool ) {
278+ func getConnKey (raddr * net.UDPAddr , laddr * net.UDPAddr ) string {
279+ normalizedLocalAddr := * laddr
280+ if len (normalizedLocalAddr .IP ) > 0 && normalizedLocalAddr .IP .IsMulticast () {
281+ // Multicast destination address does not identify a unique server-side source address.
282+ // Normalize it to avoid creating one conn key per multicast group.
283+ normalizedLocalAddr .IP = nil
284+ normalizedLocalAddr .Zone = ""
285+ }
286+ return raddr .String () + "-" + normalizedLocalAddr .String ()
287+ }
288+
289+ func (s * Server ) getOrCreateConn (udpConn * coapNet.UDPConn , raddr * net.UDPAddr , laddr * net.UDPAddr ) (cc * client.Conn , created bool ) {
258290 s .connsMutex .Lock ()
259291 defer s .connsMutex .Unlock ()
260- key := raddr . String ( )
292+ key := getConnKey ( raddr , laddr )
261293 cc = s .conns [key ]
262294
263295 if cc != nil {
@@ -345,8 +377,19 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (
345377 return cc , true
346378}
347379
348- func (s * Server ) getConn (l * coapNet.UDPConn , raddr * net.UDPAddr , firstTime bool ) (* client.Conn , error ) {
349- cc , created := s .getOrCreateConn (l , raddr )
380+ func (s * Server ) getConn (l * coapNet.UDPConn , raddr * net.UDPAddr , laddr * net.UDPAddr , firstTime bool ) (* client.Conn , error ) {
381+ if raddr == nil {
382+ return nil , errors .New ("invalid remote address" )
383+ }
384+ if laddr == nil {
385+ var err error
386+ laddr , err = s .getListenerLocalAddr (l )
387+ if err != nil {
388+ return nil , err
389+ }
390+ }
391+
392+ cc , created := s .getOrCreateConn (l , raddr , laddr )
350393 if created {
351394 if s .cfg .OnNewConn != nil {
352395 s .cfg .OnNewConn (cc )
@@ -367,18 +410,30 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool)
367410 closeFn ()
368411 }
369412 if firstTime {
370- return s .getConn (l , raddr , false )
413+ return s .getConn (l , raddr , laddr , false )
371414 }
372415 return nil , errors .New ("connection is closed" )
373416 }
374417 return cc , nil
375418}
376419
377- func (s * Server ) NewConn (addr * net.UDPAddr ) (* client.Conn , error ) {
420+ // NewConn creates or gets a connection for the provided remote address.
421+ //
422+ // Optional laddr may be used to pin a concrete local address when the listener is bound to a wildcard address.
423+ // If laddr is omitted or nil, listener's local address is used.
424+ func (s * Server ) NewConn (addr * net.UDPAddr , laddr ... * net.UDPAddr ) (* client.Conn , error ) {
425+ if len (laddr ) > 1 {
426+ return nil , fmt .Errorf ("invalid number of local addresses: %d" , len (laddr ))
427+ }
428+ var localAddr * net.UDPAddr
429+ if len (laddr ) == 1 {
430+ localAddr = laddr [0 ]
431+ }
432+
378433 l := s .getListener ()
379434 if l == nil {
380435 // server is not started/stopped
381436 return nil , errors .New ("server is not running" )
382437 }
383- return s .getConn (l , addr , true )
438+ return s .getConn (l , addr , localAddr , true )
384439}
0 commit comments