Skip to content

Commit 0a2f486

Browse files
authored
Merge pull request #10684 from ziggie1984/postgres-network-separation
sqldb: add network-mismatch safeguard for native-SQL backends
2 parents 9f77b52 + e744e19 commit 0a2f486

16 files changed

Lines changed: 450 additions & 2 deletions

chainparams/store.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package chainparams
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
9+
"github.com/btcsuite/btcd/chaincfg"
10+
"github.com/lightningnetwork/lnd/lncfg"
11+
"github.com/lightningnetwork/lnd/sqldb"
12+
"github.com/lightningnetwork/lnd/sqldb/sqlc"
13+
)
14+
15+
// ErrNetworkMismatch is returned by ValidateNetwork when the network stored in
16+
// the database does not match the network lnd is configured to use.
17+
var ErrNetworkMismatch = errors.New("database network mismatch")
18+
19+
// SQLChainParamQueries defines the SQL queries required by Store.
20+
type SQLChainParamQueries interface {
21+
InsertChainNetwork(ctx context.Context, network string) error
22+
GetChainNetwork(ctx context.Context) (string, error)
23+
}
24+
25+
// BatchedChainParamQueries is a version of SQLChainParamQueries that is
26+
// capable of batched database operations.
27+
type BatchedChainParamQueries interface {
28+
SQLChainParamQueries
29+
30+
sqldb.BatchedTx[SQLChainParamQueries]
31+
}
32+
33+
// Store is a database-backed store that persists and retrieves chain-level
34+
// parameters such as the network the database was initialised for.
35+
type Store struct {
36+
db BatchedChainParamQueries
37+
}
38+
39+
// NewStore creates a new chain params Store backed by the given BaseDB.
40+
func NewStore(db *sqldb.BaseDB) *Store {
41+
executor := sqldb.NewTransactionExecutor(
42+
db, func(tx *sql.Tx) SQLChainParamQueries {
43+
return db.WithTx(tx)
44+
},
45+
)
46+
47+
return &Store{db: executor}
48+
}
49+
50+
// ValidateNetwork checks that the network stored in the chain_params table
51+
// matches the provided network. On the first call the network is persisted so
52+
// that subsequent restarts can detect an accidental network switch.
53+
func (s *Store) ValidateNetwork(ctx context.Context,
54+
net *chaincfg.Params) error {
55+
56+
network, err := normalizeNetworkName(net)
57+
if err != nil {
58+
return err
59+
}
60+
61+
return s.db.ExecTx(
62+
ctx, sqldb.WriteTxOpt(),
63+
func(tx SQLChainParamQueries) error {
64+
// Insert the network only if the chain_params table is
65+
// still empty. This is a no-op on every startup after
66+
// the first.
67+
err := tx.InsertChainNetwork(ctx, network)
68+
if err != nil {
69+
return fmt.Errorf("unable to set network in "+
70+
"chain_params: %w", err)
71+
}
72+
73+
// Read back whatever is stored. This is either the
74+
// value we just inserted (first startup) or a value
75+
// from a previous run.
76+
storedNetwork, err := tx.GetChainNetwork(ctx)
77+
if err != nil {
78+
return fmt.Errorf("unable to read network "+
79+
"from chain_params: %w", err)
80+
}
81+
82+
if storedNetwork != network {
83+
return fmt.Errorf("%w: the database was "+
84+
"previously used with network '%s', "+
85+
"but lnd is now configured for "+
86+
"network '%s'. To fix this, either "+
87+
"point lnd at a different database "+
88+
"or reconfigure lnd to use "+
89+
"network '%s'", ErrNetworkMismatch,
90+
storedNetwork, network, storedNetwork)
91+
}
92+
93+
return nil
94+
}, sqldb.NoOpReset,
95+
)
96+
}
97+
98+
// normalizeNetworkName returns the stable network identifier persisted in the
99+
// chain_params table.
100+
func normalizeNetworkName(net *chaincfg.Params) (string, error) {
101+
if net == nil {
102+
return "", fmt.Errorf("chain parameters must not be nil")
103+
}
104+
105+
network := lncfg.NormalizeNetwork(net.Name)
106+
if network == "" {
107+
return "", fmt.Errorf("chain parameters must define a network")
108+
}
109+
110+
return network, nil
111+
}
112+
113+
// Compile-time check that *sqlc.Queries implements SQLChainParamQueries.
114+
var _ SQLChainParamQueries = (*sqlc.Queries)(nil)

chainparams/store_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//go:build test_db_postgres || test_db_sqlite
2+
3+
package chainparams
4+
5+
import (
6+
"testing"
7+
8+
"github.com/btcsuite/btcd/chaincfg"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
// TestValidateNetworkMismatch verifies that ValidateNetwork persists the first
13+
// network and fails when a different network is used later.
14+
func TestValidateNetworkMismatch(t *testing.T) {
15+
t.Parallel()
16+
17+
store := NewStore(newTestDB(t))
18+
19+
// First call: persists regtest into the store.
20+
err := store.ValidateNetwork(t.Context(), &chaincfg.RegressionNetParams)
21+
require.NoError(t, err)
22+
23+
// Second call: a different network — must fail with ErrNetworkMismatch.
24+
err = store.ValidateNetwork(t.Context(), &chaincfg.SimNetParams)
25+
require.ErrorIs(t, err, ErrNetworkMismatch)
26+
}
27+
28+
// TestValidateNetworkSameNetwork verifies that ValidateNetwork succeeds when
29+
// called repeatedly with the same network (idempotent-match path).
30+
func TestValidateNetworkSameNetwork(t *testing.T) {
31+
t.Parallel()
32+
33+
store := NewStore(newTestDB(t))
34+
35+
// First call: persists the network.
36+
err := store.ValidateNetwork(t.Context(), &chaincfg.RegressionNetParams)
37+
require.NoError(t, err)
38+
39+
// Second call: same network again — reads the stored value and must
40+
// succeed (idempotent match).
41+
err = store.ValidateNetwork(t.Context(), &chaincfg.RegressionNetParams)
42+
require.NoError(t, err)
43+
}
44+
45+
// TestValidateNetworkNormalizesTestnet verifies that network aliases collapse
46+
// to the same persisted value.
47+
func TestValidateNetworkNormalizesTestnet(t *testing.T) {
48+
t.Parallel()
49+
50+
store := NewStore(newTestDB(t))
51+
52+
// First call: persists canonical testnet3 parameters.
53+
err := store.ValidateNetwork(t.Context(), &chaincfg.TestNet3Params)
54+
require.NoError(t, err)
55+
56+
// Second call: same logical network, different Name field — still
57+
// matches after normalization (not ErrNetworkMismatch).
58+
testnetAlias := chaincfg.TestNet3Params
59+
testnetAlias.Name = "testnet"
60+
61+
err = store.ValidateNetwork(t.Context(), &testnetAlias)
62+
require.NoError(t, err)
63+
}
64+
65+
// TestValidateNetworkRejectsEmptyName verifies that malformed network params
66+
// are rejected before touching the database.
67+
func TestValidateNetworkRejectsEmptyName(t *testing.T) {
68+
t.Parallel()
69+
70+
store := NewStore(newTestDB(t))
71+
72+
// Empty Params.Name — rejected before any database read or write.
73+
err := store.ValidateNetwork(t.Context(), &chaincfg.Params{})
74+
require.ErrorContains(t, err, "must define a network")
75+
}
76+
77+
// TestValidateNetworkRejectsNilParams verifies that callers provide network
78+
// parameters.
79+
func TestValidateNetworkRejectsNilParams(t *testing.T) {
80+
t.Parallel()
81+
82+
store := NewStore(newTestDB(t))
83+
84+
// Nil params — rejected before any database read or write.
85+
err := store.ValidateNetwork(t.Context(), nil)
86+
require.ErrorContains(t, err, "must not be nil")
87+
}

chainparams/test_postgres.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//go:build test_db_postgres && !test_db_sqlite
2+
3+
package chainparams
4+
5+
import (
6+
"testing"
7+
8+
"github.com/lightningnetwork/lnd/sqldb"
9+
)
10+
11+
// newTestDB creates a Postgres-backed BaseDB for use in unit tests.
12+
func newTestDB(t testing.TB) *sqldb.BaseDB {
13+
pgFixture := sqldb.NewTestPgFixture(
14+
t, sqldb.DefaultPostgresFixtureLifetime,
15+
)
16+
t.Cleanup(func() {
17+
pgFixture.TearDown(t)
18+
})
19+
20+
return sqldb.NewTestPostgresDB(t, pgFixture).GetBaseDB()
21+
}

chainparams/test_sqlite.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//go:build !test_db_postgres && test_db_sqlite
2+
3+
package chainparams
4+
5+
import (
6+
"testing"
7+
8+
"github.com/lightningnetwork/lnd/sqldb"
9+
)
10+
11+
// newTestDB creates a SQLite-backed BaseDB for use in unit tests.
12+
func newTestDB(t testing.TB) *sqldb.BaseDB {
13+
return sqldb.NewTestSqliteDB(t).GetBaseDB()
14+
}

config_builder.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/lightninglabs/neutrino/pushtx"
3131
"github.com/lightningnetwork/lnd/blockcache"
3232
"github.com/lightningnetwork/lnd/chainntnfs"
33+
"github.com/lightningnetwork/lnd/chainparams"
3334
"github.com/lightningnetwork/lnd/chainreg"
3435
"github.com/lightningnetwork/lnd/channeldb"
3536
"github.com/lightningnetwork/lnd/clock"
@@ -1237,8 +1238,48 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
12371238

12381239
// With the DB ready and migrations applied, we can now create
12391240
// the base DB and transaction executor for the native SQL
1240-
// invoice store.
1241+
// stores.
12411242
baseDB := dbs.NativeSQLStore.GetBaseDB()
1243+
1244+
// Validate that the database was initialised for the same
1245+
// network as the currently active network. This catches cases
1246+
// where a user accidentally reuses a database (e.g. via a
1247+
// postgres DSN or by copying a file) across different networks
1248+
// (e.g. mainnet → testnet), which would otherwise lead to
1249+
// silent data corruption. This check applies to all native SQL
1250+
// backends.
1251+
//
1252+
// If migrations are explicitly skipped, we also skip this check
1253+
// because the chain_params table may not exist yet. We check
1254+
// only the active backend's flag since only one backend is
1255+
// used at a time.
1256+
var skipMigrations bool
1257+
switch d.cfg.DB.Backend {
1258+
case lncfg.SqliteBackend:
1259+
skipMigrations = d.cfg.DB.Sqlite.SkipMigrations
1260+
case lncfg.PostgresBackend:
1261+
skipMigrations = d.cfg.DB.Postgres.SkipMigrations
1262+
}
1263+
1264+
if !skipMigrations {
1265+
chainParamsStore := chainparams.NewStore(baseDB)
1266+
err = chainParamsStore.ValidateNetwork(
1267+
ctx, d.cfg.ActiveNetParams.Params,
1268+
)
1269+
if err != nil {
1270+
cleanUp()
1271+
d.logger.Error(err)
1272+
1273+
return nil, nil, err
1274+
}
1275+
} else {
1276+
d.logger.Warnf("Database network validation skipped " +
1277+
"because SkipMigrations is enabled; " +
1278+
"cross-network database reuse would not be " +
1279+
"detected.")
1280+
}
1281+
1282+
// Create the invoice store.
12421283
invoiceExecutor := sqldb.NewTransactionExecutor(
12431284
baseDB, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
12441285
return baseDB.WithTx(tx)
@@ -1251,6 +1292,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
12511292

12521293
dbs.InvoiceDB = sqlInvoiceDB
12531294

1295+
// Create the graph store.
12541296
graphExecutor := sqldb.NewTransactionExecutor(
12551297
baseDB, func(tx *sql.Tx) graphdb.SQLQueries {
12561298
return baseDB.WithTx(tx)
@@ -1272,6 +1314,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
12721314
return nil, nil, err
12731315
}
12741316

1317+
// Create the payments store.
12751318
paymentsExecutor := sqldb.NewTransactionExecutor(
12761319
baseDB, func(tx *sql.Tx) paymentsdb.SQLQueries {
12771320
return baseDB.WithTx(tx)

docs/release-notes/release-notes-0.21.0.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,15 @@
307307

308308
## Database
309309

310-
* Freeze the [graph SQL migration
310+
* [Prevent silent data corruption](https://github.com/lightningnetwork/lnd/pull/10684)
311+
when reusing the same database across different Bitcoin networks. On first
312+
startup the active network is persisted in a new `chain_params` table; on
313+
every subsequent restart lnd compares the stored value against the configured
314+
network and refuses to start if they differ, printing a clear error message
315+
with remediation steps. This safeguard applies to both the PostgreSQL and
316+
SQLite native-SQL backends when running with `--db.use-native-sql`.
317+
318+
* Freeze the [graph SQL migration
311319
code](https://github.com/lightningnetwork/lnd/pull/10338) to prevent the
312320
need for maintenance as the sqlc code evolves.
313321
* Prepare the graph DB for handling gossip V2

itest/list_on_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,10 @@ var allTestCases = []*lntest.TestCase{
791791
Name: "estimate on chain fee auto selected inputs",
792792
TestFunc: testEstimateOnChainFeeAutoSelectedInputs,
793793
},
794+
{
795+
Name: "postgres network separation",
796+
TestFunc: testPostgresNetworkSeparation,
797+
},
794798
}
795799

796800
// appendPrefixed is used to add a prefix to each test name in the subtests

0 commit comments

Comments
 (0)