Skip to content

Commit 12d9d05

Browse files
committed
drpc: fix concurrent large message corruption in multiplexed streams
With stream multiplexing, multiple streams write concurrently to a shared buffer. The stream's rawWriteLocked used to split large messages into multiple frames (via SplitData) and call WriteFrame for each chunk. Each WriteFrame acquires the shared mutex independently, so frames from different streams can interleave in the buffer. The reader on the other side doesn't handle interleaved frames from different streams mid-packet. When it sees a frame from a different stream, it resets the partial packet, silently corrupting data for messages larger than SplitSize. The fix changes the StreamWriter interface from WriteFrame(Frame) to WritePacket(Packet). The stream hands off the full message data in a single call, and the writer serializes it atomically under one mutex hold. rawWriteLocked no longer splits messages into frames, so there is nothing to interleave. Splitting may have been useful before multiplexing. The manageWriter goroutine already batches all pending data from the shared buffer into a single transport write, so splitting at the stream level adds no value. If we ever need to limit per-write size, that belongs in the writer implementation, not in the stream's rawWrite path.
1 parent ffc17bf commit 12d9d05

6 files changed

Lines changed: 155 additions & 75 deletions

File tree

drpcmanager/frame_queue_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,19 @@ import (
77
"testing"
88

99
"github.com/zeebo/assert"
10-
1110
"storj.io/drpc/drpcwire"
1211
)
1312

1413
func TestSharedWriteBuf_AppendDrain(t *testing.T) {
1514
sw := newSharedWriteBuf()
1615

17-
fr := drpcwire.Frame{
16+
pkt := drpcwire.Packet{
1817
Data: []byte("hello"),
1918
ID: drpcwire.ID{Stream: 1, Message: 2},
2019
Kind: drpcwire.KindMessage,
21-
Done: true,
2220
}
2321

24-
assert.NoError(t, sw.Append(fr))
22+
assert.NoError(t, sw.Append(pkt))
2523

2624
// Drain should return serialized bytes.
2725
data := sw.Drain(nil)
@@ -31,11 +29,11 @@ func TestSharedWriteBuf_AppendDrain(t *testing.T) {
3129
_, got, ok, err := drpcwire.ParseFrame(data)
3230
assert.NoError(t, err)
3331
assert.That(t, ok)
34-
assert.DeepEqual(t, got.Data, fr.Data)
35-
assert.Equal(t, got.ID.Stream, fr.ID.Stream)
36-
assert.Equal(t, got.ID.Message, fr.ID.Message)
37-
assert.Equal(t, got.Kind, fr.Kind)
38-
assert.Equal(t, got.Done, fr.Done)
32+
assert.DeepEqual(t, got.Data, pkt.Data)
33+
assert.Equal(t, got.ID.Stream, pkt.ID.Stream)
34+
assert.Equal(t, got.ID.Message, pkt.ID.Message)
35+
assert.Equal(t, got.Kind, pkt.Kind)
36+
assert.Equal(t, got.Done, true)
3937
}
4038

4139
func TestSharedWriteBuf_CloseIdempotent(t *testing.T) {
@@ -48,7 +46,7 @@ func TestSharedWriteBuf_AppendAfterClose(t *testing.T) {
4846
sw := newSharedWriteBuf()
4947
sw.Close()
5048

51-
err := sw.Append(drpcwire.Frame{})
49+
err := sw.Append(drpcwire.Packet{})
5250
assert.Error(t, err)
5351
}
5452

@@ -64,7 +62,7 @@ func TestSharedWriteBuf_WaitAndDrainBlocks(t *testing.T) {
6462
}()
6563

6664
// Append should wake the blocked WaitAndDrain.
67-
assert.NoError(t, sw.Append(drpcwire.Frame{Data: []byte("a")}))
65+
assert.NoError(t, sw.Append(drpcwire.Packet{Data: []byte("a")}))
6866
<-done
6967
}
7068

drpcmanager/manager_test.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/zeebo/assert"
1616
grpcmetadata "google.golang.org/grpc/metadata"
1717
"storj.io/drpc/drpcmetadata"
18-
18+
"storj.io/drpc/drpcstream"
1919
"storj.io/drpc/drpctest"
2020
"storj.io/drpc/drpcwire"
2121
)
@@ -245,6 +245,86 @@ func TestNewServerStreamUnreadMessageDoesNotBlockOtherStreams(t *testing.T) {
245245
defer func() { _ = srvStream2.Close() }()
246246
}
247247

248+
// TestConcurrentLargeMessages verifies that two streams writing messages larger
249+
// than SplitSize concurrently do not corrupt each other's data. With the current
250+
// implementation, rawWriteLocked splits messages into multiple frames and each
251+
// frame is appended to the shared write buffer independently. Frames from
252+
// different streams can interleave in the buffer, and the reader resets partial
253+
// packets when it sees a frame from a different stream, silently corrupting data.
254+
func TestConcurrentLargeMessages(t *testing.T) {
255+
ctx := drpctest.NewTracker(t)
256+
defer ctx.Close()
257+
258+
cconn, sconn := net.Pipe()
259+
defer func() { _ = cconn.Close() }()
260+
defer func() { _ = sconn.Close() }()
261+
262+
// Use a small SplitSize to force even small messages to be split into
263+
// multiple frames, making interleaving likely.
264+
streamOpts := drpcstream.Options{SplitSize: 5}
265+
266+
cman := NewWithOptions(cconn, Options{Stream: streamOpts})
267+
defer func() { _ = cman.Close() }()
268+
269+
sman := NewWithOptions(sconn, Options{Stream: streamOpts})
270+
defer func() { _ = sman.Close() }()
271+
272+
// Create two client streams and send invoke + message concurrently.
273+
stream1, err := cman.NewClientStream(ctx, "rpc-1")
274+
assert.NoError(t, err)
275+
defer func() { _ = stream1.Close() }()
276+
277+
stream2, err := cman.NewClientStream(ctx, "rpc-2")
278+
assert.NoError(t, err)
279+
defer func() { _ = stream2.Close() }()
280+
281+
msg1 := []byte("AAAAAAAAAAAAAAAAAAAA") // 20 bytes, split into 4 frames of 5 bytes
282+
msg2 := []byte("BBBBBBBBBBBBBBBBBBBB") // 20 bytes, split into 4 frames of 5 bytes
283+
284+
// Send invokes first (these are small, no splitting).
285+
assert.NoError(t, stream1.RawWrite(drpcwire.KindInvoke, []byte("rpc-1")))
286+
assert.NoError(t, stream2.RawWrite(drpcwire.KindInvoke, []byte("rpc-2")))
287+
288+
// Accept both server streams before sending messages, so the streams are
289+
// registered and the reader can route packets.
290+
srvStream1, rpc1, err := sman.NewServerStream(ctx)
291+
assert.NoError(t, err)
292+
assert.Equal(t, "rpc-1", rpc1)
293+
defer func() { _ = srvStream1.Close() }()
294+
295+
srvStream2, rpc2, err := sman.NewServerStream(ctx)
296+
assert.NoError(t, err)
297+
assert.Equal(t, "rpc-2", rpc2)
298+
defer func() { _ = srvStream2.Close() }()
299+
300+
// Write messages concurrently from both streams. With SplitSize=5, each
301+
// 20-byte message becomes 4 frames. The frames should not interleave.
302+
ready := make(chan struct{})
303+
var wg sync.WaitGroup
304+
wg.Add(2)
305+
go func() {
306+
defer wg.Done()
307+
<-ready
308+
assert.NoError(t, stream1.RawWrite(drpcwire.KindMessage, msg1))
309+
}()
310+
go func() {
311+
defer wg.Done()
312+
<-ready
313+
assert.NoError(t, stream2.RawWrite(drpcwire.KindMessage, msg2))
314+
}()
315+
close(ready)
316+
wg.Wait()
317+
318+
// Read from both server streams and verify correctness.
319+
got1, err := srvStream1.RawRecv()
320+
assert.NoError(t, err)
321+
assert.DeepEqual(t, got1, msg1)
322+
323+
got2, err := srvStream2.RawRecv()
324+
assert.NoError(t, err)
325+
assert.DeepEqual(t, got2, msg2)
326+
}
327+
248328
type blockingTransport chan struct{}
249329

250330
func (b blockingTransport) Read(p []byte) (n int, err error) { <-b; return 0, io.EOF }

drpcmanager/mux_writer.go

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,20 @@ import (
99
"storj.io/drpc/drpcwire"
1010
)
1111

12-
// muxWriter implements drpcwire.StreamWriter by serializing frame bytes into a
12+
// muxWriter implements drpcwire.StreamWriter by serializing packet bytes into a
1313
// shared write buffer. The manageWriter goroutine drains the buffer and writes
1414
// directly to the transport.
1515
//
16-
// Compared to the previous frameQueue approach, this avoids:
17-
// - copying frame payload into an intermediate queue slot,
18-
// - drpcwire.Writer mutex overhead in the writer goroutine.
19-
//
20-
// Frames are serialized (via AppendFrame) into the shared buffer under a
21-
// short-held mutex. The frame's Data slice is consumed before WriteFrame
22-
// returns, so callers may safely reuse their buffers afterward.
16+
// The entire packet is serialized as a single frame (via AppendFrame) under one
17+
// mutex hold, so frames from concurrent streams never interleave on the wire.
18+
// The packet's Data slice is consumed (copied) before WritePacket returns, so
19+
// callers may safely reuse their buffers afterward.
2320
type muxWriter struct {
2421
sw *sharedWriteBuf
2522
}
2623

27-
func (w *muxWriter) WriteFrame(fr drpcwire.Frame) error {
28-
return w.sw.Append(fr)
24+
func (w *muxWriter) WritePacket(pkt drpcwire.Packet) error {
25+
return w.sw.Append(pkt)
2926
}
3027

3128
// Flush is a no-op because the manageWriter goroutine flushes to the
@@ -50,15 +47,21 @@ func newSharedWriteBuf() *sharedWriteBuf {
5047
return sw
5148
}
5249

53-
// Append serializes fr into the shared buffer. The frame's Data slice is
54-
// consumed (copied by AppendFrame) before Append returns.
55-
func (sw *sharedWriteBuf) Append(fr drpcwire.Frame) error {
50+
// Append serializes pkt as a single frame into the shared buffer. The packet's
51+
// Data slice is consumed (copied by AppendFrame) before Append returns.
52+
func (sw *sharedWriteBuf) Append(pkt drpcwire.Packet) error {
5653
sw.mu.Lock()
5754
if sw.closed {
5855
sw.mu.Unlock()
5956
return managerClosed.New("enqueue")
6057
}
61-
sw.buf = drpcwire.AppendFrame(sw.buf, fr)
58+
sw.buf = drpcwire.AppendFrame(sw.buf, drpcwire.Frame{
59+
Data: pkt.Data,
60+
ID: pkt.ID,
61+
Kind: pkt.Kind,
62+
Control: pkt.Control,
63+
Done: true,
64+
})
6265
sw.mu.Unlock()
6366

6467
sw.cond.Signal()

drpcmanager/mux_writer_test.go

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,46 @@ import (
77
"testing"
88

99
"github.com/zeebo/assert"
10-
1110
"storj.io/drpc/drpcwire"
1211
)
1312

14-
func TestMuxWriter_WriteFrame(t *testing.T) {
13+
func TestMuxWriter_WritePacket(t *testing.T) {
1514
sw := newSharedWriteBuf()
1615
w := &muxWriter{sw: sw}
1716

18-
fr := drpcwire.Frame{
17+
pkt := drpcwire.Packet{
1918
Data: []byte("hello"),
2019
ID: drpcwire.ID{Stream: 1, Message: 2},
2120
Kind: drpcwire.KindMessage,
22-
Done: true,
2321
}
2422

25-
assert.NoError(t, w.WriteFrame(fr))
23+
assert.NoError(t, w.WritePacket(pkt))
2624

2725
data := sw.Drain(nil)
2826
_, got, ok, err := drpcwire.ParseFrame(data)
2927
assert.NoError(t, err)
3028
assert.That(t, ok)
31-
assert.DeepEqual(t, got.Data, fr.Data)
32-
assert.Equal(t, got.ID.Stream, fr.ID.Stream)
33-
assert.Equal(t, got.ID.Message, fr.ID.Message)
34-
assert.Equal(t, got.Kind, fr.Kind)
35-
assert.Equal(t, got.Done, fr.Done)
29+
assert.DeepEqual(t, got.Data, pkt.Data)
30+
assert.Equal(t, got.ID.Stream, pkt.ID.Stream)
31+
assert.Equal(t, got.ID.Message, pkt.ID.Message)
32+
assert.Equal(t, got.Kind, pkt.Kind)
33+
assert.Equal(t, got.Done, true)
3634
}
3735

38-
func TestMuxWriter_WriteFrameIsolatesData(t *testing.T) {
36+
func TestMuxWriter_WritePacketIsolatesData(t *testing.T) {
3937
sw := newSharedWriteBuf()
4038
w := &muxWriter{sw: sw}
4139

4240
data := []byte("hello")
43-
fr := drpcwire.Frame{
41+
pkt := drpcwire.Packet{
4442
Data: data,
4543
ID: drpcwire.ID{Stream: 1, Message: 2},
4644
Kind: drpcwire.KindMessage,
47-
Done: true,
4845
}
4946

50-
assert.NoError(t, w.WriteFrame(fr))
47+
assert.NoError(t, w.WritePacket(pkt))
5148

52-
// Mutate the original source buffer after WriteFrame.
49+
// Mutate the original source buffer after WritePacket.
5350
data[0] = 'j'
5451

5552
// The serialized data in the shared buffer should be unaffected because
@@ -73,11 +70,11 @@ func TestMuxWriter_Empty(t *testing.T) {
7370
assert.That(t, w.Empty())
7471
}
7572

76-
func TestMuxWriter_WriteFrameAfterClose(t *testing.T) {
73+
func TestMuxWriter_WritePacketAfterClose(t *testing.T) {
7774
sw := newSharedWriteBuf()
7875
w := &muxWriter{sw: sw}
7976
sw.Close()
8077

81-
err := w.WriteFrame(drpcwire.Frame{})
78+
err := w.WritePacket(drpcwire.Packet{})
8279
assert.Error(t, err)
8380
}

drpcstream/stream.go

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"sync"
1212

1313
"github.com/zeebo/errs"
14-
1514
"storj.io/drpc"
1615
"storj.io/drpc/drpcctx"
1716
"storj.io/drpc/drpcdebug"
@@ -80,7 +79,9 @@ func New(ctx context.Context, sid uint64, wr drpcwire.StreamWriter) *Stream {
8079
// stream id and will use the writer to write messages on. It is important use
8180
// monotonically increasing stream ids within a single transport. The options
8281
// are used to control details of how the Stream operates.
83-
func NewWithOptions(ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options) *Stream {
82+
func NewWithOptions(
83+
ctx context.Context, sid uint64, wr drpcwire.StreamWriter, opts Options,
84+
) *Stream {
8485
var task *trace.Task
8586
if trace.IsEnabled() {
8687
kind, rpc := drpcopts.GetStreamKind(&opts.Internal), drpcopts.GetStreamRPC(&opts.Internal)
@@ -312,26 +313,28 @@ func (s *Stream) checkCancelError(err error) error {
312313
return err
313314
}
314315

315-
// newFrameLocked bumps the internal message id and returns a frame. It must be
316+
// nextID bumps the internal message id and returns the new ID. It must be
316317
// called under a mutex.
317-
func (s *Stream) newFrameLocked(kind drpcwire.Kind) drpcwire.Frame {
318+
func (s *Stream) nextID() drpcwire.ID {
318319
s.id.Message++
319-
return drpcwire.Frame{ID: s.id, Kind: kind}
320+
return s.id
320321
}
321322

322323
// sendPacketLocked sends the packet in a single write and flushes. It does not
323324
// check for any conditions to stop it from writing and is meant for internal
324325
// stream use to do things like signal errors or closes to the remote side.
325326
func (s *Stream) sendPacketLocked(kind drpcwire.Kind, control bool, data []byte) (err error) {
326-
fr := s.newFrameLocked(kind)
327-
fr.Data = data
328-
fr.Control = control
329-
fr.Done = true
327+
pkt := drpcwire.Packet{
328+
ID: s.nextID(),
329+
Kind: kind,
330+
Data: data,
331+
Control: control,
332+
}
330333

331334
drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data)))
332-
s.log("SEND", fr.String)
335+
s.log("SEND", pkt.String)
333336

334-
if err := s.wr.WriteFrame(fr); err != nil {
337+
if err := s.wr.WritePacket(pkt); err != nil {
335338
return errs.Wrap(err)
336339
}
337340
if err := s.wr.Flush(); err != nil {
@@ -374,29 +377,26 @@ func (s *Stream) RawWrite(kind drpcwire.Kind, data []byte) (err error) {
374377
// rawWriteLocked does the body of RawWrite assuming the caller is holding the
375378
// appropriate locks.
376379
func (s *Stream) rawWriteLocked(kind drpcwire.Kind, data []byte) (err error) {
377-
fr := s.newFrameLocked(kind)
378-
n := s.opts.SplitSize
379-
380-
for {
381-
switch {
382-
case s.sigs.send.IsSet():
383-
return s.sigs.send.Err()
384-
case s.sigs.term.IsSet():
385-
return s.sigs.term.Err()
386-
}
380+
switch {
381+
case s.sigs.send.IsSet():
382+
return s.sigs.send.Err()
383+
case s.sigs.term.IsSet():
384+
return s.sigs.term.Err()
385+
}
387386

388-
fr.Data, data = drpcwire.SplitData(data, n)
389-
fr.Done = len(data) == 0
387+
pkt := drpcwire.Packet{
388+
ID: s.nextID(),
389+
Kind: kind,
390+
Data: data,
391+
}
390392

391-
drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(fr.Data)))
392-
s.log("SEND", fr.String)
393+
drpcopts.GetStreamStats(&s.opts.Internal).AddWritten(uint64(len(data)))
394+
s.log("SEND", pkt.String)
393395

394-
if err := s.wr.WriteFrame(fr); err != nil {
395-
return s.checkCancelError(errs.Wrap(err))
396-
} else if fr.Done {
397-
return nil
398-
}
396+
if err := s.wr.WritePacket(pkt); err != nil {
397+
return s.checkCancelError(errs.Wrap(err))
399398
}
399+
return nil
400400
}
401401

402402
// RawFlush flushes any buffers of data.

0 commit comments

Comments
 (0)