@@ -21,6 +21,7 @@ import (
2121 "os/signal"
2222 "strconv"
2323 "strings"
24+ "sync"
2425 "time"
2526
2627 "github.com/mendersoftware/go-lib-micro/ws"
@@ -94,7 +95,8 @@ type PortForwardCmd struct {
9495 sessionID string
9596 bindingHost string
9697 portMappings []portMapping
97- recvChans map [string ]chan * ws.ProtoMsg
98+ recvChanMu sync.RWMutex
99+ recvChans map [string ]chan <- * ws.ProtoMsg
98100 running bool
99101 stop chan struct {}
100102 err error
@@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro
184186 deviceID : args [0 ],
185187 bindingHost : bindingHost ,
186188 portMappings : portMappings ,
187- recvChans : make (map [string ]chan * ws.ProtoMsg ),
189+ recvChans : make (map [string ]chan <- * ws.ProtoMsg ),
188190 stop : make (chan struct {}),
189191 }, nil
190192}
@@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error {
198200 }
199201}
200202
203+ func (c * PortForwardCmd ) registerRecvChan (connectionID string , recvChan chan <- * ws.ProtoMsg ) {
204+ c .recvChanMu .Lock ()
205+ defer c .recvChanMu .Unlock ()
206+ c .recvChans [connectionID ] = recvChan
207+ }
208+
201209func (c * PortForwardCmd ) run () error {
202210 ctx , cancelContext := context .WithCancel (context .Background ())
203211 defer cancelContext ()
@@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error {
239247 if err != nil {
240248 return err
241249 }
242- go forwarder .Run (ctx , c .sessionID , msgChan , c .recvChans )
250+ go forwarder .Run (ctx , c .sessionID , msgChan , c .registerRecvChan )
243251 case protocolUDP :
244252 forwarder , err := NewUDPPortForwarder (c .bindingHost , portMapping .LocalPort ,
245253 portMapping .RemoteHost , portMapping .RemotePort )
246254 if err != nil {
247255 return err
248256 }
249- go forwarder .Run (ctx , c .sessionID , msgChan , c .recvChans )
257+ go forwarder .Run (ctx , c .sessionID , msgChan , c .registerRecvChan )
250258 default :
251259 return errors .New ("unknown protocol: " + portMapping .Protocol )
252260 }
@@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages(
411419 m .Header .MsgType == wspf .MessageTypePortForwardStop ) {
412420 connectionID , _ := m .Header .Properties [wspf .PropertyConnectionID ].(string )
413421 if connectionID != "" {
422+ c .recvChanMu .RLock ()
414423 if recvChan , ok := c .recvChans [connectionID ]; ok {
415424 recvChan <- m
416425 }
426+ c .recvChanMu .RUnlock ()
417427 }
418428 }
419429 }
0 commit comments