Skip to content

Commit 7bee626

Browse files
committed
notifications: deduplicate risk fanout
1 parent 7604be5 commit 7bee626

2 files changed

Lines changed: 154 additions & 151 deletions

File tree

notifications/manager.go

Lines changed: 107 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,18 @@ func queueNotification[T any](sub subscriber, recvChan chan T, ntfn T) {
205205
}
206206
}
207207

208+
func dropNotification[T any](sub subscriber, recvChan chan T, ntfn T,
209+
description string) {
210+
211+
select {
212+
case recvChan <- ntfn:
213+
case <-sub.subCtx.Done():
214+
default:
215+
log.Debugf("Dropping %s notification for slow subscriber",
216+
description)
217+
}
218+
}
219+
208220
// SubscribeReservations subscribes to the reservation notifications.
209221
func (m *Manager) SubscribeReservations(ctx context.Context,
210222
) <-chan *swapserverrpc.ServerReservationNotification {
@@ -472,6 +484,63 @@ func (m *Manager) subscribeNotifications(ctx context.Context) error {
472484
}
473485
}
474486

487+
func staticLoopInRiskDecisionName(accepted bool) string {
488+
if accepted {
489+
return "accepted"
490+
}
491+
492+
return "rejected"
493+
}
494+
495+
func (m *Manager) handleStaticLoopInRiskDecision(ctx context.Context,
496+
swapHashBytes []byte, accepted bool, notifType NotificationType,
497+
cacheDecision func(lntypes.Hash), notifySubscriber func(subscriber)) {
498+
499+
decision := staticLoopInRiskDecisionName(accepted)
500+
501+
var (
502+
swapHash lntypes.Hash
503+
hasSwapHash bool
504+
)
505+
if swapHashBytes != nil {
506+
hash, err := lntypes.MakeHash(swapHashBytes)
507+
if err != nil {
508+
log.Warnf("Received invalid static loop in risk "+
509+
"%s notification: %v", decision, err)
510+
} else {
511+
swapHash = hash
512+
hasSwapHash = true
513+
}
514+
}
515+
516+
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
517+
err := m.cfg.PersistStaticLoopInRiskDecision(
518+
ctx, swapHash, accepted,
519+
)
520+
if err != nil {
521+
log.Errorf("Unable to persist static loop in risk "+
522+
"%s notification: %v", decision, err)
523+
}
524+
}
525+
526+
m.Lock()
527+
defer m.Unlock()
528+
529+
if hasSwapHash {
530+
cacheDecision(swapHash)
531+
}
532+
533+
for _, sub := range m.subscribers[notifType] {
534+
if !hasSwapHash || sub.swapHash == nil ||
535+
*sub.swapHash != swapHash {
536+
537+
continue
538+
}
539+
540+
notifySubscriber(sub)
541+
}
542+
}
543+
475544
// handleNotification handles an incoming notification from the server,
476545
// forwarding it to the appropriate subscribers.
477546
func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
@@ -514,115 +583,55 @@ func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
514583
// We'll forward the static loop in risk accepted notification to the
515584
// subscriber for the matching swap.
516585
riskAcceptedNtfn := ntfn.GetStaticLoopInRiskAccepted()
517-
var (
518-
swapHash lntypes.Hash
519-
hasSwapHash bool
520-
)
586+
var swapHashBytes []byte
521587
if riskAcceptedNtfn != nil {
522-
hash, err := lntypes.MakeHash(riskAcceptedNtfn.SwapHash)
523-
if err != nil {
524-
log.Warnf("Received invalid static loop in risk "+
525-
"accepted notification: %v", err)
526-
} else {
527-
swapHash = hash
528-
hasSwapHash = true
529-
}
530-
}
531-
532-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
533-
err := m.cfg.PersistStaticLoopInRiskDecision(
534-
ctx, swapHash, true,
535-
)
536-
if err != nil {
537-
log.Errorf("Unable to persist static loop in "+
538-
"risk accepted notification: %v", err)
539-
}
588+
swapHashBytes = riskAcceptedNtfn.SwapHash
540589
}
541590

542-
m.Lock()
543-
defer m.Unlock()
544-
545-
if hasSwapHash {
546-
m.staticLoopInRiskAccepted[swapHash] =
547-
riskAcceptedNtfn
548-
delete(m.staticLoopInRiskRejected, swapHash)
549-
}
550-
551-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskAccepted] { // nolint: lll
552-
if !hasSwapHash || sub.swapHash == nil ||
553-
*sub.swapHash != swapHash {
554-
555-
continue
556-
}
557-
558-
recvChan := sub.recvChan.(chan *swapserverrpc.
559-
ServerStaticLoopInRiskAcceptedNotification)
560-
561-
select {
562-
case recvChan <- riskAcceptedNtfn:
563-
case <-sub.subCtx.Done():
564-
default:
565-
log.Debugf("Dropping static loop in risk " +
566-
"accepted notification for slow subscriber")
567-
}
568-
}
591+
m.handleStaticLoopInRiskDecision(
592+
ctx, swapHashBytes, true,
593+
NotificationTypeStaticLoopInRiskAccepted,
594+
func(swapHash lntypes.Hash) {
595+
m.staticLoopInRiskAccepted[swapHash] =
596+
riskAcceptedNtfn
597+
delete(m.staticLoopInRiskRejected, swapHash)
598+
},
599+
func(sub subscriber) {
600+
recvChan := sub.recvChan.(chan *swapserverrpc.
601+
ServerStaticLoopInRiskAcceptedNotification)
602+
dropNotification(
603+
sub, recvChan, riskAcceptedNtfn,
604+
"static loop in risk accepted",
605+
)
606+
},
607+
)
569608

570609
case *swapserverrpc.SubscribeNotificationsResponse_StaticLoopInRiskRejected: // nolint: lll
571610
// We'll forward the static loop in risk rejected notification to the
572611
// subscriber for the matching swap.
573612
riskRejectedNtfn := ntfn.GetStaticLoopInRiskRejected()
574-
var (
575-
swapHash lntypes.Hash
576-
hasSwapHash bool
577-
)
613+
var swapHashBytes []byte
578614
if riskRejectedNtfn != nil {
579-
hash, err := lntypes.MakeHash(riskRejectedNtfn.SwapHash)
580-
if err != nil {
581-
log.Warnf("Received invalid static loop in risk "+
582-
"rejected notification: %v", err)
583-
} else {
584-
swapHash = hash
585-
hasSwapHash = true
586-
}
615+
swapHashBytes = riskRejectedNtfn.SwapHash
587616
}
588617

589-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
590-
err := m.cfg.PersistStaticLoopInRiskDecision(
591-
ctx, swapHash, false,
592-
)
593-
if err != nil {
594-
log.Errorf("Unable to persist static loop in "+
595-
"risk rejected notification: %v", err)
596-
}
597-
}
598-
599-
m.Lock()
600-
defer m.Unlock()
601-
602-
if hasSwapHash {
603-
m.staticLoopInRiskRejected[swapHash] =
604-
riskRejectedNtfn
605-
delete(m.staticLoopInRiskAccepted, swapHash)
606-
}
607-
608-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskRejected] { // nolint: lll
609-
if !hasSwapHash || sub.swapHash == nil ||
610-
*sub.swapHash != swapHash {
611-
612-
continue
613-
}
614-
615-
recvChan := sub.recvChan.(chan *swapserverrpc.
616-
ServerStaticLoopInRiskRejectedNotification)
617-
618-
select {
619-
case recvChan <- riskRejectedNtfn:
620-
case <-sub.subCtx.Done():
621-
default:
622-
log.Debugf("Dropping static loop in risk " +
623-
"rejected notification for slow subscriber")
624-
}
625-
}
618+
m.handleStaticLoopInRiskDecision(
619+
ctx, swapHashBytes, false,
620+
NotificationTypeStaticLoopInRiskRejected,
621+
func(swapHash lntypes.Hash) {
622+
m.staticLoopInRiskRejected[swapHash] =
623+
riskRejectedNtfn
624+
delete(m.staticLoopInRiskAccepted, swapHash)
625+
},
626+
func(sub subscriber) {
627+
recvChan := sub.recvChan.(chan *swapserverrpc.
628+
ServerStaticLoopInRiskRejectedNotification)
629+
dropNotification(
630+
sub, recvChan, riskRejectedNtfn,
631+
"static loop in risk rejected",
632+
)
633+
},
634+
)
626635

627636
case *swapserverrpc.SubscribeNotificationsResponse_UnfinishedSwap: // nolint: lll
628637
// We'll forward the unfinished swap notification to all

notifications/manager_test.go

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -350,48 +350,21 @@ func TestManager_SlowSubscriberDoesNotBlock(t *testing.T) {
350350
func TestManager_UnfinishedSwapNotificationWaitsForSubscriber(t *testing.T) {
351351
t.Parallel()
352352

353-
mgr := NewManager(&Config{})
354-
355-
subCtx, subCancel := context.WithCancel(t.Context())
356-
defer subCancel()
357-
358-
subChan := mgr.SubscribeUnfinishedSwaps(subCtx)
359-
360-
swapHashA := lntypes.Hash{0x02, 0x03}
361-
swapHashB := lntypes.Hash{0x04, 0x05}
362-
363-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashA))
364-
365-
done := make(chan struct{})
366-
go func() {
367-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashB))
368-
close(done)
369-
}()
370-
371-
require.Eventually(t, func() bool {
372-
select {
373-
case <-done:
374-
return true
375-
default:
376-
return false
377-
}
378-
}, time.Second, 10*time.Millisecond)
379-
380-
select {
381-
case received := <-subChan:
382-
require.Equal(t, swapHashA[:], received.SwapHash)
353+
assertQueuedSwapHashNotifications(
354+
t,
355+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
356+
ServerUnfinishedSwapNotification {
383357

384-
case <-time.After(time.Second):
385-
t.Fatal("did not receive first unfinished swap notification")
386-
}
387-
388-
select {
389-
case received := <-subChan:
390-
require.Equal(t, swapHashB[:], received.SwapHash)
391-
392-
case <-time.After(time.Second):
393-
t.Fatal("second unfinished swap notification was dropped")
394-
}
358+
return mgr.SubscribeUnfinishedSwaps(ctx)
359+
},
360+
unfinishedSwapNotification,
361+
func(ntfn *swapserverrpc.ServerUnfinishedSwapNotification) []byte {
362+
return ntfn.SwapHash
363+
},
364+
lntypes.Hash{0x02, 0x03}, lntypes.Hash{0x04, 0x05},
365+
"did not receive first unfinished swap notification",
366+
"second unfinished swap notification was dropped",
367+
)
395368
}
396369

397370
// TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber verifies
@@ -402,23 +375,44 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
402375

403376
t.Parallel()
404377

378+
assertQueuedSwapHashNotifications(
379+
t,
380+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
381+
ServerStaticLoopInSweepNotification {
382+
383+
return mgr.SubscribeStaticLoopInSweepRequests(ctx)
384+
},
385+
staticLoopInSweepNotification,
386+
func(ntfn *swapserverrpc.ServerStaticLoopInSweepNotification) []byte {
387+
return ntfn.SwapHash
388+
},
389+
lntypes.Hash{0x12, 0x13}, lntypes.Hash{0x14, 0x15},
390+
"did not receive first sweep notification",
391+
"second sweep notification was not queued",
392+
)
393+
}
394+
395+
func assertQueuedSwapHashNotifications[T any](t *testing.T,
396+
subscribe func(*Manager, context.Context) <-chan T,
397+
notification func(lntypes.Hash) *swapserverrpc.
398+
SubscribeNotificationsResponse,
399+
swapHash func(T) []byte, swapHashA, swapHashB lntypes.Hash,
400+
firstFailureMsg, secondFailureMsg string) {
401+
402+
t.Helper()
403+
405404
mgr := NewManager(&Config{})
406405

407406
subCtx, subCancel := context.WithCancel(t.Context())
408407
defer subCancel()
409408

410-
subChan := mgr.SubscribeStaticLoopInSweepRequests(subCtx)
409+
subChan := subscribe(mgr, subCtx)
411410

412-
swapHashA := lntypes.Hash{0x12, 0x13}
413-
swapHashB := lntypes.Hash{0x14, 0x15}
414-
415-
mgr.handleNotification(t.Context(), staticLoopInSweepNotification(swapHashA))
411+
mgr.handleNotification(t.Context(), notification(swapHashA))
416412

417413
done := make(chan struct{})
418414
go func() {
419-
mgr.handleNotification(
420-
t.Context(), staticLoopInSweepNotification(swapHashB),
421-
)
415+
mgr.handleNotification(t.Context(), notification(swapHashB))
422416
close(done)
423417
}()
424418

@@ -433,18 +427,18 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
433427

434428
select {
435429
case received := <-subChan:
436-
require.Equal(t, swapHashA[:], received.SwapHash)
430+
require.Equal(t, swapHashA[:], swapHash(received))
437431

438432
case <-time.After(time.Second):
439-
t.Fatal("did not receive first sweep notification")
433+
t.Fatal(firstFailureMsg)
440434
}
441435

442436
select {
443437
case received := <-subChan:
444-
require.Equal(t, swapHashB[:], received.SwapHash)
438+
require.Equal(t, swapHashB[:], swapHash(received))
445439

446440
case <-time.After(time.Second):
447-
t.Fatal("second sweep notification was not queued")
441+
t.Fatal(secondFailureMsg)
448442
}
449443
}
450444

0 commit comments

Comments
 (0)