Skip to content

Commit ec42d48

Browse files
committed
notifications: deduplicate risk fanout
Risk accepted and rejected notifications had nearly identical parsing, persistence, cache update, and subscriber fanout code. The duplication made it easy for the two decision paths to drift. Introduce a shared risk-decision handler that validates the swap hash, persists the decision, updates the matching cache, clears the opposite cache, and fans out only to the subscriber for that swap. Keep risk-decision delivery best-effort for slow subscribers, while queued delivery remains limited to notification types that must not be dropped. Fold the queue tests through a common helper so both queued notification paths keep the same behavior.
1 parent 7a137df commit ec42d48

2 files changed

Lines changed: 159 additions & 151 deletions

File tree

notifications/manager.go

Lines changed: 111 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,19 @@ func queueNotification[T any](sub subscriber, recvChan chan T, ntfn T) {
207207
}
208208
}
209209

210+
// dropNotification sends a best-effort notification to a subscriber.
211+
func dropNotification[T any](sub subscriber, recvChan chan T, ntfn T,
212+
description string) {
213+
214+
select {
215+
case recvChan <- ntfn:
216+
case <-sub.subCtx.Done():
217+
default:
218+
log.Debugf("Dropping %s notification for slow subscriber",
219+
description)
220+
}
221+
}
222+
210223
// SubscribeReservations subscribes to the reservation notifications.
211224
func (m *Manager) SubscribeReservations(ctx context.Context,
212225
) <-chan *swapserverrpc.ServerReservationNotification {
@@ -474,6 +487,66 @@ func (m *Manager) subscribeNotifications(ctx context.Context) error {
474487
}
475488
}
476489

490+
// staticLoopInRiskDecisionName returns the log label for a risk decision.
491+
func staticLoopInRiskDecisionName(accepted bool) string {
492+
if accepted {
493+
return "accepted"
494+
}
495+
496+
return "rejected"
497+
}
498+
499+
// handleStaticLoopInRiskDecision persists, caches, and forwards a risk
500+
// decision notification to the matching subscriber.
501+
func (m *Manager) handleStaticLoopInRiskDecision(ctx context.Context,
502+
swapHashBytes []byte, accepted bool, notifType NotificationType,
503+
cacheDecision func(lntypes.Hash), notifySubscriber func(subscriber)) {
504+
505+
decision := staticLoopInRiskDecisionName(accepted)
506+
507+
var (
508+
swapHash lntypes.Hash
509+
hasSwapHash bool
510+
)
511+
if swapHashBytes != nil {
512+
hash, err := lntypes.MakeHash(swapHashBytes)
513+
if err != nil {
514+
log.Warnf("Received invalid static loop in risk "+
515+
"%s notification: %v", decision, err)
516+
} else {
517+
swapHash = hash
518+
hasSwapHash = true
519+
}
520+
}
521+
522+
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
523+
err := m.cfg.PersistStaticLoopInRiskDecision(
524+
ctx, swapHash, accepted,
525+
)
526+
if err != nil {
527+
log.Errorf("Unable to persist static loop in risk "+
528+
"%s notification: %v", decision, err)
529+
}
530+
}
531+
532+
m.Lock()
533+
defer m.Unlock()
534+
535+
if hasSwapHash {
536+
cacheDecision(swapHash)
537+
}
538+
539+
for _, sub := range m.subscribers[notifType] {
540+
if !hasSwapHash || sub.swapHash == nil ||
541+
*sub.swapHash != swapHash {
542+
543+
continue
544+
}
545+
546+
notifySubscriber(sub)
547+
}
548+
}
549+
477550
// handleNotification handles an incoming notification from the server,
478551
// forwarding it to the appropriate subscribers.
479552
func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
@@ -516,115 +589,55 @@ func (m *Manager) handleNotification(ctx context.Context, ntfn *swapserverrpc.
516589
// We'll forward the static loop in risk accepted notification to the
517590
// subscriber for the matching swap.
518591
riskAcceptedNtfn := ntfn.GetStaticLoopInRiskAccepted()
519-
var (
520-
swapHash lntypes.Hash
521-
hasSwapHash bool
522-
)
592+
var swapHashBytes []byte
523593
if riskAcceptedNtfn != nil {
524-
hash, err := lntypes.MakeHash(riskAcceptedNtfn.SwapHash)
525-
if err != nil {
526-
log.Warnf("Received invalid static loop in risk "+
527-
"accepted notification: %v", err)
528-
} else {
529-
swapHash = hash
530-
hasSwapHash = true
531-
}
532-
}
533-
534-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
535-
err := m.cfg.PersistStaticLoopInRiskDecision(
536-
ctx, swapHash, true,
537-
)
538-
if err != nil {
539-
log.Errorf("Unable to persist static loop in "+
540-
"risk accepted notification: %v", err)
541-
}
594+
swapHashBytes = riskAcceptedNtfn.SwapHash
542595
}
543596

544-
m.Lock()
545-
defer m.Unlock()
546-
547-
if hasSwapHash {
548-
m.staticLoopInRiskAccepted[swapHash] =
549-
riskAcceptedNtfn
550-
delete(m.staticLoopInRiskRejected, swapHash)
551-
}
552-
553-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskAccepted] { // nolint: lll
554-
if !hasSwapHash || sub.swapHash == nil ||
555-
*sub.swapHash != swapHash {
556-
557-
continue
558-
}
559-
560-
recvChan := sub.recvChan.(chan *swapserverrpc.
561-
ServerStaticLoopInRiskAcceptedNotification)
562-
563-
select {
564-
case recvChan <- riskAcceptedNtfn:
565-
case <-sub.subCtx.Done():
566-
default:
567-
log.Debugf("Dropping static loop in risk " +
568-
"accepted notification for slow subscriber")
569-
}
570-
}
597+
m.handleStaticLoopInRiskDecision(
598+
ctx, swapHashBytes, true,
599+
NotificationTypeStaticLoopInRiskAccepted,
600+
func(swapHash lntypes.Hash) {
601+
m.staticLoopInRiskAccepted[swapHash] =
602+
riskAcceptedNtfn
603+
delete(m.staticLoopInRiskRejected, swapHash)
604+
},
605+
func(sub subscriber) {
606+
recvChan := sub.recvChan.(chan *swapserverrpc.
607+
ServerStaticLoopInRiskAcceptedNotification)
608+
dropNotification(
609+
sub, recvChan, riskAcceptedNtfn,
610+
"static loop in risk accepted",
611+
)
612+
},
613+
)
571614

572615
case *swapserverrpc.SubscribeNotificationsResponse_StaticLoopInRiskRejected: // nolint: lll
573616
// We'll forward the static loop in risk rejected notification to the
574617
// subscriber for the matching swap.
575618
riskRejectedNtfn := ntfn.GetStaticLoopInRiskRejected()
576-
var (
577-
swapHash lntypes.Hash
578-
hasSwapHash bool
579-
)
619+
var swapHashBytes []byte
580620
if riskRejectedNtfn != nil {
581-
hash, err := lntypes.MakeHash(riskRejectedNtfn.SwapHash)
582-
if err != nil {
583-
log.Warnf("Received invalid static loop in risk "+
584-
"rejected notification: %v", err)
585-
} else {
586-
swapHash = hash
587-
hasSwapHash = true
588-
}
621+
swapHashBytes = riskRejectedNtfn.SwapHash
589622
}
590623

591-
if hasSwapHash && m.cfg.PersistStaticLoopInRiskDecision != nil {
592-
err := m.cfg.PersistStaticLoopInRiskDecision(
593-
ctx, swapHash, false,
594-
)
595-
if err != nil {
596-
log.Errorf("Unable to persist static loop in "+
597-
"risk rejected notification: %v", err)
598-
}
599-
}
600-
601-
m.Lock()
602-
defer m.Unlock()
603-
604-
if hasSwapHash {
605-
m.staticLoopInRiskRejected[swapHash] =
606-
riskRejectedNtfn
607-
delete(m.staticLoopInRiskAccepted, swapHash)
608-
}
609-
610-
for _, sub := range m.subscribers[NotificationTypeStaticLoopInRiskRejected] { // nolint: lll
611-
if !hasSwapHash || sub.swapHash == nil ||
612-
*sub.swapHash != swapHash {
613-
614-
continue
615-
}
616-
617-
recvChan := sub.recvChan.(chan *swapserverrpc.
618-
ServerStaticLoopInRiskRejectedNotification)
619-
620-
select {
621-
case recvChan <- riskRejectedNtfn:
622-
case <-sub.subCtx.Done():
623-
default:
624-
log.Debugf("Dropping static loop in risk " +
625-
"rejected notification for slow subscriber")
626-
}
627-
}
624+
m.handleStaticLoopInRiskDecision(
625+
ctx, swapHashBytes, false,
626+
NotificationTypeStaticLoopInRiskRejected,
627+
func(swapHash lntypes.Hash) {
628+
m.staticLoopInRiskRejected[swapHash] =
629+
riskRejectedNtfn
630+
delete(m.staticLoopInRiskAccepted, swapHash)
631+
},
632+
func(sub subscriber) {
633+
recvChan := sub.recvChan.(chan *swapserverrpc.
634+
ServerStaticLoopInRiskRejectedNotification)
635+
dropNotification(
636+
sub, recvChan, riskRejectedNtfn,
637+
"static loop in risk rejected",
638+
)
639+
},
640+
)
628641

629642
case *swapserverrpc.SubscribeNotificationsResponse_UnfinishedSwap: // nolint: lll
630643
// We'll forward the unfinished swap notification to all

notifications/manager_test.go

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -355,48 +355,21 @@ func TestManager_SlowSubscriberDoesNotBlock(t *testing.T) {
355355
func TestManager_UnfinishedSwapNotificationWaitsForSubscriber(t *testing.T) {
356356
t.Parallel()
357357

358-
mgr := NewManager(&Config{})
359-
360-
subCtx, subCancel := context.WithCancel(t.Context())
361-
defer subCancel()
362-
363-
subChan := mgr.SubscribeUnfinishedSwaps(subCtx)
364-
365-
swapHashA := lntypes.Hash{0x02, 0x03}
366-
swapHashB := lntypes.Hash{0x04, 0x05}
367-
368-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashA))
369-
370-
done := make(chan struct{})
371-
go func() {
372-
mgr.handleNotification(t.Context(), unfinishedSwapNotification(swapHashB))
373-
close(done)
374-
}()
375-
376-
require.Eventually(t, func() bool {
377-
select {
378-
case <-done:
379-
return true
380-
default:
381-
return false
382-
}
383-
}, time.Second, 10*time.Millisecond)
384-
385-
select {
386-
case received := <-subChan:
387-
require.Equal(t, swapHashA[:], received.SwapHash)
358+
assertQueuedSwapHashNotifications(
359+
t,
360+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
361+
ServerUnfinishedSwapNotification {
388362

389-
case <-time.After(time.Second):
390-
t.Fatal("did not receive first unfinished swap notification")
391-
}
392-
393-
select {
394-
case received := <-subChan:
395-
require.Equal(t, swapHashB[:], received.SwapHash)
396-
397-
case <-time.After(time.Second):
398-
t.Fatal("second unfinished swap notification was dropped")
399-
}
363+
return mgr.SubscribeUnfinishedSwaps(ctx)
364+
},
365+
unfinishedSwapNotification,
366+
func(ntfn *swapserverrpc.ServerUnfinishedSwapNotification) []byte {
367+
return ntfn.SwapHash
368+
},
369+
lntypes.Hash{0x02, 0x03}, lntypes.Hash{0x04, 0x05},
370+
"did not receive first unfinished swap notification",
371+
"second unfinished swap notification was dropped",
372+
)
400373
}
401374

402375
// TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber verifies
@@ -407,23 +380,45 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
407380

408381
t.Parallel()
409382

383+
assertQueuedSwapHashNotifications(
384+
t,
385+
func(mgr *Manager, ctx context.Context) <-chan *swapserverrpc.
386+
ServerStaticLoopInSweepNotification {
387+
388+
return mgr.SubscribeStaticLoopInSweepRequests(ctx)
389+
},
390+
staticLoopInSweepNotification,
391+
func(ntfn *swapserverrpc.ServerStaticLoopInSweepNotification) []byte {
392+
return ntfn.SwapHash
393+
},
394+
lntypes.Hash{0x12, 0x13}, lntypes.Hash{0x14, 0x15},
395+
"did not receive first sweep notification",
396+
"second sweep notification was not queued",
397+
)
398+
}
399+
400+
// assertQueuedSwapHashNotifications checks queued delivery for swap hashes.
401+
func assertQueuedSwapHashNotifications[T any](t *testing.T,
402+
subscribe func(*Manager, context.Context) <-chan T,
403+
notification func(lntypes.Hash) *swapserverrpc.
404+
SubscribeNotificationsResponse,
405+
swapHash func(T) []byte, swapHashA, swapHashB lntypes.Hash,
406+
firstFailureMsg, secondFailureMsg string) {
407+
408+
t.Helper()
409+
410410
mgr := NewManager(&Config{})
411411

412412
subCtx, subCancel := context.WithCancel(t.Context())
413413
defer subCancel()
414414

415-
subChan := mgr.SubscribeStaticLoopInSweepRequests(subCtx)
415+
subChan := subscribe(mgr, subCtx)
416416

417-
swapHashA := lntypes.Hash{0x12, 0x13}
418-
swapHashB := lntypes.Hash{0x14, 0x15}
419-
420-
mgr.handleNotification(t.Context(), staticLoopInSweepNotification(swapHashA))
417+
mgr.handleNotification(t.Context(), notification(swapHashA))
421418

422419
done := make(chan struct{})
423420
go func() {
424-
mgr.handleNotification(
425-
t.Context(), staticLoopInSweepNotification(swapHashB),
426-
)
421+
mgr.handleNotification(t.Context(), notification(swapHashB))
427422
close(done)
428423
}()
429424

@@ -438,18 +433,18 @@ func TestManager_StaticLoopInSweepNotificationQueuesForSlowSubscriber(
438433

439434
select {
440435
case received := <-subChan:
441-
require.Equal(t, swapHashA[:], received.SwapHash)
436+
require.Equal(t, swapHashA[:], swapHash(received))
442437

443438
case <-time.After(time.Second):
444-
t.Fatal("did not receive first sweep notification")
439+
t.Fatal(firstFailureMsg)
445440
}
446441

447442
select {
448443
case received := <-subChan:
449-
require.Equal(t, swapHashB[:], received.SwapHash)
444+
require.Equal(t, swapHashB[:], swapHash(received))
450445

451446
case <-time.After(time.Second):
452-
t.Fatal("second sweep notification was not queued")
447+
t.Fatal(secondFailureMsg)
453448
}
454449
}
455450

0 commit comments

Comments
 (0)