Skip to content

Commit 6dba2fa

Browse files
committed
routing: serialize payment lifecycles by hash
1 parent a8a3e13 commit 6dba2fa

3 files changed

Lines changed: 151 additions & 18 deletions

File tree

lnrpc/routerrpc/router_server.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
375375
payHash := payment.Identifier()
376376

377377
// Init the payment in db.
378-
paySession, shardTracker, err := s.cfg.Router.PreparePayment(payment)
378+
paySession, shardTracker, releaseLifecycle, err := s.cfg.Router.
379+
PreparePayment(payment)
379380
if err != nil {
380381
log.Errorf("SendPayment async error for payment %x: %v",
381382
payment.Identifier(), err)
@@ -397,6 +398,8 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
397398
// miss events.
398399
sub, err := s.subscribePayment(payHash)
399400
if err != nil {
401+
releaseLifecycle()
402+
400403
return err
401404
}
402405

@@ -415,7 +418,9 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
415418
}
416419

417420
// Send the payment asynchronously.
418-
s.cfg.Router.SendPaymentAsync(ctx, payment, paySession, shardTracker)
421+
s.cfg.Router.SendPaymentAsync(
422+
ctx, payment, paySession, shardTracker, releaseLifecycle,
423+
)
419424

420425
// Track the payment and return.
421426
return s.trackPayment(sub, payHash, stream, req.NoInflightUpdates)

routing/router.go

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/lightningnetwork/lnd/lnutils"
2323
"github.com/lightningnetwork/lnd/lnwallet"
2424
"github.com/lightningnetwork/lnd/lnwire"
25+
"github.com/lightningnetwork/lnd/multimutex"
2526
paymentsdb "github.com/lightningnetwork/lnd/payments/db"
2627
"github.com/lightningnetwork/lnd/record"
2728
"github.com/lightningnetwork/lnd/routing/route"
@@ -334,6 +335,13 @@ type ChannelRouter struct {
334335

335336
quit chan struct{}
336337
wg sync.WaitGroup
338+
339+
// paymentLifecycleMtx ensures only one payment lifecycle for a given
340+
// payment hash can be active at a time. Failed payments are retryable,
341+
// so the lock must be held until resumePayment fully exits rather than
342+
// only until the terminal status update is published.
343+
paymentLifecycleMtx *multimutex.Mutex[lntypes.Hash]
344+
paymentLifecycleMtxOnce sync.Once
337345
}
338346

339347
// New creates a new instance of the ChannelRouter with the specified
@@ -343,8 +351,9 @@ type ChannelRouter struct {
343351
// to fully sync to the latest state of the UTXO set.
344352
func New(cfg Config) (*ChannelRouter, error) {
345353
return &ChannelRouter{
346-
cfg: &cfg,
347-
quit: make(chan struct{}),
354+
cfg: &cfg,
355+
quit: make(chan struct{}),
356+
paymentLifecycleMtx: multimutex.NewMutex[lntypes.Hash](),
348357
}, nil
349358
}
350359

@@ -903,7 +912,9 @@ func (l *LightningPayment) Identifier() [32]byte {
903912
func (r *ChannelRouter) SendPayment(ctx context.Context,
904913
payment *LightningPayment) ([32]byte, *route.Route, error) {
905914

906-
paySession, shardTracker, err := r.PreparePayment(payment)
915+
paySession, shardTracker, releaseLifecycle, err := r.PreparePayment(
916+
payment,
917+
)
907918
if err != nil {
908919
return [32]byte{}, nil, err
909920
}
@@ -914,14 +925,15 @@ func (r *ChannelRouter) SendPayment(ctx context.Context,
914925
return r.sendPayment(
915926
ctx, payment.FeeLimit, payment.Identifier(),
916927
payment.PayAttemptTimeout, paySession, shardTracker,
917-
payment.FirstHopCustomRecords,
928+
payment.FirstHopCustomRecords, releaseLifecycle,
918929
)
919930
}
920931

921932
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
922933
// result needs to be retrieved via the control tower.
923934
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
924-
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {
935+
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker,
936+
releaseLifecycle func()) {
925937

926938
// Since this is the first time this payment is being made, we pass nil
927939
// for the existing attempt.
@@ -935,7 +947,7 @@ func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
935947
_, _, err := r.sendPayment(
936948
ctx, payment.FeeLimit, payment.Identifier(),
937949
payment.PayAttemptTimeout, ps, st,
938-
payment.FirstHopCustomRecords,
950+
payment.FirstHopCustomRecords, releaseLifecycle,
939951
)
940952
if err != nil {
941953
log.Errorf("Payment %x failed: %v",
@@ -966,24 +978,43 @@ func spewPayment(payment *LightningPayment) lnutils.LogClosure {
966978
})
967979
}
968980

981+
// lockPaymentLifecycle locks the payment lifecycle mutex for the given payment
982+
// hash and returns a release function.
983+
func (r *ChannelRouter) lockPaymentLifecycle(paymentHash lntypes.Hash) func() {
984+
r.paymentLifecycleMtxOnce.Do(func() {
985+
if r.paymentLifecycleMtx == nil {
986+
r.paymentLifecycleMtx = multimutex.
987+
NewMutex[lntypes.Hash]()
988+
}
989+
})
990+
991+
r.paymentLifecycleMtx.Lock(paymentHash)
992+
993+
return func() {
994+
r.paymentLifecycleMtx.Unlock(paymentHash)
995+
}
996+
}
997+
969998
// PreparePayment creates the payment session and registers the payment with the
970-
// control tower.
999+
// control tower. The returned release function must be called after the payment
1000+
// lifecycle has fully exited.
9711001
func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
972-
PaymentSession, shards.ShardTracker, error) {
1002+
PaymentSession, shards.ShardTracker, func(), error) {
9731003

9741004
ctx := context.TODO()
9751005

9761006
// Assemble any custom data we want to send to the first hop only.
9771007
var firstHopData fn.Option[tlv.Blob]
9781008
if len(payment.FirstHopCustomRecords) > 0 {
9791009
if err := payment.FirstHopCustomRecords.Validate(); err != nil {
980-
return nil, nil, fmt.Errorf("invalid first hop custom "+
981-
"records: %w", err)
1010+
return nil, nil, nil, fmt.Errorf(
1011+
"invalid first hop custom records: %w", err,
1012+
)
9821013
}
9831014

9841015
firstHopBlob, err := payment.FirstHopCustomRecords.Serialize()
9851016
if err != nil {
986-
return nil, nil, fmt.Errorf("unable to serialize "+
1017+
return nil, nil, nil, fmt.Errorf("unable to serialize "+
9871018
"first hop custom records: %w", err)
9881019
}
9891020

@@ -997,7 +1028,7 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
9971028
payment, firstHopData, r.cfg.TrafficShaper,
9981029
)
9991030
if err != nil {
1000-
return nil, nil, err
1031+
return nil, nil, nil, err
10011032
}
10021033

10031034
// Record this payment hash with the ControlTower, ensuring it is not
@@ -1032,12 +1063,16 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
10321063
)
10331064
}
10341065

1066+
releaseLifecycle := r.lockPaymentLifecycle(payment.Identifier())
1067+
10351068
err = r.cfg.Control.InitPayment(ctx, payment.Identifier(), info)
10361069
if err != nil {
1037-
return nil, nil, err
1070+
releaseLifecycle()
1071+
1072+
return nil, nil, nil, err
10381073
}
10391074

1040-
return paySession, shardTracker, nil
1075+
return paySession, shardTracker, releaseLifecycle, nil
10411076
}
10421077

10431078
// SendToRoute sends a payment using the provided route and fails the payment
@@ -1264,8 +1299,12 @@ func (r *ChannelRouter) sendPayment(ctx context.Context,
12641299
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
12651300
paymentAttemptTimeout time.Duration, paySession PaymentSession,
12661301
shardTracker shards.ShardTracker,
1267-
firstHopCustomRecords lnwire.CustomRecords) ([32]byte, *route.Route,
1268-
error) {
1302+
firstHopCustomRecords lnwire.CustomRecords,
1303+
releaseLifecycle func()) ([32]byte, *route.Route, error) {
1304+
1305+
if releaseLifecycle != nil {
1306+
defer releaseLifecycle()
1307+
}
12691308

12701309
// If the user provides a timeout, we will additionally wrap the context
12711310
// in a deadline.
@@ -1511,9 +1550,11 @@ func (r *ChannelRouter) resumePayments() error {
15111550
// attempt has finished anyway. We also set a zero fee limit,
15121551
// as no more routes should be tried.
15131552
noTimeout := time.Duration(0)
1553+
releaseLifecycle := r.lockPaymentLifecycle(payHash)
15141554
_, _, err := r.sendPayment(
15151555
context.Background(), 0, payHash, noTimeout, paySession,
15161556
shardTracker, payment.Info.FirstHopCustomRecords,
1557+
releaseLifecycle,
15171558
)
15181559
if err != nil {
15191560
log.Errorf("Resuming payment %v failed: %v", payHash,

routing/router_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,93 @@ func TestUnknownErrorSource(t *testing.T) {
13571357
}
13581358
}
13591359

1360+
// TestPreparePaymentWaitsForActiveLifecycle asserts that a retry with the same
1361+
// payment hash cannot create a new control tower payment until the existing
1362+
// payment lifecycle has fully exited.
1363+
func TestPreparePaymentWaitsForActiveLifecycle(t *testing.T) {
1364+
t.Parallel()
1365+
1366+
payment := createDummyLightningPayment(
1367+
t, route.Vertex{}, lnwire.MilliSatoshi(1000),
1368+
)
1369+
payHash := lntypes.Hash(payment.Identifier())
1370+
1371+
controlTower := &mockControlTower{}
1372+
sessionSource := &mockPaymentSessionSource{}
1373+
r := &ChannelRouter{
1374+
cfg: &Config{
1375+
Clock: clock.NewDefaultClock(),
1376+
Control: controlTower,
1377+
SessionSource: sessionSource,
1378+
TrafficShaper: fn.None[htlcswitch.AuxTrafficShaper](),
1379+
},
1380+
quit: make(chan struct{}),
1381+
}
1382+
1383+
sessionCreated := make(chan struct{})
1384+
sessionSource.On(
1385+
"NewPaymentSession", payment, mock.Anything, mock.Anything,
1386+
).Run(func(_ mock.Arguments) {
1387+
close(sessionCreated)
1388+
}).Return(&mockPaymentSession{}, nil).Once()
1389+
1390+
initCalled := make(chan struct{})
1391+
controlTower.On(
1392+
"InitPayment", payHash,
1393+
mock.MatchedBy(func(info *paymentsdb.PaymentCreationInfo) bool {
1394+
return info.PaymentIdentifier == payHash
1395+
}),
1396+
).Run(func(_ mock.Arguments) {
1397+
close(initCalled)
1398+
}).Return(nil).Once()
1399+
1400+
releaseActiveLifecycle := r.lockPaymentLifecycle(payHash)
1401+
1402+
result := make(chan error, 1)
1403+
go func() {
1404+
_, _, releaseLifecycle, err := r.PreparePayment(payment)
1405+
if err == nil {
1406+
releaseLifecycle()
1407+
}
1408+
1409+
result <- err
1410+
}()
1411+
1412+
select {
1413+
case <-sessionCreated:
1414+
1415+
case <-time.After(testTimeout):
1416+
require.Fail(t, "payment session not created")
1417+
}
1418+
1419+
select {
1420+
case <-initCalled:
1421+
require.Fail(t, "payment initialized before lifecycle exit")
1422+
1423+
case <-time.After(50 * time.Millisecond):
1424+
}
1425+
1426+
releaseActiveLifecycle()
1427+
1428+
select {
1429+
case err := <-result:
1430+
require.NoError(t, err)
1431+
1432+
case <-time.After(testTimeout):
1433+
require.Fail(t, "prepare payment did not complete")
1434+
}
1435+
1436+
select {
1437+
case <-initCalled:
1438+
1439+
default:
1440+
require.Fail(t, "payment was not initialized")
1441+
}
1442+
1443+
controlTower.AssertExpectations(t)
1444+
sessionSource.AssertExpectations(t)
1445+
}
1446+
13601447
// TestSendToRouteStructuredError asserts that SendToRoute returns a structured
13611448
// error.
13621449
func TestSendToRouteStructuredError(t *testing.T) {

0 commit comments

Comments
 (0)