Skip to content

Commit 727cbe2

Browse files
committed
staticaddr: add restore and reconciliation hooks
1 parent d95da38 commit 727cbe2

3 files changed

Lines changed: 207 additions & 16 deletions

File tree

staticaddr/address/manager.go

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package address
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
8+
"strings"
79
"sync"
810
"sync/atomic"
911

@@ -21,6 +23,12 @@ import (
2123
"github.com/lightningnetwork/lnd/lnwallet"
2224
)
2325

26+
var (
27+
// ErrNoStaticAddress is returned when no static address parameters are
28+
// present in the store.
29+
ErrNoStaticAddress = errors.New("no static address parameters found")
30+
)
31+
2432
// ManagerConfig holds the configuration for the address manager.
2533
type ManagerConfig struct {
2634
// AddressClient is the client that communicates with the loop server
@@ -45,6 +53,10 @@ type ManagerConfig struct {
4553
// ChainNotifier is the chain notifier that is used to listen for new
4654
// blocks.
4755
ChainNotifier lndclient.ChainNotifierClient
56+
57+
// OnStaticAddressCreated is called after a new static address has been
58+
// stored locally and its tapscript has been imported into lnd.
59+
OnStaticAddressCreated func(context.Context) error
4860
}
4961

5062
// Manager manages the address state machines.
@@ -199,27 +211,139 @@ func (m *Manager) NewAddress(ctx context.Context) (*btcutil.AddressTaproot,
199211
return nil, 0, err
200212
}
201213

214+
err = m.importAddressTapscript(ctx, staticAddress)
215+
if err != nil {
216+
return nil, 0, err
217+
}
218+
219+
address, err := m.GetTaprootAddress(
220+
clientPubKey.PubKey, serverPubKey, int64(serverParams.Expiry),
221+
)
222+
if err != nil {
223+
return nil, 0, err
224+
}
225+
226+
if m.cfg.OnStaticAddressCreated != nil {
227+
err = m.cfg.OnStaticAddressCreated(ctx)
228+
if err != nil {
229+
return nil, 0, err
230+
}
231+
}
232+
233+
return address, int64(serverParams.Expiry), nil
234+
}
235+
236+
// RestoreAddress recreates a static address record locally and makes sure the
237+
// corresponding tapscript is imported into lnd. If the same address already
238+
// exists locally, the call is idempotent.
239+
func (m *Manager) RestoreAddress(ctx context.Context,
240+
addrParams *Parameters) (*btcutil.AddressTaproot, error) {
241+
242+
if addrParams == nil {
243+
return nil, fmt.Errorf("missing static address parameters")
244+
}
245+
246+
staticAddress, err := script.NewStaticAddress(
247+
input.MuSig2Version100RC2, int64(addrParams.Expiry),
248+
addrParams.ClientPubkey, addrParams.ServerPubkey,
249+
)
250+
if err != nil {
251+
return nil, err
252+
}
253+
254+
pkScript, err := staticAddress.StaticAddressScript()
255+
if err != nil {
256+
return nil, err
257+
}
258+
259+
if len(addrParams.PkScript) != 0 &&
260+
!bytes.Equal(addrParams.PkScript, pkScript) {
261+
262+
return nil, fmt.Errorf("static address pk script mismatch")
263+
}
264+
265+
addrParams.PkScript = pkScript
266+
if addrParams.InitiationHeight <= 0 {
267+
addrParams.InitiationHeight = m.currentHeight.Load()
268+
}
269+
270+
m.Lock()
271+
existing, err := m.cfg.Store.GetAllStaticAddresses(ctx)
272+
if err != nil {
273+
m.Unlock()
274+
275+
return nil, err
276+
}
277+
switch {
278+
case len(existing) == 0:
279+
err = m.cfg.Store.CreateStaticAddress(ctx, addrParams)
280+
if err != nil {
281+
m.Unlock()
282+
283+
return nil, err
284+
}
285+
286+
case len(existing) > 1:
287+
m.Unlock()
288+
289+
return nil, fmt.Errorf("more than one static address found")
290+
291+
case !sameAddressParameters(existing[0], addrParams):
292+
m.Unlock()
293+
294+
return nil, fmt.Errorf("existing static address differs from " +
295+
"backup")
296+
}
297+
m.Unlock()
298+
299+
err = m.importAddressTapscript(ctx, staticAddress)
300+
if err != nil {
301+
return nil, err
302+
}
303+
304+
return m.GetTaprootAddress(
305+
addrParams.ClientPubkey, addrParams.ServerPubkey,
306+
int64(addrParams.Expiry),
307+
)
308+
}
309+
310+
func (m *Manager) importAddressTapscript(ctx context.Context,
311+
staticAddress *script.StaticAddress) error {
312+
202313
// Import the static address tapscript into our lnd wallet, so we can
203314
// track unspent outputs of it.
204315
tapScript := input.TapscriptFullTree(
205316
staticAddress.InternalPubKey, *staticAddress.TimeoutLeaf,
206317
)
207318
addr, err := m.cfg.WalletKit.ImportTaprootScript(ctx, tapScript)
208319
if err != nil {
209-
return nil, 0, err
320+
// Restoring into an lnd instance that already imported the script is
321+
// expected. Treat the duplicate import as success.
322+
if strings.Contains(err.Error(), "already exists") {
323+
log.Infof("Static address tapscript already imported")
324+
return nil
325+
}
326+
327+
return err
210328
}
211329

212330
log.Infof("Imported static address taproot script to lnd wallet: %v",
213331
addr)
214332

215-
address, err := m.GetTaprootAddress(
216-
clientPubKey.PubKey, serverPubKey, int64(serverParams.Expiry),
217-
)
218-
if err != nil {
219-
return nil, 0, err
333+
return nil
334+
}
335+
336+
func sameAddressParameters(a, b *Parameters) bool {
337+
if a == nil || b == nil {
338+
return false
220339
}
221340

222-
return address, int64(serverParams.Expiry), nil
341+
return a.ClientPubkey.IsEqual(b.ClientPubkey) &&
342+
a.ServerPubkey.IsEqual(b.ServerPubkey) &&
343+
a.Expiry == b.Expiry &&
344+
bytes.Equal(a.PkScript, b.PkScript) &&
345+
a.KeyLocator == b.KeyLocator &&
346+
a.ProtocolVersion == b.ProtocolVersion
223347
}
224348

225349
// GetTaprootAddress returns a taproot address for the given client and server
@@ -297,7 +421,7 @@ func (m *Manager) GetStaticAddressParameters(ctx context.Context) (*Parameters,
297421
}
298422

299423
if len(params) == 0 {
300-
return nil, fmt.Errorf("no static address parameters found")
424+
return nil, ErrNoStaticAddress
301425
}
302426

303427
return params[0], nil

staticaddr/address/manager_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,57 @@ func TestManager(t *testing.T) {
128128
require.EqualValues(t, defaultExpiry, expiry)
129129
}
130130

131+
// TestRestoreAddress verifies that restoring an address recreates the same
132+
// static address locally without requiring a server call.
133+
func TestRestoreAddress(t *testing.T) {
134+
ctxb := t.Context()
135+
136+
testContext := NewAddressManagerTestContext(t)
137+
138+
keyDesc, err := testContext.mockLnd.WalletKit.DeriveKey(
139+
ctxb, &keychain.KeyLocator{
140+
Family: keychain.KeyFamily(swap.StaticAddressKeyFamily),
141+
Index: 7,
142+
},
143+
)
144+
require.NoError(t, err)
145+
146+
staticAddress, err := script.NewStaticAddress(
147+
input.MuSig2Version100RC2, int64(defaultExpiry),
148+
keyDesc.PubKey, defaultServerPubkey,
149+
)
150+
require.NoError(t, err)
151+
152+
pkScript, err := staticAddress.StaticAddressScript()
153+
require.NoError(t, err)
154+
155+
addressParams := &Parameters{
156+
ClientPubkey: keyDesc.PubKey,
157+
ServerPubkey: defaultServerPubkey,
158+
Expiry: defaultExpiry,
159+
PkScript: pkScript,
160+
KeyLocator: keyDesc.KeyLocator,
161+
ProtocolVersion: 0,
162+
InitiationHeight: 123,
163+
}
164+
165+
taprootAddress, err := testContext.manager.RestoreAddress(
166+
ctxb, addressParams,
167+
)
168+
require.NoError(t, err)
169+
170+
expectedAddress, err := btcutil.NewAddressTaproot(
171+
schnorr.SerializePubKey(staticAddress.TaprootKey),
172+
testContext.manager.cfg.ChainParams,
173+
)
174+
require.NoError(t, err)
175+
require.Equal(t, expectedAddress.String(), taprootAddress.String())
176+
177+
storedParams, err := testContext.manager.GetStaticAddressParameters(ctxb)
178+
require.NoError(t, err)
179+
require.True(t, sameAddressParameters(storedParams, addressParams))
180+
}
181+
131182
// GenerateExpectedTaprootAddress generates the expected taproot address that
132183
// the predefined parameters are supposed to generate.
133184
func GenerateExpectedTaprootAddress(t *ManagerTestContext) (

staticaddr/deposit/manager.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ type Manager struct {
6464
// mu guards access to the activeDeposits map.
6565
mu sync.Mutex
6666

67+
// reconcileMu serializes deposit recovery and reconciliation so restore
68+
// requests can't race the background polling loop.
69+
reconcileMu sync.Mutex
70+
6771
// activeDeposits contains all the active static address outputs.
6872
activeDeposits map[wire.OutPoint]*FSM
6973

@@ -108,7 +112,7 @@ func (m *Manager) Run(ctx context.Context, initChan chan struct{}) error {
108112

109113
// Reconcile immediately on startup so deposits are available
110114
// before the first ticker fires.
111-
err = m.reconcileDeposits(ctx)
115+
_, err = m.ReconcileDeposits(ctx)
112116
if err != nil {
113117
log.Errorf("unable to reconcile deposits: %v", err)
114118
}
@@ -162,6 +166,9 @@ func (m *Manager) Run(ctx context.Context, initChan chan struct{}) error {
162166
// recoverDeposits recovers static address parameters, previous deposits and
163167
// state machines from the database and starts the deposit notifier.
164168
func (m *Manager) recoverDeposits(ctx context.Context) error {
169+
m.reconcileMu.Lock()
170+
defer m.reconcileMu.Unlock()
171+
165172
log.Infof("Recovering static address parameters and deposits...")
166173

167174
// Recover deposits.
@@ -218,7 +225,7 @@ func (m *Manager) pollDeposits(ctx context.Context) {
218225
for {
219226
select {
220227
case <-ticker.C:
221-
err := m.reconcileDeposits(ctx)
228+
_, err := m.ReconcileDeposits(ctx)
222229
if err != nil {
223230
log.Errorf("unable to reconcile "+
224231
"deposits: %v", err)
@@ -235,38 +242,47 @@ func (m *Manager) pollDeposits(ctx context.Context) {
235242
// wallet and matches it against the deposits in our memory that we've seen so
236243
// far. It picks the newly identified deposits and starts a state machine per
237244
// deposit to track its progress.
238-
func (m *Manager) reconcileDeposits(ctx context.Context) error {
245+
func (m *Manager) reconcileDeposits(ctx context.Context) (int, error) {
239246
log.Tracef("Reconciling new deposits...")
240247

241248
utxos, err := m.cfg.AddressManager.ListUnspent(
242249
ctx, MinConfs, MaxConfs,
243250
)
244251
if err != nil {
245-
return fmt.Errorf("unable to list new deposits: %w", err)
252+
return 0, fmt.Errorf("unable to list new deposits: %w", err)
246253
}
247254

248255
newDeposits := m.filterNewDeposits(utxos)
249256
if len(newDeposits) == 0 {
250257
log.Tracef("No new deposits...")
251-
return nil
258+
return 0, nil
252259
}
253260

254261
for _, utxo := range newDeposits {
255262
deposit, err := m.createNewDeposit(ctx, utxo)
256263
if err != nil {
257-
return fmt.Errorf("unable to retain new deposit: %w",
264+
return 0, fmt.Errorf("unable to retain new deposit: %w",
258265
err)
259266
}
260267

261268
log.Debugf("Received deposit: %v", deposit)
262269
err = m.startDepositFsm(ctx, deposit)
263270
if err != nil {
264-
return fmt.Errorf("unable to start new deposit FSM: %w",
271+
return 0, fmt.Errorf("unable to start new deposit FSM: %w",
265272
err)
266273
}
267274
}
268275

269-
return nil
276+
return len(newDeposits), nil
277+
}
278+
279+
// ReconcileDeposits triggers a best-effort reconciliation pass and returns the
280+
// number of newly discovered deposits.
281+
func (m *Manager) ReconcileDeposits(ctx context.Context) (int, error) {
282+
m.reconcileMu.Lock()
283+
defer m.reconcileMu.Unlock()
284+
285+
return m.reconcileDeposits(ctx)
270286
}
271287

272288
// createNewDeposit transforms the wallet utxo into a deposit struct and stores

0 commit comments

Comments
 (0)