Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions types/validator_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,29 +309,40 @@ func (vals *ValidatorSet) Size() int {
return len(vals.Validators)
}

// Forces recalculation of the set's total voting power.
// Panics if total voting power is bigger than MaxTotalVotingPower.
func (vals *ValidatorSet) updateTotalVotingPower() {
// updateTotalVotingPower forces recalculation of the set's total voting power.
// Returns an error if total voting power exceeds MaxTotalVotingPower.
func (vals *ValidatorSet) updateTotalVotingPower() error {
sum := int64(0)
for _, val := range vals.Validators {
// mind overflow
sum = safeAddClip(sum, val.VotingPower)
if sum > MaxTotalVotingPower {
panic(fmt.Sprintf(
"Total voting power should be guarded to not exceed %v; got: %v",
MaxTotalVotingPower,
sum))
return fmt.Errorf("total voting power %d exceeds maximum %d", sum, MaxTotalVotingPower)
}
}

vals.totalVotingPower = sum
return nil
}

// TotalVotingPowerSafe returns the sum of the voting powers of all validators,
// or an error if the total exceeds MaxTotalVotingPower.
func (vals *ValidatorSet) TotalVotingPowerSafe() (int64, error) {
if vals.totalVotingPower == 0 {
if err := vals.updateTotalVotingPower(); err != nil {
return 0, err
}
}
return vals.totalVotingPower, nil
}

// TotalVotingPower returns the sum of the voting powers of all validators.
// It recomputes the total voting power if required.
func (vals *ValidatorSet) TotalVotingPower() int64 {
if vals.totalVotingPower == 0 {
vals.updateTotalVotingPower()
if err := vals.updateTotalVotingPower(); err != nil {
panic(err)
}
}
return vals.totalVotingPower
}
Expand Down Expand Up @@ -665,7 +676,9 @@ func (vals *ValidatorSet) updateWithChangeSet(changes []*Validator, allowDeletes
// Should go after additions.
vals.checkAllKeysHaveSameType()

vals.updateTotalVotingPower() // will panic if total voting power > MaxTotalVotingPower
if err = vals.updateTotalVotingPower(); err != nil {
panic(err)
}

// Scale and center.
vals.RescalePriorities(PriorityWindowSizeFactor * vals.TotalVotingPower())
Expand Down Expand Up @@ -935,8 +948,10 @@ func ValidatorSetFromProto(vp *cmtproto.ValidatorSet) (*ValidatorSet, error) {
// power hence we need to recompute it.
// FIXME: We should look to remove TotalVotingPower from proto or add it in the validators hash
// so we don't have to do this
vals.TotalVotingPower()

// NOTE: Use TotalVotingPowerSafe to return error instead of panicking on invalid input.
if _, err := vals.TotalVotingPowerSafe(); err != nil {
return nil, err
}
return vals, vals.ValidateBasic()
}

Expand All @@ -960,7 +975,9 @@ func ValidatorSetFromExistingValidators(valz []*Validator) (*ValidatorSet, error
}
vals.checkAllKeysHaveSameType()
vals.Proposer = vals.findPreviousProposer()
vals.updateTotalVotingPower()
if err := vals.updateTotalVotingPower(); err != nil {
return nil, err
}
sort.Sort(ValidatorsByVotingPower(vals.Validators))
return vals, nil
}
Expand Down
28 changes: 24 additions & 4 deletions types/validator_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"testing"
"testing/quick"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/cometbft/cometbft/crypto"
"github.com/cometbft/cometbft/crypto/ed25519"
cryptoenc "github.com/cometbft/cometbft/crypto/encoding"
"github.com/cometbft/cometbft/crypto/sr25519"
cmtmath "github.com/cometbft/cometbft/libs/math"
cmtrand "github.com/cometbft/cometbft/libs/rand"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidatorSetBasic(t *testing.T) {
Expand Down Expand Up @@ -471,6 +471,25 @@ func TestValidatorSetTotalVotingPowerPanicsOnOverflow(t *testing.T) {
assert.Panics(t, shouldPanic)
}

func TestValidatorSetFromProtoReturnsErrorOnOverflow(t *testing.T) {
// ValidatorSetFromProto should return an error instead of panicking when total voting power exceeds MaxTotalVotingPower.
pubKey := ed25519.GenPrivKey().PubKey()
pkProto, err := cryptoenc.PubKeyToProto(pubKey)
require.NoError(t, err)

protoVals := &cmtproto.ValidatorSet{
Validators: []*cmtproto.Validator{
{Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0},
{Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0},
},
Proposer: &cmtproto.Validator{Address: pubKey.Address(), PubKey: pkProto, VotingPower: math.MaxInt64, ProposerPriority: 0},
}

_, err = ValidatorSetFromProto(protoVals)
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum")
}

func TestAvgProposerPriority(t *testing.T) {
// Create Validator set without calling IncrementProposerPriority:
tcs := []struct {
Expand Down Expand Up @@ -832,7 +851,8 @@ func verifyValidatorSet(t *testing.T, valSet *ValidatorSet) {

// verify that the set's total voting power has been updated
tvp := valSet.totalVotingPower
valSet.updateTotalVotingPower()
err := valSet.updateTotalVotingPower()
require.NoError(t, err)
expectedTvp := valSet.TotalVotingPower()
assert.Equal(t, expectedTvp, tvp,
"expected TVP %d. Got %d, valSet=%s", expectedTvp, tvp, valSet)
Expand Down
Loading