Skip to content

Commit f986b14

Browse files
committed
staticaddr: add restore and reconciliation hooks
1 parent 2f5e821 commit f986b14

4 files changed

Lines changed: 279 additions & 16 deletions

File tree

staticaddr/address/manager.go

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package address
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
8+
"strings"
79
"sync"
810
"sync/atomic"
911

@@ -21,6 +23,12 @@ import (
2123
"github.com/lightningnetwork/lnd/lnwallet"
2224
)
2325

26+
var (
27+
// ErrNoStaticAddress is returned when no static address parameters are
28+
// present in the store.
29+
ErrNoStaticAddress = errors.New("no static address parameters found")
30+
)
31+
2432
// ManagerConfig holds the configuration for the address manager.
2533
type ManagerConfig struct {
2634
// AddressClient is the client that communicates with the loop server
@@ -199,27 +207,143 @@ func (m *Manager) NewAddress(ctx context.Context) (*btcutil.AddressTaproot,
199207
return nil, 0, err
200208
}
201209

210+
_, err = m.importAddressTapscript(ctx, staticAddress)
211+
if err != nil {
212+
return nil, 0, err
213+
}
214+
215+
address, err := m.GetTaprootAddress(
216+
clientPubKey.PubKey, serverPubKey, int64(serverParams.Expiry),
217+
)
218+
if err != nil {
219+
return nil, 0, err
220+
}
221+
222+
return address, int64(serverParams.Expiry), nil
223+
}
224+
225+
// RestoreAddress recreates a static address record locally and makes sure the
226+
// corresponding tapscript is imported into lnd. If the same address already
227+
// exists locally, the call is idempotent.
228+
func (m *Manager) RestoreAddress(ctx context.Context,
229+
addrParams *Parameters) (*btcutil.AddressTaproot, bool, error) {
230+
231+
if addrParams == nil {
232+
return nil, false, fmt.Errorf("missing static address parameters")
233+
}
234+
235+
staticAddress, err := script.NewStaticAddress(
236+
input.MuSig2Version100RC2, int64(addrParams.Expiry),
237+
addrParams.ClientPubkey, addrParams.ServerPubkey,
238+
)
239+
if err != nil {
240+
return nil, false, err
241+
}
242+
243+
pkScript, err := staticAddress.StaticAddressScript()
244+
if err != nil {
245+
return nil, false, err
246+
}
247+
248+
if len(addrParams.PkScript) != 0 &&
249+
!bytes.Equal(addrParams.PkScript, pkScript) {
250+
251+
return nil, false, fmt.Errorf("static address pk script mismatch")
252+
}
253+
254+
addrParams.PkScript = pkScript
255+
if addrParams.InitiationHeight <= 0 {
256+
addrParams.InitiationHeight = m.currentHeight.Load()
257+
}
258+
259+
m.Lock()
260+
existing, err := m.cfg.Store.GetAllStaticAddresses(ctx)
261+
if err != nil {
262+
m.Unlock()
263+
264+
return nil, false, err
265+
}
266+
267+
changed := false
268+
switch {
269+
case len(existing) == 0:
270+
err = m.cfg.Store.CreateStaticAddress(ctx, addrParams)
271+
if err != nil {
272+
m.Unlock()
273+
274+
return nil, false, err
275+
}
276+
changed = true
277+
278+
case len(existing) > 1:
279+
m.Unlock()
280+
281+
return nil, false, fmt.Errorf("more than one static address found")
282+
283+
case !sameAddressParameters(existing[0], addrParams):
284+
m.Unlock()
285+
286+
return nil, false, fmt.Errorf("existing static address differs from " +
287+
"backup")
288+
}
289+
m.Unlock()
290+
291+
imported, err := m.importAddressTapscript(ctx, staticAddress)
292+
if err != nil {
293+
return nil, false, err
294+
}
295+
296+
changed = changed || imported
297+
298+
addr, err := m.GetTaprootAddress(
299+
addrParams.ClientPubkey, addrParams.ServerPubkey,
300+
int64(addrParams.Expiry),
301+
)
302+
if err != nil {
303+
return nil, false, err
304+
}
305+
306+
return addr, changed, nil
307+
}
308+
309+
func (m *Manager) importAddressTapscript(ctx context.Context,
310+
staticAddress *script.StaticAddress) (bool, error) {
311+
202312
// Import the static address tapscript into our lnd wallet, so we can
203313
// track unspent outputs of it.
204314
tapScript := input.TapscriptFullTree(
205315
staticAddress.InternalPubKey, *staticAddress.TimeoutLeaf,
206316
)
207317
addr, err := m.cfg.WalletKit.ImportTaprootScript(ctx, tapScript)
208318
if err != nil {
209-
return nil, 0, err
319+
// Restoring into an lnd instance that already imported the script is
320+
// expected. Treat the duplicate import as success.
321+
if strings.Contains(err.Error(), "already exists") {
322+
log.Infof("Static address tapscript already imported")
323+
return false, nil
324+
}
325+
326+
return false, err
210327
}
211328

212329
log.Infof("Imported static address taproot script to lnd wallet: %v",
213330
addr)
214331

215-
address, err := m.GetTaprootAddress(
216-
clientPubKey.PubKey, serverPubKey, int64(serverParams.Expiry),
217-
)
218-
if err != nil {
219-
return nil, 0, err
332+
return true, nil
333+
}
334+
335+
func sameAddressParameters(a, b *Parameters) bool {
336+
if a == nil || b == nil {
337+
return false
220338
}
221339

222-
return address, int64(serverParams.Expiry), nil
340+
return a.ClientPubkey.IsEqual(b.ClientPubkey) &&
341+
a.ServerPubkey.IsEqual(b.ServerPubkey) &&
342+
a.Expiry == b.Expiry &&
343+
bytes.Equal(a.PkScript, b.PkScript) &&
344+
a.KeyLocator == b.KeyLocator &&
345+
a.ProtocolVersion == b.ProtocolVersion &&
346+
a.InitiationHeight == b.InitiationHeight
223347
}
224348

225349
// GetTaprootAddress returns a taproot address for the given client and server
@@ -297,7 +421,7 @@ func (m *Manager) GetStaticAddressParameters(ctx context.Context) (*Parameters,
297421
}
298422

299423
if len(params) == 0 {
300-
return nil, fmt.Errorf("no static address parameters found")
424+
return nil, ErrNoStaticAddress
301425
}
302426

303427
return params[0], nil

staticaddr/address/manager_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,110 @@ func TestManager(t *testing.T) {
128128
require.EqualValues(t, defaultExpiry, expiry)
129129
}
130130

131+
// TestRestoreAddress verifies that restoring an address recreates the same
132+
// static address locally without requiring a server call.
133+
func TestRestoreAddress(t *testing.T) {
134+
ctxb := t.Context()
135+
136+
testContext := NewAddressManagerTestContext(t)
137+
138+
keyDesc, err := testContext.mockLnd.WalletKit.DeriveKey(
139+
ctxb, &keychain.KeyLocator{
140+
Family: keychain.KeyFamily(swap.StaticAddressKeyFamily),
141+
Index: 7,
142+
},
143+
)
144+
require.NoError(t, err)
145+
146+
staticAddress, err := script.NewStaticAddress(
147+
input.MuSig2Version100RC2, int64(defaultExpiry),
148+
keyDesc.PubKey, defaultServerPubkey,
149+
)
150+
require.NoError(t, err)
151+
152+
pkScript, err := staticAddress.StaticAddressScript()
153+
require.NoError(t, err)
154+
155+
addressParams := &Parameters{
156+
ClientPubkey: keyDesc.PubKey,
157+
ServerPubkey: defaultServerPubkey,
158+
Expiry: defaultExpiry,
159+
PkScript: pkScript,
160+
KeyLocator: keyDesc.KeyLocator,
161+
ProtocolVersion: 0,
162+
InitiationHeight: 123,
163+
}
164+
165+
taprootAddress, restored, err := testContext.manager.RestoreAddress(
166+
ctxb, addressParams,
167+
)
168+
require.NoError(t, err)
169+
require.True(t, restored)
170+
171+
expectedAddress, err := btcutil.NewAddressTaproot(
172+
schnorr.SerializePubKey(staticAddress.TaprootKey),
173+
testContext.manager.cfg.ChainParams,
174+
)
175+
require.NoError(t, err)
176+
require.Equal(t, expectedAddress.String(), taprootAddress.String())
177+
178+
storedParams, err := testContext.manager.GetStaticAddressParameters(ctxb)
179+
require.NoError(t, err)
180+
require.True(t, sameAddressParameters(storedParams, addressParams))
181+
182+
taprootAddress, restored, err = testContext.manager.RestoreAddress(
183+
ctxb, addressParams,
184+
)
185+
require.NoError(t, err)
186+
require.False(t, restored)
187+
require.Equal(t, expectedAddress.String(), taprootAddress.String())
188+
}
189+
190+
// TestRestoreAddressRejectsDifferentInitiationHeight verifies that a restore
191+
// request with the same address material but a different initiation height is
192+
// rejected instead of being treated as idempotent.
193+
func TestRestoreAddressRejectsDifferentInitiationHeight(t *testing.T) {
194+
ctxb := t.Context()
195+
196+
testContext := NewAddressManagerTestContext(t)
197+
198+
keyDesc, err := testContext.mockLnd.WalletKit.DeriveKey(
199+
ctxb, &keychain.KeyLocator{
200+
Family: keychain.KeyFamily(swap.StaticAddressKeyFamily),
201+
Index: 7,
202+
},
203+
)
204+
require.NoError(t, err)
205+
206+
staticAddress, err := script.NewStaticAddress(
207+
input.MuSig2Version100RC2, int64(defaultExpiry),
208+
keyDesc.PubKey, defaultServerPubkey,
209+
)
210+
require.NoError(t, err)
211+
212+
pkScript, err := staticAddress.StaticAddressScript()
213+
require.NoError(t, err)
214+
215+
addressParams := &Parameters{
216+
ClientPubkey: keyDesc.PubKey,
217+
ServerPubkey: defaultServerPubkey,
218+
Expiry: defaultExpiry,
219+
PkScript: pkScript,
220+
KeyLocator: keyDesc.KeyLocator,
221+
ProtocolVersion: 0,
222+
InitiationHeight: 123,
223+
}
224+
225+
_, _, err = testContext.manager.RestoreAddress(ctxb, addressParams)
226+
require.NoError(t, err)
227+
228+
differentHeight := *addressParams
229+
differentHeight.InitiationHeight = 456
230+
231+
_, _, err = testContext.manager.RestoreAddress(ctxb, &differentHeight)
232+
require.ErrorContains(t, err, "existing static address differs from backup")
233+
}
234+
131235
// GenerateExpectedTaprootAddress generates the expected taproot address that
132236
// the predefined parameters are supposed to generate.
133237
func GenerateExpectedTaprootAddress(t *ManagerTestContext) (

staticaddr/deposit/manager.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ type Manager struct {
6464
// mu guards access to the activeDeposits map.
6565
mu sync.Mutex
6666

67+
// reconcileMu serializes deposit recovery and reconciliation so restore
68+
// requests can't race the background polling loop.
69+
reconcileMu sync.Mutex
70+
6771
// activeDeposits contains all the active static address outputs.
6872
activeDeposits map[wire.OutPoint]*FSM
6973

@@ -108,7 +112,7 @@ func (m *Manager) Run(ctx context.Context, initChan chan struct{}) error {
108112

109113
// Reconcile immediately on startup so deposits are available
110114
// before the first ticker fires.
111-
err = m.reconcileDeposits(ctx)
115+
_, err = m.ReconcileDeposits(ctx)
112116
if err != nil {
113117
log.Errorf("unable to reconcile deposits: %v", err)
114118
}
@@ -162,6 +166,9 @@ func (m *Manager) Run(ctx context.Context, initChan chan struct{}) error {
162166
// recoverDeposits recovers static address parameters, previous deposits and
163167
// state machines from the database and starts the deposit notifier.
164168
func (m *Manager) recoverDeposits(ctx context.Context) error {
169+
m.reconcileMu.Lock()
170+
defer m.reconcileMu.Unlock()
171+
165172
log.Infof("Recovering static address parameters and deposits...")
166173

167174
// Recover deposits.
@@ -218,7 +225,7 @@ func (m *Manager) pollDeposits(ctx context.Context) {
218225
for {
219226
select {
220227
case <-ticker.C:
221-
err := m.reconcileDeposits(ctx)
228+
_, err := m.ReconcileDeposits(ctx)
222229
if err != nil {
223230
log.Errorf("unable to reconcile "+
224231
"deposits: %v", err)
@@ -235,38 +242,47 @@ func (m *Manager) pollDeposits(ctx context.Context) {
235242
// wallet and matches it against the deposits in our memory that we've seen so
236243
// far. It picks the newly identified deposits and starts a state machine per
237244
// deposit to track its progress.
238-
func (m *Manager) reconcileDeposits(ctx context.Context) error {
245+
func (m *Manager) reconcileDeposits(ctx context.Context) (int, error) {
239246
log.Tracef("Reconciling new deposits...")
240247

241248
utxos, err := m.cfg.AddressManager.ListUnspent(
242249
ctx, MinConfs, MaxConfs,
243250
)
244251
if err != nil {
245-
return fmt.Errorf("unable to list new deposits: %w", err)
252+
return 0, fmt.Errorf("unable to list new deposits: %w", err)
246253
}
247254

248255
newDeposits := m.filterNewDeposits(utxos)
249256
if len(newDeposits) == 0 {
250257
log.Tracef("No new deposits...")
251-
return nil
258+
return 0, nil
252259
}
253260

254261
for _, utxo := range newDeposits {
255262
deposit, err := m.createNewDeposit(ctx, utxo)
256263
if err != nil {
257-
return fmt.Errorf("unable to retain new deposit: %w",
264+
return 0, fmt.Errorf("unable to retain new deposit: %w",
258265
err)
259266
}
260267

261268
log.Debugf("Received deposit: %v", deposit)
262269
err = m.startDepositFsm(ctx, deposit)
263270
if err != nil {
264-
return fmt.Errorf("unable to start new deposit FSM: %w",
271+
return 0, fmt.Errorf("unable to start new deposit FSM: %w",
265272
err)
266273
}
267274
}
268275

269-
return nil
276+
return len(newDeposits), nil
277+
}
278+
279+
// ReconcileDeposits triggers a best-effort reconciliation pass and returns the
280+
// number of newly discovered deposits.
281+
func (m *Manager) ReconcileDeposits(ctx context.Context) (int, error) {
282+
m.reconcileMu.Lock()
283+
defer m.reconcileMu.Unlock()
284+
285+
return m.reconcileDeposits(ctx)
270286
}
271287

272288
// createNewDeposit transforms the wallet utxo into a deposit struct and stores

0 commit comments

Comments
 (0)