Skip to content

Commit e16d021

Browse files
committed
drpcmanager: replace streamBuffer with streamRegistry
Replace the single-stream streamBuffer with a stream registry that maps stream IDs to stream objects. The registry currently holds at most one active stream (two briefly during handoff), but provides the foundation for stream multiplexing where callers will look up streams by ID directly.
1 parent f204e04 commit e16d021

File tree

4 files changed

+266
-73
lines changed

4 files changed

+266
-73
lines changed

drpcmanager/manager.go

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net"
1212
"strings"
1313
"sync"
14+
"sync/atomic"
1415
"syscall"
1516
"time"
1617

@@ -81,8 +82,12 @@ type Manager struct {
8182

8283
wg sync.WaitGroup // tracks active manageStream goroutines
8384

84-
sem drpcsignal.Chan // held by the active stream
85-
sbuf streamBuffer // largest stream id created
85+
// reg tracks active streams. Currently holds at most one active stream;
86+
// a second may briefly coexist during stream handoff (old stream's
87+
// Unregister races with new stream's Register).
88+
reg *streamRegistry
89+
90+
sem drpcsignal.Chan // held by the active stream
8691

8792
pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream
8893
invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream
@@ -123,9 +128,6 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
123128
invokes: make(chan invokeInfo),
124129
}
125130

126-
// initialize the stream buffer
127-
m.sbuf.init()
128-
129131
// this semaphore controls the number of concurrent streams. it MUST be 1.
130132
m.sem.Make(1)
131133

@@ -134,6 +136,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
134136
m.pdone.Make(1)
135137

136138
m.pa = drpcwire.NewPacketAssembler()
139+
m.reg = newStreamRegistry()
137140

138141
// set the internal stream options
139142
drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr)
@@ -186,7 +189,7 @@ func (m *Manager) acquireSemaphore(ctx context.Context) error {
186189
// longer make any reads or writes on the transport. It exits early if the
187190
// context is canceled or the manager is terminated.
188191
func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) {
189-
prev := m.sbuf.Get()
192+
prev := m.reg.GetLatest()
190193
if prev == nil {
191194
return nil
192195
}
@@ -217,7 +220,8 @@ func (m *Manager) terminate(err error) {
217220
if m.sigs.term.Set(err) {
218221
m.log("TERM", func() string { return fmt.Sprint(err) })
219222
m.sigs.tport.Set(m.tr.Close())
220-
m.sbuf.Close()
223+
m.reg.ForEach(func(s *drpcstream.Stream) { s.Close() })
224+
m.reg.Close()
221225
}
222226
}
223227

@@ -249,7 +253,7 @@ func (m *Manager) manageReader() {
249253
return
250254
}
251255

252-
switch curr := m.sbuf.Get(); {
256+
switch curr := m.reg.GetLatest(); {
253257
// If the frame is for the current stream, deliver it.
254258
case curr != nil && incomingFrame.ID.Stream == curr.ID():
255259
if err := curr.HandleFrame(incomingFrame); err != nil {
@@ -319,9 +323,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error {
319323
// Invoke packet completes the sequence. Send to NewServerStream.
320324
select {
321325
case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}:
322-
// Wait for NewServerStream to finish stream creation (including
323-
// sbuf.Set) before reading the next frame. This guarantees curr
324-
// is set for subsequent non-invoke packets.
326+
// Wait for NewServerStream to finish stream creation before reading the
327+
// next frame. This guarantees curr is set for subsequent non-invoke
328+
// packets.
325329
m.pdone.Recv()
326330

327331
m.pa.Reset()
@@ -346,10 +350,13 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin
346350

347351
stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts)
348352

353+
if err := m.reg.Register(sid, stream); err != nil {
354+
return nil, err
355+
}
356+
349357
m.wg.Add(1)
350358
go m.manageStream(ctx, stream)
351359

352-
m.sbuf.Set(stream)
353360
m.log("STREAM", stream.String)
354361

355362
return stream, nil
@@ -359,6 +366,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin
359366
// is finished, canceling the stream if the context is canceled.
360367
func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) {
361368
defer m.wg.Done()
369+
defer m.reg.Unregister(stream.ID())
362370
select {
363371
case <-m.sigs.term.Signal():
364372
err := m.sigs.term.Err()
@@ -429,7 +437,7 @@ func (m *Manager) Closed() <-chan struct{} {
429437
// the return result is only valid until the next call to NewClientStream or
430438
// NewServerStream.
431439
func (m *Manager) Unblocked() <-chan struct{} {
432-
if prev := m.sbuf.Get(); prev != nil {
440+
if prev := m.reg.GetLatest(); prev != nil {
433441
return prev.Context().Done()
434442
}
435443
return closedCh
@@ -506,9 +514,8 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea
506514
}
507515
}
508516
stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc)
509-
// Signal pdone only after stream registration so that
510-
// manageReader sees the new stream via sbuf.Get() when it reads
511-
// the next frame.
517+
// Signal pdone only after stream registration so that manageReader sees
518+
// the new stream in the registry when it reads the next frame.
512519
m.pdone.Send()
513520
return stream, rpc, err
514521
}

drpcmanager/registry.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright (C) 2026 Cockroach Labs.
2+
// See LICENSE for copying information.
3+
4+
package drpcmanager
5+
6+
import (
7+
"sync"
8+
9+
"storj.io/drpc/drpcstream"
10+
)
11+
12+
// streamRegistry is a thread-safe map of stream IDs to stream objects.
13+
// It is used by the Manager to track active streams for lifecycle management.
14+
type streamRegistry struct {
15+
mu sync.RWMutex
16+
streams map[uint64]*drpcstream.Stream
17+
closed bool
18+
}
19+
20+
func newStreamRegistry() *streamRegistry {
21+
return &streamRegistry{
22+
streams: make(map[uint64]*drpcstream.Stream),
23+
}
24+
}
25+
26+
// Register adds a stream to the registry. It returns an error if the registry
27+
// is closed or if a stream with the same ID is already registered.
28+
func (r *streamRegistry) Register(id uint64, stream *drpcstream.Stream) error {
29+
r.mu.Lock()
30+
defer r.mu.Unlock()
31+
32+
if stream == nil {
33+
return managerClosed.New("stream can't be nil")
34+
}
35+
36+
if r.closed {
37+
return managerClosed.New("register")
38+
}
39+
if _, ok := r.streams[id]; ok {
40+
return managerClosed.New("duplicate stream id")
41+
}
42+
r.streams[id] = stream
43+
return nil
44+
}
45+
46+
// Unregister removes a stream from the registry. It is a no-op if the stream
47+
// is not registered or if the registry has been closed.
48+
func (r *streamRegistry) Unregister(id uint64) {
49+
r.mu.Lock()
50+
defer r.mu.Unlock()
51+
52+
if r.streams != nil {
53+
delete(r.streams, id)
54+
}
55+
}
56+
57+
// Get returns the stream for the given ID and whether it was found.
58+
func (r *streamRegistry) Get(id uint64) (*drpcstream.Stream, bool) {
59+
r.mu.RLock()
60+
defer r.mu.RUnlock()
61+
62+
s, ok := r.streams[id]
63+
return s, ok
64+
}
65+
66+
// GetLatest returns the stream with the highest ID, or nil if the registry is
67+
// empty. It iterates the map because the registry may briefly hold two streams
68+
// during stream handoff. This method should be removed once multiplexing is
69+
// supported and callers look up streams by ID directly.
70+
func (r *streamRegistry) GetLatest() *drpcstream.Stream {
71+
r.mu.RLock()
72+
defer r.mu.RUnlock()
73+
74+
var latest *drpcstream.Stream
75+
for _, s := range r.streams {
76+
if latest == nil || latest.ID() < s.ID() {
77+
latest = s
78+
}
79+
}
80+
return latest
81+
}
82+
83+
// Close marks the registry as closed, preventing future Register calls.
84+
// It does not cancel any streams.
85+
func (r *streamRegistry) Close() {
86+
r.mu.Lock()
87+
defer r.mu.Unlock()
88+
89+
r.closed = true
90+
}
91+
92+
// ForEach calls fn for each registered stream. The registry is read-locked
93+
// during iteration.
94+
func (r *streamRegistry) ForEach(fn func(*drpcstream.Stream)) {
95+
r.mu.RLock()
96+
defer r.mu.RUnlock()
97+
98+
for _, s := range r.streams {
99+
fn(s)
100+
}
101+
}
102+
103+
// Len returns the number of registered streams.
104+
func (r *streamRegistry) Len() int {
105+
r.mu.RLock()
106+
defer r.mu.RUnlock()
107+
108+
return len(r.streams)
109+
}

drpcmanager/registry_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright (C) 2026 Cockroach Labs.
2+
// See LICENSE for copying information.
3+
4+
package drpcmanager
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/zeebo/assert"
11+
12+
"storj.io/drpc/drpcstream"
13+
"storj.io/drpc/drpcwire"
14+
)
15+
16+
func testStream(id uint64) *drpcstream.Stream {
17+
return drpcstream.New(context.Background(), id, &drpcwire.Writer{})
18+
}
19+
20+
func TestStreamRegistry_RegisterAndGet(t *testing.T) {
21+
reg := newStreamRegistry()
22+
s := testStream(1)
23+
24+
assert.NoError(t, reg.Register(1, s))
25+
26+
got, ok := reg.Get(1)
27+
assert.That(t, ok)
28+
assert.Equal(t, got, s)
29+
}
30+
31+
func TestStreamRegistry_GetMissing(t *testing.T) {
32+
reg := newStreamRegistry()
33+
34+
got, ok := reg.Get(42)
35+
assert.That(t, !ok)
36+
assert.Nil(t, got)
37+
}
38+
39+
func TestStreamRegistry_Unregister(t *testing.T) {
40+
reg := newStreamRegistry()
41+
s := testStream(1)
42+
43+
assert.NoError(t, reg.Register(1, s))
44+
assert.Equal(t, reg.Len(), 1)
45+
46+
reg.Unregister(1)
47+
48+
_, ok := reg.Get(1)
49+
assert.That(t, !ok)
50+
assert.Equal(t, reg.Len(), 0)
51+
}
52+
53+
func TestStreamRegistry_UnregisterIdempotent(t *testing.T) {
54+
reg := newStreamRegistry()
55+
56+
// must not panic when unregistering a non-existent ID
57+
reg.Unregister(99)
58+
}
59+
60+
func TestStreamRegistry_DuplicateRegister(t *testing.T) {
61+
reg := newStreamRegistry()
62+
s1 := testStream(1)
63+
s2 := testStream(1)
64+
65+
assert.NoError(t, reg.Register(1, s1))
66+
assert.Error(t, reg.Register(1, s2))
67+
68+
// original stream is still registered
69+
got, ok := reg.Get(1)
70+
assert.That(t, ok)
71+
assert.Equal(t, got, s1)
72+
}
73+
74+
func TestStreamRegistry_RegisterAfterClose(t *testing.T) {
75+
reg := newStreamRegistry()
76+
reg.Close()
77+
78+
err := reg.Register(1, testStream(1))
79+
assert.Error(t, err)
80+
}
81+
82+
func TestStreamRegistry_UnregisterAfterClose(t *testing.T) {
83+
reg := newStreamRegistry()
84+
s := testStream(1)
85+
assert.NoError(t, reg.Register(1, s))
86+
87+
reg.Close()
88+
89+
// must not panic
90+
reg.Unregister(1)
91+
}
92+
93+
func TestStreamRegistry_Len(t *testing.T) {
94+
reg := newStreamRegistry()
95+
assert.Equal(t, reg.Len(), 0)
96+
97+
assert.NoError(t, reg.Register(1, testStream(1)))
98+
assert.Equal(t, reg.Len(), 1)
99+
100+
assert.NoError(t, reg.Register(2, testStream(2)))
101+
assert.Equal(t, reg.Len(), 2)
102+
103+
reg.Unregister(1)
104+
assert.Equal(t, reg.Len(), 1)
105+
}
106+
107+
func TestStreamRegistry_ForEach(t *testing.T) {
108+
reg := newStreamRegistry()
109+
s1 := testStream(1)
110+
s2 := testStream(2)
111+
s3 := testStream(3)
112+
113+
assert.NoError(t, reg.Register(1, s1))
114+
assert.NoError(t, reg.Register(2, s2))
115+
assert.NoError(t, reg.Register(3, s3))
116+
117+
seen := make(map[uint64]*drpcstream.Stream)
118+
reg.ForEach(func(s *drpcstream.Stream) {
119+
seen[s.ID()] = s
120+
})
121+
122+
assert.Equal(t, len(seen), 3)
123+
assert.Equal(t, seen[1], s1)
124+
assert.Equal(t, seen[2], s2)
125+
assert.Equal(t, seen[3], s3)
126+
}
127+
128+
func TestStreamRegistry_ForEach_Empty(t *testing.T) {
129+
reg := newStreamRegistry()
130+
131+
count := 0
132+
reg.ForEach(func(_ *drpcstream.Stream) { count++ })
133+
assert.Equal(t, count, 0)
134+
}

0 commit comments

Comments
 (0)