Skip to content

Commit 8ae163c

Browse files
committed
notifications: deduplicate risk fanout
Risk accepted and rejected notifications had nearly identical parsing, persistence, cache update, and subscriber fanout code. The duplication made it easy for the two decision paths to drift. Introduce a shared risk-decision handler that validates the swap hash, persists the decision, updates the matching cache, clears the opposite cache, and fans out only to the subscriber for that swap. Keep risk-decision delivery best-effort for slow subscribers, while queued delivery remains limited to notification types that must not be dropped. Fold the queue tests through a common helper so both queued notification paths keep the same behavior.
1 parent b3fd80e commit 8ae163c

2 files changed

Lines changed: 159 additions & 151 deletions

File tree

notifications/manager.go

Lines changed: 111 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,19 @@ func queueNotification[T any](sub subscriber, recvChan chan T, ntfn T) {
215215
}
216216
}
217217

218+
// dropNotification sends a best-effort notification to a subscriber.
219+
func dropNotification[T any](sub subscriber, recvChan chan T, ntfn T,
220+
description string) {
221+
222+
select {
223+
case recvChan <- ntfn:
224+
case <-sub.subCtx.Done():
225+
default:
226+
log.Debugf("Dropping %s notification for slow subscriber",
227+
description)
228+
}
229+
}
230+
218231
// SubscribeReservations subscribes to the reservation notifications.
219232
func (m *Manager) SubscribeReservations(ctx context.Context,
220233
) <-chan *swapserverrpc.ServerReservationNotification {
@@ -503,6 +516,66 @@ func (m *Manager) subscribeNotifications(ctx context.Context) error {
503516
}
504517
}
505518

519+
// staticLoopInRiskDecisionName returns the log label for a risk decision.
520+
func staticLoopInRiskDecisionName(accepted bool) string {
521+
if accepted {
522+
return "accepted"
523+
}
524+
525+
return "rejected"
526+
}
527+
528+
// handleStaticLoopInRiskDecision persists, caches, and forwards a risk
529+
// decision notification to the matching subscriber.
530+
func (m *Manager) handleStaticLoopInRiskDecision(ctx context.Context,
531+
swapHashBytes []byte, accepted bool, notifType NotificationType,
532+
cacheDecision func(lntypes.Hash), notifySubscriber func(subscriber)) {
533+
534+
decision := staticLoopInRiskDecisionName(accepted)
535+
536+
var (
537+
swapHash lntypes.Hash
538+
hasSwapHash bool
539+
)
540+
if swapHashBytes != nil {
541+
hash, err := lntypes.MakeHash(swapHashBytes)
542+
if err != nil {
543+
log.Warnf("Received invalid static loop in risk "+
544+
"%s notification: %v", decision, err)
545+
} else {
546+
swapHash = hash
547+
hasSwapHash = true
548+
}
549+
}
550+
551+
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
552+
err := m.cfg.PersistStaticLoopInRiskDecision(
553+
ctx, swapHash, accepted,
554+
)
555+
if err != nil {
556+
log.Errorf("Unable to persist static loop in risk "+
557+
"%s notification: %v", decision, err)
558+
}
559+
}
560+
561+
m.Lock()
562+
defer m.Unlock()
563+
564+
if hasSwapHash {
565+
cacheDecision(swapHash)
566+
}
567+
568+
for _, sub := range m.subscribers[notifType] {
569+
if !hasSwapHash || sub.swapHash == nil ||
570+
*sub.swapHash != swapHash {
571+
572+
continue
573+
}
574+
575+
notifySubscriber(sub)
576+
}
577+
}
578+
506579
// handleNotification handles an incoming notification from the server,
507580
// forwarding it to the appropriate subscribers.
508581
func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
@@ -545,115 +618,55 @@ func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
545618
// We'll forward the static loop in risk accepted notification to the
546619
// subscriber for the matching swap.
547620
riskAcceptedNtfn := ntfn.GetStaticLoopInRiskAccepted()
548-
var (
549-
swapHash lntypes.Hash
550-
hasSwapHash bool
551-
)
621+
var swapHashBytes []byte
552622
if riskAcceptedNtfn != nil {
553-
hash, err := lntypes.MakeHash(riskAcceptedNtfn.SwapHash)
554-
if err != nil {
555-
log.Warnf("Received invalid static loop in risk "+
556-
"accepted notification: %v", err)
557-
} else {
558-
swapHash = hash
559-
hasSwapHash = true
560-
}
561-
}
562-
563-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
564-
err := m.cfg.PersistStaticLoopInRiskDecision(
565-
ctx, swapHash, true,
566-
)
567-
if err != nil {
568-
log.Errorf("Unable to persist static loop in "+
569-
"risk accepted notification: %v", err)
570-
}
623+
swapHashBytes = riskAcceptedNtfn.SwapHash
571624
}
572625

573-
m.Lock()
574-
defer m.Unlock()
575-
576-
if hasSwapHash {
577-
m.staticLoopInRiskAccepted[swapHash] =
578-
riskAcceptedNtfn
579-
delete(m.staticLoopInRiskRejected, swapHash)
580-
}
581-
582-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskAccepted] { // nolint: lll
583-
if !hasSwapHash || sub.swapHash == nil ||
584-
*sub.swapHash != swapHash {
585-
586-
continue
587-
}
588-
589-
recvChan := sub.recvChan.(chan *swapserverrpc.
590-
ServerStaticLoopInRiskAcceptedNotification)
591-
592-
select {
593-
case recvChan <- riskAcceptedNtfn:
594-
case <-sub.subCtx.Done():
595-
default:
596-
log.Debugf("Dropping static loop in risk " +
597-
"accepted notification for slow subscriber")
598-
}
599-
}
626+
m.handleStaticLoopInRiskDecision(
627+
ctx, swapHashBytes, true,
628+
NotificationTypeStaticLoopInRiskAccepted,
629+
func(swapHash lntypes.Hash) {
630+
m.staticLoopInRiskAccepted[swapHash] =
631+
riskAcceptedNtfn
632+
delete(m.staticLoopInRiskRejected, swapHash)
633+
},
634+
func(sub subscriber) {
635+
recvChan := sub.recvChan.(chan *swapserverrpc.
636+
ServerStaticLoopInRiskAcceptedNotification)
637+
dropNotification(
638+
sub, recvChan, riskAcceptedNtfn,
639+
"static loop in risk accepted",
640+
)
641+
},
642+
)
600643

601644
case *swapserverrpc.SubscribeNotificationsResponse_StaticLoopInRiskRejected: // nolint: lll
602645
// We'll forward the static loop in risk rejected notification to the
603646
// subscriber for the matching swap.
604647
riskRejectedNtfn := ntfn.GetStaticLoopInRiskRejected()
605-
var (
606-
swapHash lntypes.Hash
607-
hasSwapHash bool
608-
)
648+
var swapHashBytes []byte
609649
if riskRejectedNtfn != nil {
610-
hash, err := lntypes.MakeHash(riskRejectedNtfn.SwapHash)
611-
if err != nil {
612-
log.Warnf("Received invalid static loop in risk "+
613-
"rejected notification: %v", err)
614-
} else {
615-
swapHash = hash
616-
hasSwapHash = true
617-
}
650+
swapHashBytes = riskRejectedNtfn.SwapHash
618651
}
619652

620-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
621-
err := m.cfg.PersistStaticLoopInRiskDecision(
622-
ctx, swapHash, false,
623-
)
624-
if err != nil {
625-
log.Errorf("Unable to persist static loop in "+
626-
"risk rejected notification: %v", err)
627-
}
628-
}
629-
630-
m.Lock()
631-
defer m.Unlock()
632-
633-
if hasSwapHash {
634-
m.staticLoopInRiskRejected[swapHash] =
635-
riskRejectedNtfn
636-
delete(m.staticLoopInRiskAccepted, swapHash)
637-
}
638-
639-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskRejected] { // nolint: lll
640-
if !hasSwapHash || sub.swapHash == nil ||
641-
*sub.swapHash != swapHash {
642-
643-
continue
644-
}
645-
646-
recvChan := sub.recvChan.(chan *swapserverrpc.
647-
ServerStaticLoopInRiskRejectedNotification)
648-
649-
select {
650-
case recvChan <- riskRejectedNtfn:
651-
case <-sub.subCtx.Done():
652-
default:
653-
log.Debugf("Dropping static loop in risk " +
654-
"rejected notification for slow subscriber")
655-
}
656-
}
653+
m.handleStaticLoopInRiskDecision(
654+
ctx, swapHashBytes, false,
655+
NotificationTypeStaticLoopInRiskRejected,
656+
func(swapHash lntypes.Hash) {
657+
m.staticLoopInRiskRejected[swapHash] =
658+
riskRejectedNtfn
659+
delete(m.staticLoopInRiskAccepted, swapHash)
660+
},
661+
func(sub subscriber) {
662+
recvChan := sub.recvChan.(chan *swapserverrpc.
663+
ServerStaticLoopInRiskRejectedNotification)
664+
dropNotification(
665+
sub, recvChan, riskRejectedNtfn,
666+
"static loop in risk rejected",
667+
)
668+
},
669+
)
657670

658671
case *swapserverrpc.SubscribeNotificationsResponse_UnfinishedSwap: // nolint: lll
659672
// We'll forward the unfinished swap notification to all

notifications/manager_test.go

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -357,48 +357,21 @@ func TestManager_SlowSubscriberDoesNotBlock(t *testing.T) {
357357
func TestManager_UnfinishedSwapNotificationWaitsForSubscriber(t *testing.T) {
358358
t.Parallel()
359359

360-
mgr := NewManager(&Config{})
361-
362-
subCtx, subCancel := context.WithCancel(t.Context())
363-
defer subCancel()
364-
365-
subChan := mgr.SubscribeUnfinishedSwaps(subCtx)
366-
367-
swapHashA := lntypes.Hash{0x02, 0x03}
368-
swapHashB := lntypes.Hash{0x04, 0x05}
369-
370-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashA))
371-
372-
done := make(chan struct{})
373-
go func() {
374-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashB))
375-
close(done)
376-
}()
377-
378-
require.Eventually(t, func() bool {
379-
select {
380-
case <-done:
381-
return true
382-
default:
383-
return false
384-
}
385-
}, time.Second, 10*time.Millisecond)
386-
387-
select {
388-
case received := <-subChan:
389-
require.Equal(t, swapHashA[:], received.SwapHash)
360+
assertQueuedSwapHashNotifications(
361+
t,
362+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
363+
ServerUnfinishedSwapNotification {
390364

391-
case <-time.After(time.Second):
392-
t.Fatal("did not receive first unfinished swap notification")
393-
}
394-
395-
select {
396-
case received := <-subChan:
397-
require.Equal(t, swapHashB[:], received.SwapHash)
398-
399-
case <-time.After(time.Second):
400-
t.Fatal("second unfinished swap notification was dropped")
401-
}
365+
return mgr.SubscribeUnfinishedSwaps(ctx)
366+
},
367+
unfinishedSwapNotification,
368+
func(ntfn *swapserverrpc.ServerUnfinishedSwapNotification) []byte {
369+
return ntfn.SwapHash
370+
},
371+
lntypes.Hash{0x02, 0x03}, lntypes.Hash{0x04, 0x05},
372+
"did not receive first unfinished swap notification",
373+
"second unfinished swap notification was dropped",
374+
)
402375
}
403376

404377
// TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber verifies
@@ -409,23 +382,45 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
409382

410383
t.Parallel()
411384

385+
assertQueuedSwapHashNotifications(
386+
t,
387+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
388+
ServerStaticLoopInSweepNotification {
389+
390+
return mgr.SubscribeStaticLoopInSweepRequests(ctx)
391+
},
392+
staticLoopInSweepNotification,
393+
func(ntfn *swapserverrpc.ServerStaticLoopInSweepNotification) []byte {
394+
return ntfn.SwapHash
395+
},
396+
lntypes.Hash{0x12, 0x13}, lntypes.Hash{0x14, 0x15},
397+
"did not receive first sweep notification",
398+
"second sweep notification was not queued",
399+
)
400+
}
401+
402+
// assertQueuedSwapHashNotifications checks queued delivery for swap hashes.
403+
func assertQueuedSwapHashNotifications[T any](t *testing.T,
404+
subscribe func(*Manager, context.Context) <-chan T,
405+
notification func(lntypes.Hash) *swapserverrpc.
406+
SubscribeNotificationsResponse,
407+
swapHash func(T) []byte, swapHashA, swapHashB lntypes.Hash,
408+
firstFailureMsg, secondFailureMsg string) {
409+
410+
t.Helper()
411+
412412
mgr := NewManager(&Config{})
413413

414414
subCtx, subCancel := context.WithCancel(t.Context())
415415
defer subCancel()
416416

417-
subChan := mgr.SubscribeStaticLoopInSweepRequests(subCtx)
417+
subChan := subscribe(mgr, subCtx)
418418

419-
swapHashA := lntypes.Hash{0x12, 0x13}
420-
swapHashB := lntypes.Hash{0x14, 0x15}
421-
422-
mgr.handleNotification(t.Context(), staticLoopInSweepNotification(swapHashA))
419+
mgr.handleNotification(t.Context(), notification(swapHashA))
423420

424421
done := make(chan struct{})
425422
go func() {
426-
mgr.handleNotification(
427-
t.Context(), staticLoopInSweepNotification(swapHashB),
428-
)
423+
mgr.handleNotification(t.Context(), notification(swapHashB))
429424
close(done)
430425
}()
431426

@@ -440,18 +435,18 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
440435

441436
select {
442437
case received := <-subChan:
443-
require.Equal(t, swapHashA[:], received.SwapHash)
438+
require.Equal(t, swapHashA[:], swapHash(received))
444439

445440
case <-time.After(time.Second):
446-
t.Fatal("did not receive first sweep notification")
441+
t.Fatal(firstFailureMsg)
447442
}
448443

449444
select {
450445
case received := <-subChan:
451-
require.Equal(t, swapHashB[:], received.SwapHash)
446+
require.Equal(t, swapHashB[:], swapHash(received))
452447

453448
case <-time.After(time.Second):
454-
t.Fatal("second sweep notification was not queued")
449+
t.Fatal(secondFailureMsg)
455450
}
456451
}
457452

0 commit comments

Comments
 (0)