diff --git a/types/validator_set.go b/types/validator_set.go index 62b53dc8945..8856324b864 100644 --- a/types/validator_set.go +++ b/types/validator_set.go @@ -309,29 +309,40 @@ func (vals *ValidatorSet) Size() int { return len(vals.Validators) } -// Forces recalculation of the set's total voting power. -// Panics if total voting power is bigger than MaxTotalVotingPower. -func (vals *ValidatorSet) updateTotalVotingPower() { +// updateTotalVotingPower forces recalculation of the set's total voting power. +// Returns an error if total voting power exceeds MaxTotalVotingPower. +func (vals *ValidatorSet) updateTotalVotingPower() error { sum := int64(0) for _, val := range vals.Validators { // mind overflow sum = safeAddClip(sum, val.VotingPower) if sum > MaxTotalVotingPower { - panic(fmt.Sprintf( - "Total voting power should be guarded to not exceed %v; got: %v", - MaxTotalVotingPower, - sum)) + return fmt.Errorf("total voting power %d exceeds maximum %d", sum, MaxTotalVotingPower) } } vals.totalVotingPower = sum + return nil +} + +// TotalVotingPowerSafe returns the sum of the voting powers of all validators, +// or an error if the total exceeds MaxTotalVotingPower. +func (vals *ValidatorSet) TotalVotingPowerSafe() (int64, error) { + if vals.totalVotingPower == 0 { + if err := vals.updateTotalVotingPower(); err != nil { + return 0, err + } + } + return vals.totalVotingPower, nil } // TotalVotingPower returns the sum of the voting powers of all validators. // It recomputes the total voting power if required. func (vals *ValidatorSet) TotalVotingPower() int64 { if vals.totalVotingPower == 0 { - vals.updateTotalVotingPower() + if err := vals.updateTotalVotingPower(); err != nil { + panic(err) + } } return vals.totalVotingPower } @@ -665,7 +676,9 @@ func (vals *ValidatorSet) updateWithChangeSet(changes []*Validator, allowDeletes // Should go after additions. vals.checkAllKeysHaveSameType() - vals.updateTotalVotingPower() // will panic if total voting power > MaxTotalVotingPower + if err = vals.updateTotalVotingPower(); err != nil { + panic(err) + } // Scale and center. vals.RescalePriorities(PriorityWindowSizeFactor * vals.TotalVotingPower()) @@ -935,8 +948,10 @@ func ValidatorSetFromProto(vp *cmtproto.ValidatorSet) (*ValidatorSet, error) { // power hence we need to recompute it. // FIXME: We should look to remove TotalVotingPower from proto or add it in the validators hash // so we don't have to do this - vals.TotalVotingPower() - + // NOTE: Use TotalVotingPowerSafe to return error instead of panicking on invalid input. + if _, err := vals.TotalVotingPowerSafe(); err != nil { + return nil, err + } return vals, vals.ValidateBasic() } @@ -960,7 +975,9 @@ func ValidatorSetFromExistingValidators(valz []*Validator) (*ValidatorSet, error } vals.checkAllKeysHaveSameType() vals.Proposer = vals.findPreviousProposer() - vals.updateTotalVotingPower() + if err := vals.updateTotalVotingPower(); err != nil { + return nil, err + } sort.Sort(ValidatorsByVotingPower(vals.Validators)) return vals, nil } diff --git a/types/validator_set_test.go b/types/validator_set_test.go index 2a281f12198..e931a147b22 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -9,15 +9,15 @@ import ( "testing" "testing/quick" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/cometbft/cometbft/crypto" "github.com/cometbft/cometbft/crypto/ed25519" + cryptoenc "github.com/cometbft/cometbft/crypto/encoding" "github.com/cometbft/cometbft/crypto/sr25519" cmtmath "github.com/cometbft/cometbft/libs/math" cmtrand "github.com/cometbft/cometbft/libs/rand" cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestValidatorSetBasic(t *testing.T) { @@ -471,6 +471,25 @@ func TestValidatorSetTotalVotingPowerPanicsOnOverflow(t *testing.T) { assert.Panics(t, shouldPanic) } +func TestValidatorSetFromProtoReturnsErrorOnOverflow(t *testing.T) { + // ValidatorSetFromProto should return an error instead of panicking when total voting power exceeds MaxTotalVotingPower. + pubKey := ed25519.GenPrivKey().PubKey() + pkProto, err := cryptoenc.PubKeyToProto(pubKey) + require.NoError(t, err) + + protoVals := &cmtproto.ValidatorSet{ + Validators: []*cmtproto.Validator{ + {Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0}, + {Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0}, + }, + Proposer: &cmtproto.Validator{Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0}, + } + + _, err = ValidatorSetFromProto(protoVals) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") +} + func TestAvgProposerPriority(t *testing.T) { // Create Validator set without calling IncrementProposerPriority: tcs := []struct { @@ -832,7 +851,8 @@ func verifyValidatorSet(t *testing.T, valSet *ValidatorSet) { // verify that the set's total voting power has been updated tvp := valSet.totalVotingPower - valSet.updateTotalVotingPower() + err := valSet.updateTotalVotingPower() + require.NoError(t, err) expectedTvp := valSet.TotalVotingPower() assert.Equal(t, expectedTvp, tvp, "expected TVP %d. Got %d, valSet=%s", expectedTvp, tvp, valSet)