Skip to content

Commit d4e3dcc

Browse files
committed
notifications: deduplicate risk fanout
1 parent 8fd491d commit d4e3dcc

2 files changed

Lines changed: 155 additions & 153 deletions

File tree

notifications/manager.go

Lines changed: 108 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ func queueNotification[T any](sub subscriber, recvChan chan T, ntfn T) {
203203
}
204204
}
205205

206+
func dropNotification[T any](sub subscriber, recvChan chan T, ntfn T,
207+
description string) {
208+
209+
select {
210+
case recvChan <- ntfn:
211+
case <-sub.subCtx.Done():
212+
default:
213+
log.Debugf("Dropping %s notification for slow subscriber",
214+
description)
215+
}
216+
}
217+
206218
// SubscribeReservations subscribes to the reservation notifications.
207219
func (m *Manager) SubscribeReservations(ctx context.Context,
208220
) <-chan *swapserverrpc.ServerReservationNotification {
@@ -470,6 +482,64 @@ func (m *Manager) subscribeNotifications(ctx context.Context) error {
470482
}
471483
}
472484

485+
func staticLoopInRiskDecisionName(accepted bool) string {
486+
if accepted {
487+
return "accepted"
488+
}
489+
490+
return "rejected"
491+
}
492+
493+
func (m *Manager) handleStaticLoopInRiskDecision(ctx context.Context,
494+
swapHashBytes []byte, accepted bool, notifType NotificationType,
495+
cacheDecision func(lntypes.Hash), notifySubscriber func(subscriber)) {
496+
497+
decision := staticLoopInRiskDecisionName(accepted)
498+
499+
var (
500+
swapHash lntypes.Hash
501+
hasSwapHash bool
502+
)
503+
if swapHashBytes != nil {
504+
hash, err := lntypes.MakeHash(swapHashBytes)
505+
if err != nil {
506+
log.Warnf("Received invalid static loop in risk "+
507+
"%s notification: %v", decision, err)
508+
} else {
509+
swapHash = hash
510+
hasSwapHash = true
511+
}
512+
}
513+
514+
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
515+
err := m.cfg.PersistStaticLoopInRiskDecision(
516+
ctx, swapHash, accepted,
517+
)
518+
if err != nil {
519+
log.Errorf("Unable to persist static loop in risk "+
520+
"%s notification: %v", decision, err)
521+
return
522+
}
523+
}
524+
525+
m.Lock()
526+
defer m.Unlock()
527+
528+
if hasSwapHash {
529+
cacheDecision(swapHash)
530+
}
531+
532+
for _, sub := range m.subscribers[notifType] {
533+
if !hasSwapHash || sub.swapHash == nil ||
534+
*sub.swapHash != swapHash {
535+
536+
continue
537+
}
538+
539+
notifySubscriber(sub)
540+
}
541+
}
542+
473543
// handleNotification handles an incoming notification from the server,
474544
// forwarding it to the appropriate subscribers.
475545
func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
@@ -512,117 +582,55 @@ func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
512582
// We'll forward the static loop in risk accepted notification to the
513583
// subscriber for the matching swap.
514584
riskAcceptedNtfn := ntfn.GetStaticLoopInRiskAccepted()
515-
var (
516-
swapHash lntypes.Hash
517-
hasSwapHash bool
518-
)
585+
var swapHashBytes []byte
519586
if riskAcceptedNtfn != nil {
520-
hash, err := lntypes.MakeHash(riskAcceptedNtfn.SwapHash)
521-
if err != nil {
522-
log.Warnf("Received invalid static loop in risk "+
523-
"accepted notification: %v", err)
524-
} else {
525-
swapHash = hash
526-
hasSwapHash = true
527-
}
587+
swapHashBytes = riskAcceptedNtfn.SwapHash
528588
}
529589

530-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
531-
err := m.cfg.PersistStaticLoopInRiskDecision(
532-
ctx, swapHash, true,
533-
)
534-
if err != nil {
535-
log.Errorf("Unable to persist static loop in "+
536-
"risk accepted notification: %v", err)
537-
return
538-
}
539-
}
540-
541-
m.Lock()
542-
defer m.Unlock()
543-
544-
if hasSwapHash {
545-
m.staticLoopInRiskAccepted[swapHash] =
546-
riskAcceptedNtfn
547-
delete(m.staticLoopInRiskRejected, swapHash)
548-
}
549-
550-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskAccepted] { // nolint: lll
551-
if !hasSwapHash || sub.swapHash == nil ||
552-
*sub.swapHash != swapHash {
553-
554-
continue
555-
}
556-
557-
recvChan := sub.recvChan.(chan *swapserverrpc.
558-
ServerStaticLoopInRiskAcceptedNotification)
559-
560-
select {
561-
case recvChan <- riskAcceptedNtfn:
562-
case <-sub.subCtx.Done():
563-
default:
564-
log.Debugf("Dropping static loop in risk " +
565-
"accepted notification for slow subscriber")
566-
}
567-
}
590+
m.handleStaticLoopInRiskDecision(
591+
ctx, swapHashBytes, true,
592+
NotificationTypeStaticLoopInRiskAccepted,
593+
func(swapHash lntypes.Hash) {
594+
m.staticLoopInRiskAccepted[swapHash] =
595+
riskAcceptedNtfn
596+
delete(m.staticLoopInRiskRejected, swapHash)
597+
},
598+
func(sub subscriber) {
599+
recvChan := sub.recvChan.(chan *swapserverrpc.
600+
ServerStaticLoopInRiskAcceptedNotification)
601+
dropNotification(
602+
sub, recvChan, riskAcceptedNtfn,
603+
"static loop in risk accepted",
604+
)
605+
},
606+
)
568607

569608
case *swapserverrpc.SubscribeNotificationsResponse_StaticLoopInRiskRejected: // nolint: lll
570609
// We'll forward the static loop in risk rejected notification to the
571610
// subscriber for the matching swap.
572611
riskRejectedNtfn := ntfn.GetStaticLoopInRiskRejected()
573-
var (
574-
swapHash lntypes.Hash
575-
hasSwapHash bool
576-
)
612+
var swapHashBytes []byte
577613
if riskRejectedNtfn != nil {
578-
hash, err := lntypes.MakeHash(riskRejectedNtfn.SwapHash)
579-
if err != nil {
580-
log.Warnf("Received invalid static loop in risk "+
581-
"rejected notification: %v", err)
582-
} else {
583-
swapHash = hash
584-
hasSwapHash = true
585-
}
614+
swapHashBytes = riskRejectedNtfn.SwapHash
586615
}
587616

588-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
589-
err := m.cfg.PersistStaticLoopInRiskDecision(
590-
ctx, swapHash, false,
591-
)
592-
if err != nil {
593-
log.Errorf("Unable to persist static loop in "+
594-
"risk rejected notification: %v", err)
595-
return
596-
}
597-
}
598-
599-
m.Lock()
600-
defer m.Unlock()
601-
602-
if hasSwapHash {
603-
m.staticLoopInRiskRejected[swapHash] =
604-
riskRejectedNtfn
605-
delete(m.staticLoopInRiskAccepted, swapHash)
606-
}
607-
608-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskRejected] { // nolint: lll
609-
if !hasSwapHash || sub.swapHash == nil ||
610-
*sub.swapHash != swapHash {
611-
612-
continue
613-
}
614-
615-
recvChan := sub.recvChan.(chan *swapserverrpc.
616-
ServerStaticLoopInRiskRejectedNotification)
617-
618-
select {
619-
case recvChan <- riskRejectedNtfn:
620-
case <-sub.subCtx.Done():
621-
default:
622-
log.Debugf("Dropping static loop in risk " +
623-
"rejected notification for slow subscriber")
624-
}
625-
}
617+
m.handleStaticLoopInRiskDecision(
618+
ctx, swapHashBytes, false,
619+
NotificationTypeStaticLoopInRiskRejected,
620+
func(swapHash lntypes.Hash) {
621+
m.staticLoopInRiskRejected[swapHash] =
622+
riskRejectedNtfn
623+
delete(m.staticLoopInRiskAccepted, swapHash)
624+
},
625+
func(sub subscriber) {
626+
recvChan := sub.recvChan.(chan *swapserverrpc.
627+
ServerStaticLoopInRiskRejectedNotification)
628+
dropNotification(
629+
sub, recvChan, riskRejectedNtfn,
630+
"static loop in risk rejected",
631+
)
632+
},
633+
)
626634

627635
case *swapserverrpc.SubscribeNotificationsResponse_UnfinishedSwap: // nolint: lll
628636
// We'll forward the unfinished swap notification to all

notifications/manager_test.go

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -350,48 +350,21 @@ func TestManager_SlowSubscriberDoesNotBlock(t *testing.T) {
350350
func TestManager_UnfinishedSwapNotificationWaitsForSubscriber(t *testing.T) {
351351
t.Parallel()
352352

353-
mgr := NewManager(&Config{})
354-
355-
subCtx, subCancel := context.WithCancel(t.Context())
356-
defer subCancel()
357-
358-
subChan := mgr.SubscribeUnfinishedSwaps(subCtx)
359-
360-
swapHashA := lntypes.Hash{0x02, 0x03}
361-
swapHashB := lntypes.Hash{0x04, 0x05}
362-
363-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashA))
364-
365-
done := make(chan struct{})
366-
go func() {
367-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashB))
368-
close(done)
369-
}()
370-
371-
require.Eventually(t, func() bool {
372-
select {
373-
case <-done:
374-
return true
375-
default:
376-
return false
377-
}
378-
}, time.Second, 10*time.Millisecond)
379-
380-
select {
381-
case received := <-subChan:
382-
require.Equal(t, swapHashA[:], received.SwapHash)
353+
assertQueuedSwapHashNotifications(
354+
t,
355+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
356+
ServerUnfinishedSwapNotification {
383357

384-
case <-time.After(time.Second):
385-
t.Fatal("did not receive first unfinished swap notification")
386-
}
387-
388-
select {
389-
case received := <-subChan:
390-
require.Equal(t, swapHashB[:], received.SwapHash)
391-
392-
case <-time.After(time.Second):
393-
t.Fatal("second unfinished swap notification was dropped")
394-
}
358+
return mgr.SubscribeUnfinishedSwaps(ctx)
359+
},
360+
unfinishedSwapNotification,
361+
func(ntfn *swapserverrpc.ServerUnfinishedSwapNotification) []byte {
362+
return ntfn.SwapHash
363+
},
364+
lntypes.Hash{0x02, 0x03}, lntypes.Hash{0x04, 0x05},
365+
"did not receive first unfinished swap notification",
366+
"second unfinished swap notification was dropped",
367+
)
395368
}
396369

397370
// TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber verifies
@@ -402,23 +375,44 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
402375

403376
t.Parallel()
404377

378+
assertQueuedSwapHashNotifications(
379+
t,
380+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
381+
ServerStaticLoopInSweepNotification {
382+
383+
return mgr.SubscribeStaticLoopInSweepRequests(ctx)
384+
},
385+
staticLoopInSweepNotification,
386+
func(ntfn *swapserverrpc.ServerStaticLoopInSweepNotification) []byte {
387+
return ntfn.SwapHash
388+
},
389+
lntypes.Hash{0x12, 0x13}, lntypes.Hash{0x14, 0x15},
390+
"did not receive first sweep notification",
391+
"second sweep notification was not queued",
392+
)
393+
}
394+
395+
func assertQueuedSwapHashNotifications[T any](t *testing.T,
396+
subscribe func(*Manager, context.Context) <-chan T,
397+
notification func(lntypes.Hash) *swapserverrpc.
398+
SubscribeNotificationsResponse,
399+
swapHash func(T) []byte, swapHashA, swapHashB lntypes.Hash,
400+
firstFailureMsg, secondFailureMsg string) {
401+
402+
t.Helper()
403+
405404
mgr := NewManager(&Config{})
406405

407406
subCtx, subCancel := context.WithCancel(t.Context())
408407
defer subCancel()
409408

410-
subChan := mgr.SubscribeStaticLoopInSweepRequests(subCtx)
409+
subChan := subscribe(mgr, subCtx)
411410

412-
swapHashA := lntypes.Hash{0x12, 0x13}
413-
swapHashB := lntypes.Hash{0x14, 0x15}
414-
415-
mgr.handleNotification(t.Context(), staticLoopInSweepNotification(swapHashA))
411+
mgr.handleNotification(t.Context(), notification(swapHashA))
416412

417413
done := make(chan struct{})
418414
go func() {
419-
mgr.handleNotification(
420-
t.Context(), staticLoopInSweepNotification(swapHashB),
421-
)
415+
mgr.handleNotification(t.Context(), notification(swapHashB))
422416
close(done)
423417
}()
424418

@@ -433,18 +427,18 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
433427

434428
select {
435429
case received := <-subChan:
436-
require.Equal(t, swapHashA[:], received.SwapHash)
430+
require.Equal(t, swapHashA[:], swapHash(received))
437431

438432
case <-time.After(time.Second):
439-
t.Fatal("did not receive first sweep notification")
433+
t.Fatal(firstFailureMsg)
440434
}
441435

442436
select {
443437
case received := <-subChan:
444-
require.Equal(t, swapHashB[:], received.SwapHash)
438+
require.Equal(t, swapHashB[:], swapHash(received))
445439

446440
case <-time.After(time.Second):
447-
t.Fatal("second sweep notification was not queued")
441+
t.Fatal(secondFailureMsg)
448442
}
449443
}
450444

0 commit comments

Comments
 (0)