Skip to content

Commit 3a19dcb

Browse files
authored
fix(connection): response-scoped notification barrier (fixes WaitGroup reuse panic) (#30)
1 parent c180b9d commit 3a19dcb

4 files changed

Lines changed: 608 additions & 64 deletions

File tree

acp_test.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -864,20 +864,9 @@ func TestConnectionFailsFastOnNotificationQueueOverflow(t *testing.T) {
864864
t.Fatalf("expected overflow cancellation cause, got %v", cause)
865865
}
866866

867-
// Let queued work drain and ensure waitgroup accounting remains balanced.
867+
// Let queued work drain and ensure the notification barrier remains balanced.
868868
close(releaseFirst)
869-
870-
drained := make(chan struct{})
871-
go func() {
872-
c.notificationWg.Wait()
873-
close(drained)
874-
}()
875-
876-
select {
877-
case <-drained:
878-
case <-time.After(1 * time.Second):
879-
t.Fatalf("notification waitgroup did not drain after overflow")
880-
}
869+
waitForNotificationBarrierDrain(t, c, 1*time.Second)
881870
}
882871

883872
// Test initialize method behavior

connection.go

Lines changed: 196 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,18 @@ type anyMessage struct {
3030
Error *RequestError `json:"error,omitempty"`
3131
}
3232

33+
type queuedNotification struct {
34+
seq uint64
35+
msg *anyMessage
36+
}
37+
38+
type responseEnvelope struct {
39+
msg anyMessage
40+
notificationWatermark uint64
41+
}
42+
3343
type pendingResponse struct {
34-
ch chan anyMessage
44+
ch chan responseEnvelope
3545
}
3646

3747
type cancelRequestParams struct {
@@ -69,13 +79,16 @@ type Connection struct {
6979

7080
logger *slog.Logger
7181

72-
// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
73-
// for all notifications received before the response to complete processing.
74-
notificationWg sync.WaitGroup
82+
notifyMu sync.Mutex
83+
// notifyCond coordinates response-scoped waits for sequential notification processing.
84+
notifyCond *sync.Cond
85+
// invariant: completedNotificationSeq <= lastEnqueuedNotificationSeq.
86+
lastEnqueuedNotificationSeq uint64
87+
completedNotificationSeq uint64
7588

7689
// notificationQueue serializes notification processing to maintain order.
7790
// It is bounded to keep memory usage predictable.
78-
notificationQueue chan *anyMessage
91+
notificationQueue chan queuedNotification
7992
}
8093

8194
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
@@ -92,8 +105,15 @@ func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Rea
92105
cancel: cancel,
93106
inboundCtx: inboundCtx,
94107
inboundCancel: inboundCancel,
95-
notificationQueue: make(chan *anyMessage, defaultMaxQueuedNotifications),
108+
notificationQueue: make(chan queuedNotification, defaultMaxQueuedNotifications),
96109
}
110+
c.notifyCond = sync.NewCond(&c.notifyMu)
111+
go func() {
112+
<-c.ctx.Done()
113+
c.notifyMu.Lock()
114+
c.notifyCond.Broadcast()
115+
c.notifyMu.Unlock()
116+
}()
97117
go c.sendCancelRequests()
98118
go c.receive()
99119
go c.processNotifications()
@@ -402,15 +422,27 @@ func (c *Connection) receive() {
402422
continue
403423
}
404424

405-
c.notificationWg.Add(1)
406-
407-
// Queue the notification for sequential processing.
425+
// Queue the notification for sequential processing. The sequence number marks
426+
// the response-scoped barrier boundary for requests that observe later responses.
408427
m := msg
428+
c.notifyMu.Lock()
429+
c.lastEnqueuedNotificationSeq++
430+
seq := c.lastEnqueuedNotificationSeq
409431
select {
410-
case c.notificationQueue <- &m:
432+
case c.notificationQueue <- queuedNotification{seq: seq, msg: &m}:
433+
c.notifyMu.Unlock()
411434
default:
412-
// Balance Add above when the message was not accepted.
413-
c.notificationWg.Done()
435+
if c.lastEnqueuedNotificationSeq != seq {
436+
c.notifyMu.Unlock()
437+
panic("notification sequence advanced while receive goroutine was queueing")
438+
}
439+
c.lastEnqueuedNotificationSeq--
440+
// invariant: completedNotificationSeq never exceeds the highest accepted enqueue.
441+
if c.completedNotificationSeq > c.lastEnqueuedNotificationSeq {
442+
c.notifyMu.Unlock()
443+
panic("completed notification sequence exceeded enqueued notification sequence")
444+
}
445+
c.notifyMu.Unlock()
414446
c.loggerOrDefault().Error("failed to queue notification; closing connection", "err", errNotificationQueueOverflow, "capacity", cap(c.notificationQueue), "queued", len(c.notificationQueue))
415447
c.shutdownReceive(errNotificationQueueOverflow)
416448
return
@@ -440,30 +472,43 @@ func (c *Connection) shutdownReceive(cause error) {
440472
// notification handlers may legitimately block until their context is canceled.
441473
close(c.notificationQueue)
442474

475+
c.notifyMu.Lock()
476+
finalEnqueuedSeq := c.lastEnqueuedNotificationSeq
477+
if c.completedNotificationSeq > finalEnqueuedSeq {
478+
c.notifyMu.Unlock()
479+
panic("completed notification sequence exceeded final enqueued sequence during shutdown")
480+
}
481+
c.notifyMu.Unlock()
482+
443483
// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
444484
// handler blocks waiting for cancellation.
445-
go func() {
446-
done := make(chan struct{})
447-
go func() {
448-
c.notificationWg.Wait()
449-
close(done)
450-
}()
451-
select {
452-
case <-done:
453-
case <-time.After(notificationQueueDrainTimeout):
454-
}
485+
go func(finalEnqueuedSeq uint64) {
486+
c.waitForNotificationDrain(finalEnqueuedSeq, notificationQueueDrainTimeout)
455487
c.inboundCancel(cause)
456-
}()
488+
}(finalEnqueuedSeq)
457489

458490
c.loggerOrDefault().Info("connection closed", "cause", cause.Error())
459491
}
460492

461493
// processNotifications processes notifications sequentially to maintain order.
462494
// It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
463495
func (c *Connection) processNotifications() {
464-
for msg := range c.notificationQueue {
465-
c.handleInbound(c.inboundCtx, msg)
466-
c.notificationWg.Done()
496+
for queued := range c.notificationQueue {
497+
c.handleInbound(c.inboundCtx, queued.msg)
498+
499+
c.notifyMu.Lock()
500+
expectedSeq := c.completedNotificationSeq + 1
501+
if queued.seq != expectedSeq {
502+
c.notifyMu.Unlock()
503+
panic("notification sequence completed out of order")
504+
}
505+
c.completedNotificationSeq = queued.seq
506+
if c.completedNotificationSeq > c.lastEnqueuedNotificationSeq {
507+
c.notifyMu.Unlock()
508+
panic("completed notification sequence exceeded enqueued notification sequence")
509+
}
510+
c.notifyCond.Broadcast()
511+
c.notifyMu.Unlock()
467512
}
468513
}
469514

@@ -482,7 +527,14 @@ func (c *Connection) handleResponse(msg *anyMessage) {
482527
c.mu.Unlock()
483528

484529
if pr != nil {
485-
pr.ch <- *msg
530+
c.notifyMu.Lock()
531+
watermark := c.lastEnqueuedNotificationSeq
532+
if c.completedNotificationSeq > watermark {
533+
c.notifyMu.Unlock()
534+
panic("completed notification sequence exceeded response watermark")
535+
}
536+
c.notifyMu.Unlock()
537+
pr.ch <- responseEnvelope{msg: *msg, notificationWatermark: watermark}
486538
}
487539
}
488540

@@ -578,7 +630,7 @@ func SendRequest[T any](c *Connection, ctx context.Context, method string, param
578630
return result, err
579631
}
580632

581-
pr := &pendingResponse{ch: make(chan anyMessage, 1)}
633+
pr := &pendingResponse{ch: make(chan responseEnvelope, 1)}
582634
c.mu.Lock()
583635
c.pending[idKey] = pr
584636
c.mu.Unlock()
@@ -592,18 +644,16 @@ func SendRequest[T any](c *Connection, ctx context.Context, method string, param
592644
if err != nil {
593645
return result, err
594646
}
647+
if err := c.waitNotificationsUpTo(ctx, resp.notificationWatermark); err != nil {
648+
return result, err
649+
}
595650

596-
// Wait for all notification handlers that were spawned before this response to complete
597-
// processing. This ensures that when a request returns, all notifications sent by the
598-
// server before the response have been fully processed.
599-
c.notificationWg.Wait()
600-
601-
if resp.Error != nil {
602-
return result, resp.Error
651+
if resp.msg.Error != nil {
652+
return result, resp.msg.Error
603653
}
604654

605-
if len(resp.Result) > 0 {
606-
if err := json.Unmarshal(resp.Result, &result); err != nil {
655+
if len(resp.msg.Result) > 0 {
656+
if err := json.Unmarshal(resp.msg.Result, &result); err != nil {
607657
return result, NewInternalError(map[string]any{"error": err.Error()})
608658
}
609659
}
@@ -687,7 +737,7 @@ func (c *Connection) sendCancelRequest(idKey string) {
687737
}
688738
}
689739

690-
func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (anyMessage, error) {
740+
func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (responseEnvelope, error) {
691741
peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected before response"})
692742

693743
select {
@@ -699,7 +749,7 @@ func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, i
699749
select {
700750
case <-c.Done():
701751
c.cleanupPending(idKey)
702-
return anyMessage{}, peerDisconnectedErr
752+
return responseEnvelope{}, peerDisconnectedErr
703753
default:
704754
}
705755

@@ -711,12 +761,110 @@ func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, i
711761
cause = ctx.Err()
712762
}
713763
if cause != nil {
714-
return anyMessage{}, toReqErr(cause)
764+
return responseEnvelope{}, toReqErr(cause)
715765
}
716-
return anyMessage{}, NewInternalError(map[string]any{"error": "request context ended without cause"})
766+
return responseEnvelope{}, NewInternalError(map[string]any{"error": "request context ended without cause"})
717767
case <-c.Done():
718768
c.cleanupPending(idKey)
719-
return anyMessage{}, peerDisconnectedErr
769+
return responseEnvelope{}, peerDisconnectedErr
770+
}
771+
}
772+
773+
func (c *Connection) waitNotificationsUpTo(ctx context.Context, target uint64) error {
774+
if target == 0 {
775+
return nil
776+
}
777+
778+
peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected while waiting for pre-response notifications"})
779+
stopWake := make(chan struct{})
780+
defer close(stopWake)
781+
782+
c.notifyMu.Lock()
783+
defer c.notifyMu.Unlock()
784+
if target > c.lastEnqueuedNotificationSeq {
785+
panic("response watermark exceeded last enqueued notification sequence")
786+
}
787+
788+
go func() {
789+
select {
790+
case <-ctx.Done():
791+
case <-stopWake:
792+
return
793+
}
794+
c.notifyMu.Lock()
795+
c.notifyCond.Broadcast()
796+
c.notifyMu.Unlock()
797+
}()
798+
799+
for c.completedNotificationSeq < target {
800+
if c.completedNotificationSeq > c.lastEnqueuedNotificationSeq {
801+
panic("completed notification sequence exceeded enqueued notification sequence while waiting")
802+
}
803+
804+
select {
805+
case <-c.Done():
806+
return peerDisconnectedErr
807+
default:
808+
}
809+
select {
810+
case <-ctx.Done():
811+
select {
812+
case <-c.Done():
813+
return peerDisconnectedErr
814+
default:
815+
}
816+
cause := context.Cause(ctx)
817+
if cause == nil {
818+
cause = ctx.Err()
819+
}
820+
if cause != nil {
821+
return toReqErr(cause)
822+
}
823+
return NewInternalError(map[string]any{"error": "request context ended without cause while waiting for notifications"})
824+
default:
825+
}
826+
827+
c.notifyCond.Wait()
828+
}
829+
return nil
830+
}
831+
832+
func (c *Connection) waitForNotificationDrain(target uint64, timeout time.Duration) {
833+
if target == 0 {
834+
return
835+
}
836+
837+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
838+
defer cancel()
839+
840+
stopWake := make(chan struct{})
841+
defer close(stopWake)
842+
843+
c.notifyMu.Lock()
844+
defer c.notifyMu.Unlock()
845+
if target > c.lastEnqueuedNotificationSeq {
846+
panic("drain target exceeded last enqueued notification sequence")
847+
}
848+
849+
go func() {
850+
select {
851+
case <-ctx.Done():
852+
case <-stopWake:
853+
return
854+
}
855+
c.notifyMu.Lock()
856+
c.notifyCond.Broadcast()
857+
c.notifyMu.Unlock()
858+
}()
859+
860+
for c.completedNotificationSeq < target {
861+
if c.completedNotificationSeq > c.lastEnqueuedNotificationSeq {
862+
panic("completed notification sequence exceeded enqueued notification sequence during drain")
863+
}
864+
if ctx.Err() != nil {
865+
return
866+
}
867+
c.notifyCond.Wait()
720868
}
721869
}
722870

@@ -733,7 +881,7 @@ func (c *Connection) SendRequestNoResult(ctx context.Context, method string, par
733881
return err
734882
}
735883

736-
pr := &pendingResponse{ch: make(chan anyMessage, 1)}
884+
pr := &pendingResponse{ch: make(chan responseEnvelope, 1)}
737885
c.mu.Lock()
738886
c.pending[idKey] = pr
739887
c.mu.Unlock()
@@ -747,14 +895,12 @@ func (c *Connection) SendRequestNoResult(ctx context.Context, method string, par
747895
if err != nil {
748896
return err
749897
}
898+
if err := c.waitNotificationsUpTo(ctx, resp.notificationWatermark); err != nil {
899+
return err
900+
}
750901

751-
// Wait for all notification handlers that were spawned before this response to complete
752-
// processing. This ensures that when a request returns, all notifications sent by the
753-
// server before the response have been fully processed.
754-
c.notificationWg.Wait()
755-
756-
if resp.Error != nil {
757-
return resp.Error
902+
if resp.msg.Error != nil {
903+
return resp.msg.Error
758904
}
759905
return nil
760906
}

connection_cancel_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ func TestConnectionWaitForResponse_PeerDisconnectWinsOverDerivedContextCancel(t
631631
}
632632

633633
idKey := fmt.Sprintf("id-%d", i)
634-
pr := &pendingResponse{ch: make(chan anyMessage)}
634+
pr := &pendingResponse{ch: make(chan responseEnvelope)}
635635
c.pending[idKey] = pr
636636

637637
requestCtx, requestCancel := context.WithCancel(baseCtx)

0 commit comments

Comments
 (0)