|
| 1 | +package chainparams |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "errors" |
| 7 | + "fmt" |
| 8 | + |
| 9 | + "github.com/btcsuite/btcd/chaincfg" |
| 10 | + "github.com/lightningnetwork/lnd/lncfg" |
| 11 | + "github.com/lightningnetwork/lnd/sqldb" |
| 12 | + "github.com/lightningnetwork/lnd/sqldb/sqlc" |
| 13 | +) |
| 14 | + |
| 15 | +// ErrNetworkMismatch is returned by ValidateNetwork when the network stored in |
| 16 | +// the database does not match the network lnd is configured to use. |
| 17 | +var ErrNetworkMismatch = errors.New("database network mismatch") |
| 18 | + |
| 19 | +// SQLChainParamQueries defines the SQL queries required by Store. |
| 20 | +type SQLChainParamQueries interface { |
| 21 | + InsertChainNetwork(ctx context.Context, network string) error |
| 22 | + GetChainNetwork(ctx context.Context) (string, error) |
| 23 | +} |
| 24 | + |
| 25 | +// BatchedChainParamQueries is a version of SQLChainParamQueries that is |
| 26 | +// capable of batched database operations. |
| 27 | +type BatchedChainParamQueries interface { |
| 28 | + SQLChainParamQueries |
| 29 | + |
| 30 | + sqldb.BatchedTx[SQLChainParamQueries] |
| 31 | +} |
| 32 | + |
| 33 | +// Store is a database-backed store that persists and retrieves chain-level |
| 34 | +// parameters such as the network the database was initialised for. |
| 35 | +type Store struct { |
| 36 | + db BatchedChainParamQueries |
| 37 | +} |
| 38 | + |
| 39 | +// NewStore creates a new chain params Store backed by the given BaseDB. |
| 40 | +func NewStore(db *sqldb.BaseDB) *Store { |
| 41 | + executor := sqldb.NewTransactionExecutor( |
| 42 | + db, func(tx *sql.Tx) SQLChainParamQueries { |
| 43 | + return db.WithTx(tx) |
| 44 | + }, |
| 45 | + ) |
| 46 | + |
| 47 | + return &Store{db: executor} |
| 48 | +} |
| 49 | + |
| 50 | +// ValidateNetwork checks that the network stored in the chain_params table |
| 51 | +// matches the provided network. On the first call the network is persisted so |
| 52 | +// that subsequent restarts can detect an accidental network switch. |
| 53 | +func (s *Store) ValidateNetwork(ctx context.Context, |
| 54 | + net *chaincfg.Params) error { |
| 55 | + |
| 56 | + network, err := normalizeNetworkName(net) |
| 57 | + if err != nil { |
| 58 | + return err |
| 59 | + } |
| 60 | + |
| 61 | + return s.db.ExecTx( |
| 62 | + ctx, sqldb.WriteTxOpt(), |
| 63 | + func(tx SQLChainParamQueries) error { |
| 64 | + // Insert the network only if the chain_params table is |
| 65 | + // still empty. This is a no-op on every startup after |
| 66 | + // the first. |
| 67 | + err := tx.InsertChainNetwork(ctx, network) |
| 68 | + if err != nil { |
| 69 | + return fmt.Errorf("unable to set network in "+ |
| 70 | + "chain_params: %w", err) |
| 71 | + } |
| 72 | + |
| 73 | + // Read back whatever is stored. This is either the |
| 74 | + // value we just inserted (first startup) or a value |
| 75 | + // from a previous run. |
| 76 | + storedNetwork, err := tx.GetChainNetwork(ctx) |
| 77 | + if err != nil { |
| 78 | + return fmt.Errorf("unable to read network "+ |
| 79 | + "from chain_params: %w", err) |
| 80 | + } |
| 81 | + |
| 82 | + if storedNetwork != network { |
| 83 | + return fmt.Errorf("%w: the database was "+ |
| 84 | + "previously used with network '%s', "+ |
| 85 | + "but lnd is now configured for "+ |
| 86 | + "network '%s'. To fix this, either "+ |
| 87 | + "point lnd at a different database "+ |
| 88 | + "or reconfigure lnd to use "+ |
| 89 | + "network '%s'", ErrNetworkMismatch, |
| 90 | + storedNetwork, network, storedNetwork) |
| 91 | + } |
| 92 | + |
| 93 | + return nil |
| 94 | + }, sqldb.NoOpReset, |
| 95 | + ) |
| 96 | +} |
| 97 | + |
| 98 | +// normalizeNetworkName returns the stable network identifier persisted in the |
| 99 | +// chain_params table. |
| 100 | +func normalizeNetworkName(net *chaincfg.Params) (string, error) { |
| 101 | + if net == nil { |
| 102 | + return "", fmt.Errorf("chain parameters must not be nil") |
| 103 | + } |
| 104 | + |
| 105 | + network := lncfg.NormalizeNetwork(net.Name) |
| 106 | + if network == "" { |
| 107 | + return "", fmt.Errorf("chain parameters must define a network") |
| 108 | + } |
| 109 | + |
| 110 | + return network, nil |
| 111 | +} |
| 112 | + |
| 113 | +// Compile-time check that *sqlc.Queries implements SQLChainParamQueries. |
| 114 | +var _ SQLChainParamQueries = (*sqlc.Queries)(nil) |
0 commit comments