Skip to content

Commit 12c8ca6

Browse files
committed
fix: protect port forward receive chan map from concurrent read/write
Signed-off-by: Alf-Rune Siqveland <alf.rune@northern.tech>
1 parent abf50b7 commit 12c8ca6

3 files changed

Lines changed: 33 additions & 12 deletions

File tree

cmd/portforward.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"os/signal"
2222
"strconv"
2323
"strings"
24+
"sync"
2425
"time"
2526

2627
"github.com/mendersoftware/go-lib-micro/ws"
@@ -94,7 +95,8 @@ type PortForwardCmd struct {
9495
sessionID string
9596
bindingHost string
9697
portMappings []portMapping
97-
recvChans map[string]chan *ws.ProtoMsg
98+
recvChanMu sync.RWMutex
99+
recvChans map[string]chan<- *ws.ProtoMsg
98100
running bool
99101
stop chan struct{}
100102
err error
@@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro
184186
deviceID: args[0],
185187
bindingHost: bindingHost,
186188
portMappings: portMappings,
187-
recvChans: make(map[string]chan *ws.ProtoMsg),
189+
recvChans: make(map[string]chan<- *ws.ProtoMsg),
188190
stop: make(chan struct{}),
189191
}, nil
190192
}
@@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error {
198200
}
199201
}
200202

203+
func (c *PortForwardCmd) registerRecvChan(connectionID string, recvChan chan<- *ws.ProtoMsg) {
204+
c.recvChanMu.Lock()
205+
defer c.recvChanMu.Unlock()
206+
c.recvChans[connectionID] = recvChan
207+
}
208+
201209
func (c *PortForwardCmd) run() error {
202210
ctx, cancelContext := context.WithCancel(context.Background())
203211
defer cancelContext()
@@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error {
239247
if err != nil {
240248
return err
241249
}
242-
go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
250+
go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan)
243251
case protocolUDP:
244252
forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort,
245253
portMapping.RemoteHost, portMapping.RemotePort)
246254
if err != nil {
247255
return err
248256
}
249-
go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
257+
go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan)
250258
default:
251259
return errors.New("unknown protocol: " + portMapping.Protocol)
252260
}
@@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages(
411419
m.Header.MsgType == wspf.MessageTypePortForwardStop) {
412420
connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string)
413421
if connectionID != "" {
422+
c.recvChanMu.RLock()
414423
if recvChan, ok := c.recvChans[connectionID]; ok {
415424
recvChan <- m
416425
}
426+
c.recvChanMu.RUnlock()
417427
}
418428
}
419429
}

cmd/portforward_tcp.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (p *TCPPortForwarder) Run(
6060
ctx context.Context,
6161
sessionID string,
6262
msgChan chan *ws.ProtoMsg,
63-
recvChans map[string]chan *ws.ProtoMsg,
63+
registerRecvChan func(string, chan<- *ws.ProtoMsg),
6464
) {
6565
// listen for new connections
6666
defer p.listen.Close()
@@ -101,7 +101,7 @@ func (p *TCPPortForwarder) Run(
101101
connectionUUID, _ := uuid.NewUUID()
102102
connectionID := connectionUUID.String()
103103
recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize)
104-
recvChans[connectionID] = recvChan
104+
registerRecvChan(connectionID, recvChan)
105105
go p.handleRequest(ctx, conn, sessionID, connectionID, recvChan, msgChan)
106106
case <-ctx.Done():
107107
return
@@ -135,8 +135,15 @@ func (p *TCPPortForwarder) handleInboundMessages(
135135
}()
136136

137137
for {
138+
var (
139+
m *ws.ProtoMsg
140+
open bool
141+
)
138142
select {
139-
case m := <-recvChan:
143+
case m, open = <-recvChan:
144+
if !open {
145+
return
146+
}
140147
if m.Header.Proto == ws.ProtoTypePortForward &&
141148
m.Header.MsgType == wspf.MessageTypePortForwardStop {
142149
sendStopMessage = false
@@ -145,12 +152,13 @@ func (p *TCPPortForwarder) handleInboundMessages(
145152
m.Header.MsgType == wspf.MessageTypePortForward {
146153
_, err := conn.Write(m.Body)
147154
if err != nil {
148-
if errors.Unwrap(err) != net.ErrClosed {
155+
if !errors.Is(err, net.ErrClosed) {
149156
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
150157
}
158+
return
151159
} else {
152160
// send the ack
153-
m := &ws.ProtoMsg{
161+
m = &ws.ProtoMsg{
154162
Header: ws.ProtoHdr{
155163
Proto: ws.ProtoTypePortForward,
156164
MsgType: wspf.MessageTypePortForwardAck,
@@ -164,7 +172,10 @@ func (p *TCPPortForwarder) handleInboundMessages(
164172
}
165173
} else if m.Header.Proto == ws.ProtoTypePortForward &&
166174
m.Header.MsgType == wspf.MessageTypePortForwardAck {
167-
<-ackChan
175+
_, open = <-ackChan
176+
if !open {
177+
return
178+
}
168179
}
169180
case <-ctx.Done():
170181
return

cmd/portforward_udp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ func (p *UDPPortForwarder) Run(
7373
ctx context.Context,
7474
sessionID string,
7575
msgChan chan *ws.ProtoMsg,
76-
recvChans map[string]chan *ws.ProtoMsg,
76+
registerRecvChan func(string, chan<- *ws.ProtoMsg),
7777
) {
7878
// listen for new connections
7979
defer p.conn.Close()
8080

8181
connectionUUID, _ := uuid.NewUUID()
8282
connectionID := connectionUUID.String()
8383
recvChan := make(chan *ws.ProtoMsg, portForwardUDPChannelSize)
84-
recvChans[connectionID] = recvChan
84+
registerRecvChan(connectionID, recvChan)
8585

8686
protocol := portforward.PortForwardProtocol(wspf.PortForwardProtocolUDP)
8787
portforwardNew := &wspf.PortForwardNew{

0 commit comments

Comments
 (0)