diff --git a/client.go b/client.go index f947cd08e..da2616618 100644 --- a/client.go +++ b/client.go @@ -21,6 +21,7 @@ import ( "github.com/lightninglabs/loop/sweep" "github.com/lightninglabs/loop/sweepbatcher" "github.com/lightninglabs/loop/utils" + "github.com/lightninglabs/loop/utils/chainhashutil" "github.com/lightninglabs/taproot-assets/rpcutils" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -272,14 +273,9 @@ func NewClient(dbDir string, loopDB loopdb.SwapStore, } if len(cfg.SkippedTxns) != 0 { - skippedTxns := make(map[chainhash.Hash]struct{}) - for _, txid := range cfg.SkippedTxns { - txid, err := chainhash.NewHashFromStr(txid) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse "+ - "txid to skip %v: %w", txid, err) - } - skippedTxns[*txid] = struct{}{} + skippedTxns, err := parseSkippedTxns(cfg.SkippedTxns) + if err != nil { + return nil, nil, err } batcherOpts = append(batcherOpts, sweepbatcher.WithSkippedTxns( skippedTxns, @@ -323,6 +319,24 @@ func NewClient(dbDir string, loopDB loopdb.SwapStore, return client, cleanup, nil } +// parseSkippedTxns parses the configured skipped transaction IDs and rejects +// any txid that is not fully specified. +func parseSkippedTxns(txids []string) (map[chainhash.Hash]struct{}, error) { + skippedTxns := make(map[chainhash.Hash]struct{}, len(txids)) + + for _, txid := range txids { + hash, err := chainhashutil.NewHashFromStrExact(txid) + if err != nil { + return nil, fmt.Errorf("failed to parse txid to skip %v: %w", + txid, err) + } + + skippedTxns[hash] = struct{}{} + } + + return skippedTxns, nil +} + // GetConn returns the gRPC connection to the server. func (s *Client) GetConn() *grpc.ClientConn { return s.clientConfig.Conn diff --git a/client_test.go b/client_test.go index 15b9c7c9d..ede6e2ff6 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "errors" + "strings" "testing" "github.com/btcsuite/btcd/btcutil" @@ -46,6 +47,50 @@ var ( defaultConfirmations = int32(loopdb.DefaultLoopOutHtlcConfirmations) ) +// TestParseSkippedTxns verifies that skipped txids must be fully specified. +func TestParseSkippedTxns(t *testing.T) { + t.Parallel() + + validTxid := strings.Repeat("01", 32) + validHash, err := chainhash.NewHashFromStr(validTxid) + require.NoError(t, err) + + tests := []struct { + name string + txids []string + expected map[chainhash.Hash]struct{} + expectedErr string + }{ + { + name: "valid", + txids: []string{validTxid}, + expected: map[chainhash.Hash]struct{}{ + *validHash: {}, + }, + }, + { + name: "short", + txids: []string{"abcd"}, + expectedErr: "failed to parse txid to skip abcd", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + skippedTxns, err := parseSkippedTxns(test.txids) + if test.expectedErr != "" { + require.ErrorContains(t, err, test.expectedErr) + return + } + + require.NoError(t, err) + require.Equal(t, test.expected, skippedTxns) + }) + } +} + var htlcKeys = func() loopdb.HtlcKeys { var senderKey, receiverKey [33]byte diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 62187d3e5..c69ad2a64 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -1836,8 +1836,10 @@ func (s *swapClientServer) ListStaticAddressDeposits(ctx context.Context, }, nil } -// ListStaticAddressWithdrawals returns a list of all finalized withdrawal -// transactions. +// ListStaticAddressWithdrawals returns a list of all static address +// withdrawals, including pending withdrawals. Pending withdrawals expose +// default empty or zero values for fields that are only known after +// confirmation. func (s *swapClientServer) ListStaticAddressWithdrawals(ctx context.Context, _ *looprpc.ListStaticAddressWithdrawalRequest) ( *looprpc.ListStaticAddressWithdrawalResponse, error) { @@ -1855,6 +1857,11 @@ func (s *swapClientServer) ListStaticAddressWithdrawals(ctx context.Context, []*looprpc.StaticAddressWithdrawal, 0, len(withdrawals), ) for _, w := range withdrawals { + txID := "" + if w.TxID != nil { + txID = w.TxID.String() + } + deposits := make([]*looprpc.Deposit, 0, len(w.Deposits)) for _, d := range w.Deposits { deposits = append(deposits, &looprpc.Deposit{ @@ -1868,7 +1875,7 @@ func (s *swapClientServer) ListStaticAddressWithdrawals(ctx context.Context, }) } withdrawal := &looprpc.StaticAddressWithdrawal{ - TxId: w.TxID.String(), + TxId: txID, Deposits: deposits, TotalDepositAmountSatoshis: int64(w.TotalDepositAmount), WithdrawnAmountSatoshis: int64(w.WithdrawnAmount), diff --git a/loopdb/sql_store.go b/loopdb/sql_store.go index da83f49ad..6679fd889 100644 --- a/loopdb/sql_store.go +++ b/loopdb/sql_store.go @@ -4,14 +4,15 @@ import ( "context" "database/sql" "errors" + "fmt" "strconv" "strings" "time" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightninglabs/loop/loopdb/sqlc" + "github.com/lightninglabs/loop/utils/chainhashutil" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/routing/route" @@ -754,12 +755,15 @@ func getSwapEvents(updates []sqlc.SwapUpdate) ([]*LoopEvent, error) { } if updates[i].HtlcTxhash != "" { - chainHash, err := chainhash.NewHashFromStr(updates[i].HtlcTxhash) + chainHash, err := chainhashutil.NewHashFromStrExact( + updates[i].HtlcTxhash, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid htlc tx hash "+ + "%q: %w", updates[i].HtlcTxhash, err) } - events[i].HtlcTxHash = chainHash + events[i].HtlcTxHash = &chainHash } } diff --git a/loopdb/sql_test.go b/loopdb/sql_test.go index 8312a7937..4369394e5 100644 --- a/loopdb/sql_test.go +++ b/loopdb/sql_test.go @@ -535,6 +535,58 @@ func TestBatchUpdateCost(t *testing.T) { require.Equal(t, updateMap[hash2], swapsMap[hash2].State().Cost) } +// TestSqliteRejectsTruncatedHtlcTxHash verifies that a persisted short HTLC +// txid is rejected instead of being accepted as a padded hash. +func TestSqliteRejectsTruncatedHtlcTxHash(t *testing.T) { + store := NewTestDB(t) + + destAddr := test.GetDestAddr(t, 0) + pendingSwap := LoopOutContract{ + SwapContract: SwapContract{ + AmountRequested: 100, + Preimage: testPreimage, + CltvExpiry: 144, + HtlcKeys: HtlcKeys{ + SenderScriptKey: senderKey, + ReceiverScriptKey: receiverKey, + SenderInternalPubKey: senderInternalKey, + ReceiverInternalPubKey: receiverInternalKey, + ClientScriptKeyLocator: keychain.KeyLocator{ + Family: 1, + Index: 2, + }, + }, + MaxMinerFee: 10, + MaxSwapFee: 20, + InitiationHeight: 99, + InitiationTime: testTime, + ProtocolVersion: ProtocolVersionMuSig2, + }, + PrepayInvoice: "prepayinvoice", + DestAddr: destAddr, + SwapInvoice: "swapinvoice", + SweepConfTarget: 2, + HtlcConfirmations: 2, + } + + ctxb := t.Context() + hash := pendingSwap.Preimage.Hash() + + err := store.CreateLoopOut(ctxb, hash, &pendingSwap) + require.NoError(t, err) + + err = store.Queries.InsertSwapUpdate(ctxb, sqlc.InsertSwapUpdateParams{ + SwapHash: hash[:], + UpdateTimestamp: testTime, + UpdateState: int32(StatePreimageRevealed), + HtlcTxhash: "abcd", + }) + require.NoError(t, err) + + _, err = store.FetchLoopOutSwap(ctxb, hash) + require.ErrorContains(t, err, "invalid htlc tx hash") +} + // TestMigrationTracker tests the migration tracker functionality. func TestMigrationTracker(t *testing.T) { ctxb := context.Background() diff --git a/loopdb/sqlc/migrations/000015_static_address_withdrawals.up.sql b/loopdb/sqlc/migrations/000015_static_address_withdrawals.up.sql index 5400be7fa..5308b8c2c 100644 --- a/loopdb/sqlc/migrations/000015_static_address_withdrawals.up.sql +++ b/loopdb/sqlc/migrations/000015_static_address_withdrawals.up.sql @@ -1,4 +1,4 @@ --- withdrawals stores finalized static address withdrawals. +-- withdrawals stores pending and finalized static address withdrawals. CREATE TABLE IF NOT EXISTS withdrawals ( -- id is the auto-incrementing primary key for a withdrawal. id INTEGER PRIMARY KEY, @@ -6,7 +6,8 @@ CREATE TABLE IF NOT EXISTS withdrawals ( -- withdrawal_id is the unique identifier for the withdrawal. withdrawal_id BLOB NOT NULL UNIQUE, - -- withdrawal_tx_id is the transaction tx id of the withdrawal. + -- withdrawal_tx_id is the confirmed transaction txid of the withdrawal. + -- It remains NULL while the withdrawal is still pending. withdrawal_tx_id TEXT UNIQUE, -- total_deposit_amount is the total amount of the deposits in satoshis. diff --git a/looprpc/client.pb.go b/looprpc/client.pb.go index 35ae71c82..154d92d9b 100644 --- a/looprpc/client.pb.go +++ b/looprpc/client.pb.go @@ -5695,7 +5695,8 @@ func (x *Deposit) GetSwapHash() []byte { type StaticAddressWithdrawal struct { state protoimpl.MessageState `protogen:"open.v1"` - // The transaction id of the withdrawal transaction. + // The transaction id of the withdrawal transaction. It is empty until the + // confirmed transaction is persisted. TxId string `protobuf:"bytes,1,opt,name=tx_id,json=txId,proto3" json:"tx_id,omitempty"` // The selected deposits that is withdrawn from. Deposits []*Deposit `protobuf:"bytes,2,rep,name=deposits,proto3" json:"deposits,omitempty"` @@ -5703,11 +5704,13 @@ type StaticAddressWithdrawal struct { TotalDepositAmountSatoshis int64 `protobuf:"varint,3,opt,name=total_deposit_amount_satoshis,json=totalDepositAmountSatoshis,proto3" json:"total_deposit_amount_satoshis,omitempty"` // The actual amount that was withdrawn from the selected deposits. This value // represents the sum of selected deposit values minus tx fees minus optional - // change output. + // change output. It is zero until the confirmed transaction is persisted. WithdrawnAmountSatoshis int64 `protobuf:"varint,4,opt,name=withdrawn_amount_satoshis,json=withdrawnAmountSatoshis,proto3" json:"withdrawn_amount_satoshis,omitempty"` - // An optional change. + // An optional change. It is zero until the confirmed transaction is + // persisted. ChangeAmountSatoshis int64 `protobuf:"varint,5,opt,name=change_amount_satoshis,json=changeAmountSatoshis,proto3" json:"change_amount_satoshis,omitempty"` - // The confirmation block height of the withdrawal transaction. + // The confirmation block height of the withdrawal transaction. It is zero + // until the withdrawal is confirmed. ConfirmationHeight uint32 `protobuf:"varint,6,opt,name=confirmation_height,json=confirmationHeight,proto3" json:"confirmation_height,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache diff --git a/looprpc/client.proto b/looprpc/client.proto index cf14ffa0a..a41b87287 100644 --- a/looprpc/client.proto +++ b/looprpc/client.proto @@ -188,7 +188,8 @@ service SwapClient { returns (ListStaticAddressDepositsResponse); /* loop:`listwithdrawals` - ListStaticAddressWithdrawals returns a list of static address withdrawals. + ListStaticAddressWithdrawals returns a list of static address withdrawals, + including pending withdrawals that have not yet been confirmed. */ rpc ListStaticAddressWithdrawals (ListStaticAddressWithdrawalRequest) returns (ListStaticAddressWithdrawalResponse); @@ -2047,7 +2048,8 @@ message Deposit { message StaticAddressWithdrawal { /* - The transaction id of the withdrawal transaction. + The transaction id of the withdrawal transaction. It is empty until the + confirmed transaction is persisted. */ string tx_id = 1; @@ -2064,17 +2066,19 @@ message StaticAddressWithdrawal { /* The actual amount that was withdrawn from the selected deposits. This value represents the sum of selected deposit values minus tx fees minus optional - change output. + change output. It is zero until the confirmed transaction is persisted. */ int64 withdrawn_amount_satoshis = 4; /* - An optional change. + An optional change. It is zero until the confirmed transaction is + persisted. */ int64 change_amount_satoshis = 5; /* - The confirmation block height of the withdrawal transaction. + The confirmation block height of the withdrawal transaction. It is zero + until the withdrawal is confirmed. */ uint32 confirmation_height = 6; } diff --git a/looprpc/client.swagger.json b/looprpc/client.swagger.json index 3d75d15da..3b4e459c9 100644 --- a/looprpc/client.swagger.json +++ b/looprpc/client.swagger.json @@ -1115,7 +1115,7 @@ }, "/v1/staticaddr/withdrawals": { "get": { - "summary": "loop:`listwithdrawals`\nListStaticAddressWithdrawals returns a list of static address withdrawals.", + "summary": "loop:`listwithdrawals`\nListStaticAddressWithdrawals returns a list of static address withdrawals,\nincluding pending withdrawals that have not yet been confirmed.", "operationId": "SwapClient_ListStaticAddressWithdrawals", "responses": { "200": { @@ -2872,7 +2872,7 @@ "properties": { "tx_id": { "type": "string", - "description": "The transaction id of the withdrawal transaction." + "description": "The transaction id of the withdrawal transaction. It is empty until the\nconfirmed transaction is persisted." }, "deposits": { "type": "array", @@ -2890,17 +2890,17 @@ "withdrawn_amount_satoshis": { "type": "string", "format": "int64", - "description": "The actual amount that was withdrawn from the selected deposits. This value\nrepresents the sum of selected deposit values minus tx fees minus optional\nchange output." + "description": "The actual amount that was withdrawn from the selected deposits. This value\nrepresents the sum of selected deposit values minus tx fees minus optional\nchange output. It is zero until the confirmed transaction is persisted." }, "change_amount_satoshis": { "type": "string", "format": "int64", - "description": "An optional change." + "description": "An optional change. It is zero until the confirmed transaction is\npersisted." }, "confirmation_height": { "type": "integer", "format": "int64", - "description": "The confirmation block height of the withdrawal transaction." + "description": "The confirmation block height of the withdrawal transaction. It is zero\nuntil the withdrawal is confirmed." } } }, diff --git a/looprpc/client_grpc.pb.go b/looprpc/client_grpc.pb.go index b03cc9e87..abb50dd84 100644 --- a/looprpc/client_grpc.pb.go +++ b/looprpc/client_grpc.pb.go @@ -127,7 +127,8 @@ type SwapClientClient interface { // deposits. ListStaticAddressDeposits(ctx context.Context, in *ListStaticAddressDepositsRequest, opts ...grpc.CallOption) (*ListStaticAddressDepositsResponse, error) // loop:`listwithdrawals` - // ListStaticAddressWithdrawals returns a list of static address withdrawals. + // ListStaticAddressWithdrawals returns a list of static address withdrawals, + // including pending withdrawals that have not yet been confirmed. ListStaticAddressWithdrawals(ctx context.Context, in *ListStaticAddressWithdrawalRequest, opts ...grpc.CallOption) (*ListStaticAddressWithdrawalResponse, error) // loop:`listswaps` // ListStaticAddressSwaps returns a list of filtered static address @@ -587,7 +588,8 @@ type SwapClientServer interface { // deposits. ListStaticAddressDeposits(context.Context, *ListStaticAddressDepositsRequest) (*ListStaticAddressDepositsResponse, error) // loop:`listwithdrawals` - // ListStaticAddressWithdrawals returns a list of static address withdrawals. + // ListStaticAddressWithdrawals returns a list of static address withdrawals, + // including pending withdrawals that have not yet been confirmed. ListStaticAddressWithdrawals(context.Context, *ListStaticAddressWithdrawalRequest) (*ListStaticAddressWithdrawalResponse, error) // loop:`listswaps` // ListStaticAddressSwaps returns a list of filtered static address diff --git a/staticaddr/loopin/sql_store.go b/staticaddr/loopin/sql_store.go index 1b70bbc48..ab4da4275 100644 --- a/staticaddr/loopin/sql_store.go +++ b/staticaddr/loopin/sql_store.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "fmt" "strings" "github.com/btcsuite/btcd/btcec/v2" @@ -15,6 +16,7 @@ import ( "github.com/lightninglabs/loop/loopdb/sqlc" "github.com/lightninglabs/loop/staticaddr/deposit" "github.com/lightninglabs/loop/staticaddr/version" + "github.com/lightninglabs/loop/utils/chainhashutil" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -498,13 +500,20 @@ func toStaticAddressLoopIn(_ context.Context, network *chaincfg.Params, } var htlcTimeoutSweepTxHash *chainhash.Hash - if swap.HtlcTimeoutSweepTxID.Valid { - htlcTimeoutSweepTxHash, err = chainhash.NewHashFromStr( + // Loop never writes empty timeout sweep txids, but tolerate them on read + // so a malformed row does not prevent swap recovery. + if swap.HtlcTimeoutSweepTxID.Valid && + swap.HtlcTimeoutSweepTxID.String != "" { + + hash, err := chainhashutil.NewHashFromStrExact( swap.HtlcTimeoutSweepTxID.String, ) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid htlc timeout sweep txid %q: %w", + swap.HtlcTimeoutSweepTxID.String, err) } + + htlcTimeoutSweepTxHash = &hash } depositOutpoints := strings.Split( diff --git a/staticaddr/loopin/sql_store_test.go b/staticaddr/loopin/sql_store_test.go index 356049bc7..68aef2a6b 100644 --- a/staticaddr/loopin/sql_store_test.go +++ b/staticaddr/loopin/sql_store_test.go @@ -2,6 +2,7 @@ package loopin import ( "context" + "database/sql" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/loopdb/sqlc" "github.com/lightninglabs/loop/staticaddr/deposit" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/clock" @@ -271,3 +273,150 @@ func TestCreateLoopIn(t *testing.T) { require.Equal(t, d2.Value, swap.Deposits[1].Value) require.Equal(t, deposit.LoopingIn, swap.Deposits[1].GetState()) } + +// TestGetLoopInByHashRejectsTruncatedTimeoutSweepTxID verifies that a +// persisted short timeout sweep txid is rejected during swap loading. +func TestGetLoopInByHashRejectsTruncatedTimeoutSweepTxID(t *testing.T) { + ctxb := t.Context() + testDb := loopdb.NewTestDB(t) + testClock := clock.NewTestClock(time.Now()) + defer testDb.Close() + + depositStore := deposit.NewSqlStore(testDb.BaseDB) + swapStore := NewSqlStore( + loopdb.NewTypedStore[Querier](testDb), testClock, + &chaincfg.RegressionNetParams, + ) + + depositID, err := deposit.GetRandomDepositID() + require.NoError(t, err) + + depositRecord := &deposit.Deposit{ + ID: depositID, + OutPoint: wire.OutPoint{ + Hash: chainhash.Hash{0x1a, 0x2b, 0x3c, 0x4d}, + Index: 0, + }, + Value: btcutil.Amount(100_000), + TimeOutSweepPkScript: []byte{ + 0x00, 0x14, 0x1a, 0x2b, 0x3c, 0x41, + }, + } + + err = depositStore.CreateDeposit(ctxb, depositRecord) + require.NoError(t, err) + + depositRecord.SetState(deposit.LoopingIn) + err = depositStore.UpdateDeposit(ctxb, depositRecord) + require.NoError(t, err) + + _, clientPubKey := test.CreateKey(1) + _, serverPubKey := test.CreateKey(2) + addr, err := btcutil.DecodeAddress(P2wkhAddr, nil) + require.NoError(t, err) + + swapHash := lntypes.Hash{0x1, 0x2, 0x3, 0x4} + loopIn := StaticAddressLoopIn{ + SwapHash: swapHash, + SwapPreimage: lntypes.Preimage{0x1, 0x2, 0x3, 0x4}, + DepositOutpoints: []string{ + depositRecord.OutPoint.String(), + }, + Deposits: []*deposit.Deposit{depositRecord}, + ClientPubkey: clientPubKey, + ServerPubkey: serverPubKey, + HtlcTimeoutSweepAddress: addr, + } + loopIn.SetState(SignHtlcTx) + + err = swapStore.CreateLoopIn(ctxb, &loopIn) + require.NoError(t, err) + + err = testDb.Queries.UpdateStaticAddressLoopIn( + ctxb, sqlc.UpdateStaticAddressLoopInParams{ + SwapHash: swapHash[:], + HtlcTimeoutSweepTxID: sql.NullString{ + String: "abcd", + Valid: true, + }, + }, + ) + require.NoError(t, err) + + _, err = swapStore.GetLoopInByHash(ctxb, swapHash) + require.ErrorContains(t, err, "invalid htlc timeout sweep txid") +} + +// TestGetLoopInByHashAllowsEmptyTimeoutSweepTxID verifies that empty persisted +// timeout sweep txids are tolerated for recovery robustness. +func TestGetLoopInByHashAllowsEmptyTimeoutSweepTxID(t *testing.T) { + ctxb := t.Context() + testDb := loopdb.NewTestDB(t) + testClock := clock.NewTestClock(time.Now()) + defer testDb.Close() + + depositStore := deposit.NewSqlStore(testDb.BaseDB) + swapStore := NewSqlStore( + loopdb.NewTypedStore[Querier](testDb), testClock, + &chaincfg.RegressionNetParams, + ) + + depositID, err := deposit.GetRandomDepositID() + require.NoError(t, err) + + depositRecord := &deposit.Deposit{ + ID: depositID, + OutPoint: wire.OutPoint{ + Hash: chainhash.Hash{0x1a, 0x2b, 0x3c, 0x4d}, + Index: 0, + }, + Value: btcutil.Amount(100_000), + TimeOutSweepPkScript: []byte{ + 0x00, 0x14, 0x1a, 0x2b, 0x3c, 0x41, + }, + } + + err = depositStore.CreateDeposit(ctxb, depositRecord) + require.NoError(t, err) + + depositRecord.SetState(deposit.LoopingIn) + err = depositStore.UpdateDeposit(ctxb, depositRecord) + require.NoError(t, err) + + _, clientPubKey := test.CreateKey(1) + _, serverPubKey := test.CreateKey(2) + addr, err := btcutil.DecodeAddress(P2wkhAddr, nil) + require.NoError(t, err) + + swapHash := lntypes.Hash{0x1, 0x2, 0x3, 0x4} + loopIn := StaticAddressLoopIn{ + SwapHash: swapHash, + SwapPreimage: lntypes.Preimage{0x1, 0x2, 0x3, 0x4}, + DepositOutpoints: []string{ + depositRecord.OutPoint.String(), + }, + Deposits: []*deposit.Deposit{depositRecord}, + ClientPubkey: clientPubKey, + ServerPubkey: serverPubKey, + HtlcTimeoutSweepAddress: addr, + } + loopIn.SetState(SignHtlcTx) + + err = swapStore.CreateLoopIn(ctxb, &loopIn) + require.NoError(t, err) + + err = testDb.Queries.UpdateStaticAddressLoopIn( + ctxb, sqlc.UpdateStaticAddressLoopInParams{ + SwapHash: swapHash[:], + HtlcTimeoutSweepTxID: sql.NullString{ + String: "", + Valid: true, + }, + }, + ) + require.NoError(t, err) + + swap, err := swapStore.GetLoopInByHash(ctxb, swapHash) + require.NoError(t, err) + require.Nil(t, swap.HtlcTimeoutSweepTxHash) +} diff --git a/staticaddr/withdraw/manager.go b/staticaddr/withdraw/manager.go index 99fddd267..a8f98ffc8 100644 --- a/staticaddr/withdraw/manager.go +++ b/staticaddr/withdraw/manager.go @@ -92,8 +92,8 @@ type ManagerConfig struct { // Signer is the signer client that is used to sign transactions. Signer lndclient.SignerClient - // Store is the store that is used to persist the finalized withdrawal - // transactions. + // Store is the store that is used to persist pending and finalized + // withdrawal records. Store *SqlStore } @@ -1187,7 +1187,8 @@ func (m *Manager) DeliverWithdrawalRequest(ctx context.Context, } } -// GetAllWithdrawals returns all finalized withdrawals from the store. +// GetAllWithdrawals returns all pending and finalized withdrawals from the +// store. func (m *Manager) GetAllWithdrawals(ctx context.Context) ([]Withdrawal, error) { return m.cfg.Store.GetAllWithdrawals(ctx) } diff --git a/staticaddr/withdraw/sql_store.go b/staticaddr/withdraw/sql_store.go index df6f27f69..e5aa0930c 100644 --- a/staticaddr/withdraw/sql_store.go +++ b/staticaddr/withdraw/sql_store.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql" + "fmt" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -11,6 +12,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/loopdb/sqlc" "github.com/lightninglabs/loop/staticaddr/deposit" + "github.com/lightninglabs/loop/utils/chainhashutil" "github.com/lightningnetwork/lnd/clock" ) @@ -37,7 +39,8 @@ type Querier interface { GetWithdrawalDeposits(ctx context.Context, withdrawalID []byte) ( [][]byte, error) - // GetAllWithdrawals retrieves all withdrawals from the database. + // GetAllWithdrawals retrieves all pending and finalized withdrawals from + // the database. GetAllWithdrawals(ctx context.Context) ([]sqlc.Withdrawal, error) } @@ -69,7 +72,8 @@ func NewSqlStore(db BaseDB, depositStore deposit.Store) *SqlStore { } } -// CreateWithdrawal creates a static address withdrawal record in the database. +// CreateWithdrawal creates a pending static address withdrawal record in the +// database. func (s *SqlStore) CreateWithdrawal(ctx context.Context, deposits []*deposit.Deposit) error { @@ -110,8 +114,8 @@ func (s *SqlStore) CreateWithdrawal(ctx context.Context, }) } -// UpdateWithdrawal updates a withdrawal record with the transaction -// information, including the withdrawn amount, change amount, and +// UpdateWithdrawal finalizes a pending withdrawal record with the confirmed +// transaction information, including the withdrawn amount, change amount, and // confirmation height. It is expected that the withdrawal has already been // created with CreateWithdrawal, and that the deposits slice contains the // deposits associated with the withdrawal. @@ -169,9 +173,9 @@ func (s *SqlStore) UpdateWithdrawal(ctx context.Context, }) } -// GetAllWithdrawals retrieves all static address withdrawals from the -// database. It returns a slice of Withdrawal structs, each containing a list -// of associated deposits. +// GetAllWithdrawals retrieves all pending and finalized static address +// withdrawals from the database. Pending withdrawals return default zero +// values for fields that are only known after confirmation, and a nil TxID. func (s *SqlStore) GetAllWithdrawals(ctx context.Context) ([]Withdrawal, error) { @@ -200,14 +204,22 @@ func (s *SqlStore) GetAllWithdrawals(ctx context.Context) ([]Withdrawal, deposits = append(deposits, deposit) } - txID, err := chainhash.NewHashFromStr(w.WithdrawalTxID.String) - if err != nil { - return nil, err + var txID *chainhash.Hash + if w.WithdrawalTxID.Valid { + hash, err := chainhashutil.NewHashFromStrExact( + w.WithdrawalTxID.String, + ) + if err != nil { + return nil, fmt.Errorf("invalid withdrawal txid %q: %w", + w.WithdrawalTxID.String, err) + } + + txID = &hash } result = append(result, Withdrawal{ ID: ID(w.WithdrawalID), - TxID: *txID, + TxID: txID, Deposits: deposits, TotalDepositAmount: btcutil.Amount(w.TotalDepositAmount), WithdrawnAmount: btcutil.Amount(w.WithdrawnAmount.Int64), diff --git a/staticaddr/withdraw/sql_store_test.go b/staticaddr/withdraw/sql_store_test.go index 5897f20ef..4994104f3 100644 --- a/staticaddr/withdraw/sql_store_test.go +++ b/staticaddr/withdraw/sql_store_test.go @@ -2,11 +2,13 @@ package withdraw import ( "context" + "database/sql" "testing" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/loopdb/sqlc" "github.com/lightninglabs/loop/staticaddr/deposit" "github.com/stretchr/testify/require" ) @@ -83,6 +85,10 @@ func TestSqlStore(t *testing.T) { t, d2.Value, withdrawals[0].Deposits[1].Value, ) require.NotEmpty(t, withdrawals[0].InitiationTime) + require.Nil(t, withdrawals[0].TxID) + require.Zero(t, withdrawals[0].WithdrawnAmount) + require.Zero(t, withdrawals[0].ChangeAmount) + require.Zero(t, withdrawals[0].ConfirmationHeight) err = store.UpdateWithdrawal( ctxb, []*deposit.Deposit{d1, d2}, withdrawalTx, 6, []byte{0x01}, @@ -92,10 +98,57 @@ func TestSqlStore(t *testing.T) { withdrawals, err = store.GetAllWithdrawals(ctxb) require.NoError(t, err) require.Len(t, withdrawals, 1) - require.NotEmpty(t, withdrawals[0].TxID) + require.NotNil(t, withdrawals[0].TxID) + require.Equal(t, withdrawalTx.TxHash(), *withdrawals[0].TxID) require.EqualValues( t, d1.Value+d2.Value-100, withdrawals[0].WithdrawnAmount, ) require.EqualValues(t, 100, withdrawals[0].ChangeAmount) require.EqualValues(t, 6, withdrawals[0].ConfirmationHeight) } + +// TestGetAllWithdrawalsRejectsInvalidTxID verifies that a malformed persisted +// withdrawal txid is rejected, while pending withdrawals remain readable via +// NULL values. +func TestGetAllWithdrawalsRejectsInvalidTxID(t *testing.T) { + ctxb := context.Background() + testDb := loopdb.NewTestDB(t) + defer testDb.Close() + + depositStore := deposit.NewSqlStore(testDb.BaseDB) + store := NewSqlStore(loopdb.NewTypedStore[Querier](testDb), depositStore) + + depositID, err := deposit.GetRandomDepositID() + require.NoError(t, err) + + d := &deposit.Deposit{ + ID: depositID, + Value: btcutil.Amount(100_000), + TimeOutSweepPkScript: []byte{ + 0x00, 0x14, 0x1a, 0x2b, 0x3c, 0x41, + }, + } + + err = depositStore.CreateDeposit(ctxb, d) + require.NoError(t, err) + + err = store.CreateWithdrawal(ctxb, []*deposit.Deposit{d}) + require.NoError(t, err) + + withdrawalID, err := testDb.Queries.GetWithdrawalIDByDepositID( + ctxb, d.ID[:], + ) + require.NoError(t, err) + + err = testDb.Queries.UpdateWithdrawal(ctxb, sqlc.UpdateWithdrawalParams{ + WithdrawalID: withdrawalID, + WithdrawalTxID: sql.NullString{ + String: "abcd", + Valid: true, + }, + }) + require.NoError(t, err) + + _, err = store.GetAllWithdrawals(ctxb) + require.ErrorContains(t, err, "invalid withdrawal txid") +} diff --git a/staticaddr/withdraw/withdrawal.go b/staticaddr/withdraw/withdrawal.go index 9c32ec736..a1525f697 100644 --- a/staticaddr/withdraw/withdrawal.go +++ b/staticaddr/withdraw/withdrawal.go @@ -29,14 +29,15 @@ func (r *ID) FromByteSlice(b []byte) error { return nil } -// Withdrawal represents a finalized static address withdrawal record in the -// database. +// Withdrawal represents a static address withdrawal record in the database. +// The record may be pending or finalized. type Withdrawal struct { // ID is the unique identifier of the deposit. ID ID - // TxID is the transaction ID of the withdrawal. - TxID chainhash.Hash + // TxID is the transaction ID of the withdrawal. It is nil until the + // confirmed withdrawal transaction is persisted. + TxID *chainhash.Hash // Deposits is a list of deposits used to fund the withdrawal. Deposits []*deposit.Deposit @@ -46,17 +47,19 @@ type Withdrawal struct { TotalDepositAmount btcutil.Amount // WithdrawnAmount is the amount withdrawn. It represents the total - // value of selected deposits minus fees and change. + // value of selected deposits minus fees and change. It is zero until the + // confirmed withdrawal transaction is persisted. WithdrawnAmount btcutil.Amount - // ChangeAmount is the optional change returned to the static address. + // ChangeAmount is the optional change returned to the static address. It + // is zero until the confirmed withdrawal transaction is persisted. ChangeAmount btcutil.Amount // InitiationTime is the time at which the withdrawal was initiated. InitiationTime time.Time // ConfirmationHeight is the block height at which the withdrawal was - // confirmed. + // confirmed. It is zero until the withdrawal is confirmed. ConfirmationHeight int64 } diff --git a/sweepbatcher/store.go b/sweepbatcher/store.go index 4cf899c85..1f86645d2 100644 --- a/sweepbatcher/store.go +++ b/sweepbatcher/store.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/loopdb/sqlc" + "github.com/lightninglabs/loop/utils/chainhashutil" "github.com/lightningnetwork/lnd/lntypes" ) @@ -91,7 +92,7 @@ func (s *SQLStore) FetchUnconfirmedSweepBatches(ctx context.Context) ( } for _, dbBatch := range dbBatches { - batch := convertBatchRow(dbBatch) + batch, err := convertBatchRow(dbBatch) if err != nil { return nil, err } @@ -99,7 +100,7 @@ func (s *SQLStore) FetchUnconfirmedSweepBatches(ctx context.Context) ( batches = append(batches, batch) } - return batches, err + return batches, nil } // InsertSweepBatch inserts a batch into the database, returning the id of the @@ -198,7 +199,7 @@ func (s *SQLStore) GetParentBatch(ctx context.Context, outpoint wire.OutPoint) ( return nil, err } - return convertBatchRow(batch), nil + return convertBatchRow(batch) } // UpsertSweep inserts a sweep into the database, or updates an existing sweep @@ -258,17 +259,24 @@ type dbSweep struct { } // convertBatchRow converts a batch row from db to a sweepbatcher.Batch struct. -func convertBatchRow(row sqlc.SweepBatch) *dbBatch { +func convertBatchRow(row sqlc.SweepBatch) (*dbBatch, error) { batch := dbBatch{ ID: row.ID, Confirmed: row.Confirmed, } - if row.BatchTxID.Valid { - err := chainhash.Decode(&batch.BatchTxid, row.BatchTxID.String) + // Loop never writes empty batch txids, but tolerate them on read so a + // malformed row does not prevent batcher recovery. + if row.BatchTxID.Valid && row.BatchTxID.String != "" { + hash, err := chainhashutil.NewHashFromStrExact( + row.BatchTxID.String, + ) if err != nil { - return nil + return nil, fmt.Errorf("invalid batch txid %q: %w", + row.BatchTxID.String, err) } + + batch.BatchTxid = hash } batch.BatchPkScript = row.BatchPkScript @@ -283,7 +291,7 @@ func convertBatchRow(row sqlc.SweepBatch) *dbBatch { batch.MaxTimeoutDistance = row.MaxTimeoutDistance - return &batch + return &batch, nil } // batchToInsertArgs converts a Batch struct to the arguments needed to insert diff --git a/sweepbatcher/store_test.go b/sweepbatcher/store_test.go new file mode 100644 index 000000000..b9f02cfd4 --- /dev/null +++ b/sweepbatcher/store_test.go @@ -0,0 +1,147 @@ +package sweepbatcher + +import ( + "context" + "database/sql" + "strings" + "testing" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/loopdb/sqlc" + "github.com/stretchr/testify/require" +) + +// TestFetchUnconfirmedSweepBatchesRejectsInvalidBatchTxID verifies that +// malformed persisted batch txids are rejected during batch loading. +func TestFetchUnconfirmedSweepBatchesRejectsInvalidBatchTxID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + txid string + errMsg string + }{ + { + name: "short", + txid: "abcd", + errMsg: "invalid batch txid", + }, + { + name: "non-hex", + txid: strings.Repeat("z", 64), + errMsg: "invalid batch txid", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctxb := t.Context() + testDb := loopdb.NewTestDB(t) + defer testDb.Close() + + store := NewSQLStore( + loopdb.NewTypedStore[Querier](testDb), + &chaincfg.RegressionNetParams, + ) + + _, err := testDb.Queries.InsertBatch( + ctxb, sqlc.InsertBatchParams{ + BatchTxID: sql.NullString{ + String: test.txid, + Valid: true, + }, + BatchPkScript: []byte{0x00}, + LastRbfHeight: sql.NullInt32{ + Int32: 1, + Valid: true, + }, + LastRbfSatPerKw: sql.NullInt32{ + Int32: 1000, + Valid: true, + }, + MaxTimeoutDistance: 1, + }, + ) + require.NoError(t, err) + + _, err = store.FetchUnconfirmedSweepBatches(ctxb) + require.ErrorContains(t, err, test.errMsg) + }) + } +} + +// TestFetchUnconfirmedSweepBatchesAllowsEmptyBatchTxID verifies that empty +// persisted batch txids are tolerated for recovery robustness. +func TestFetchUnconfirmedSweepBatchesAllowsEmptyBatchTxID(t *testing.T) { + t.Parallel() + + ctxb := t.Context() + testDb := loopdb.NewTestDB(t) + defer testDb.Close() + + store := NewSQLStore( + loopdb.NewTypedStore[Querier](testDb), + &chaincfg.RegressionNetParams, + ) + + _, err := testDb.Queries.InsertBatch( + ctxb, sqlc.InsertBatchParams{ + BatchTxID: sql.NullString{ + String: "", + Valid: true, + }, + BatchPkScript: []byte{0x00}, + LastRbfHeight: sql.NullInt32{ + Int32: 1, + Valid: true, + }, + LastRbfSatPerKw: sql.NullInt32{ + Int32: 1000, + Valid: true, + }, + MaxTimeoutDistance: 1, + }, + ) + require.NoError(t, err) + + _, err = store.FetchUnconfirmedSweepBatches(ctxb) + require.NoError(t, err) +} + +// parentBatchDB is a test double that overrides GetParentBatch while reusing +// the rest of the SQL store interface from an embedded BaseDB. +type parentBatchDB struct { + BaseDB + + batch sqlc.SweepBatch + err error +} + +// GetParentBatch returns the preconfigured batch row for store tests. +func (s parentBatchDB) GetParentBatch(ctx context.Context, + outpoint string) (sqlc.SweepBatch, error) { + + return s.batch, s.err +} + +// TestGetParentBatchRejectsInvalidBatchTxID verifies that malformed persisted +// batch txids are rejected through the GetParentBatch path as well. +func TestGetParentBatchRejectsInvalidBatchTxID(t *testing.T) { + t.Parallel() + + store := NewSQLStore(parentBatchDB{ + batch: sqlc.SweepBatch{ + BatchTxID: sql.NullString{ + String: "abcd", + Valid: true, + }, + }, + }, &chaincfg.RegressionNetParams) + + _, err := store.GetParentBatch(t.Context(), wire.OutPoint{}) + require.ErrorContains(t, err, "invalid batch txid") +} diff --git a/utils.go b/utils.go index 69d79fda9..f7fdf2325 100644 --- a/utils.go +++ b/utils.go @@ -4,12 +4,9 @@ import ( "bytes" "context" "fmt" - "strconv" - "strings" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/lndclient" "github.com/lightningnetwork/lnd/graph/db/models" @@ -85,30 +82,6 @@ func fetchChannelEdgesByID(ctx context.Context, return edgeInfo, policy1, policy2, nil } -// parseOutPoint attempts to parse an outpoint from the passed in string. -func parseOutPoint(s string) (*wire.OutPoint, error) { - split := strings.Split(s, ":") - if len(split) != 2 { - return nil, fmt.Errorf("expecting outpoint to be in format "+ - "of txid:index: %s", s) - } - - index, err := strconv.ParseInt(split[1], 10, 32) - if err != nil { - return nil, fmt.Errorf("unable to decode output index: %v", err) - } - - txid, err := chainhash.NewHashFromStr(split[0]) - if err != nil { - return nil, fmt.Errorf("unable to parse hex string: %v", err) - } - - return &wire.OutPoint{ - Hash: *txid, - Index: uint32(index), - }, nil -} - // getAlias tries to get the ShortChannelId from the passed ChannelId and // aliasCache. func getAlias(aliasCache map[lnwire.ChannelID]lnwire.ShortChannelID, @@ -143,9 +116,12 @@ func SelectHopHints(ctx context.Context, lndClient lndclient.LightningClient, } } - outPoint, err := parseOutPoint(channel.ChannelPoint) + outPoint, err := wire.NewOutPointFromString( + channel.ChannelPoint, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to parse outpoint: %w", + err) } remotePubkey, err := btcec.ParsePubKey(channel.PubKeyBytes[:]) diff --git a/utils/chainhashutil/strict.go b/utils/chainhashutil/strict.go new file mode 100644 index 000000000..202601792 --- /dev/null +++ b/utils/chainhashutil/strict.go @@ -0,0 +1,26 @@ +package chainhashutil + +import ( + "fmt" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// NewHashFromStrExact parses a chainhash string that must be fully specified. +func NewHashFromStrExact(hash string) (chainhash.Hash, error) { + if len(hash) != chainhash.MaxHashStringSize { + return chainhash.Hash{}, fmt.Errorf( + "invalid hash string length of %v, want %v", + len(hash), chainhash.MaxHashStringSize) + } + + parsed, err := chainhash.NewHashFromStr(hash) + if err != nil { + return chainhash.Hash{}, err + } + + // chainhash.NewHashFromStr uses a pointer return, but on success it + // returns a populated hash, not (nil, nil), so dereferencing here is + // safe. + return *parsed, nil +} diff --git a/utils/chainhashutil/strict_test.go b/utils/chainhashutil/strict_test.go new file mode 100644 index 000000000..916964028 --- /dev/null +++ b/utils/chainhashutil/strict_test.go @@ -0,0 +1,64 @@ +package chainhashutil + +import ( + "strings" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/stretchr/testify/require" +) + +// TestNewHashFromStrExact verifies that strict hash parsing rejects any +// non-fully-specified chainhash string. +func TestNewHashFromStrExact(t *testing.T) { + t.Parallel() + + validHash := strings.Repeat("01", 32) + + testCases := []struct { + name string + hash string + wantErr string + }{ + { + name: "valid", + hash: validHash, + wantErr: "", + }, + { + name: "short", + hash: validHash[:62], + wantErr: "invalid hash string length", + }, + { + name: "odd length", + hash: validHash[:63], + wantErr: "invalid hash string length", + }, + { + name: "empty", + hash: "", + wantErr: "invalid hash string length", + }, + { + name: "non hex", + hash: strings.Repeat("z", 64), + wantErr: "invalid byte", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + hash, err := NewHashFromStrExact(testCase.hash) + if testCase.wantErr != "" { + require.ErrorContains(t, err, testCase.wantErr) + require.Equal(t, chainhash.Hash{}, hash) + } else { + require.NoError(t, err) + require.Equal(t, testCase.hash, hash.String()) + } + }) + } +}