Skip to content

Commit 472f245

Browse files
Validate exact expected error in signed header verification tests (#1165)
<!-- Please read and fill out this form before submitting your PR. Please make sure you have reviewed our contributors guide before submitting your first PR. --> ## Overview Closes: #1049 Stacked on top of #1162 <!-- Please provide an explanation of the PR, including the appropriate context, background, goal, and rationale. If there is an issue with this information, please provide a tl;dr and link the issue. --> ## Checklist <!-- Please complete the checklist to ensure that the PR is ready to be reviewed. IMPORTANT: PRs should be left in Draft until the below checklist is completed. --> - [x] New and updated code has appropriate documentation - [x] New and updated code has new and/or updated testing - [x] Required CI checks are passing - [ ] Visual proof for any user facing features like CLI or documentation updates - [x] Linked issues closed with keywords --------- Co-authored-by: Matthew Sevey <mjsevey@gmail.com>
1 parent bd7664c commit 472f245

3 files changed

Lines changed: 77 additions & 30 deletions

File tree

types/header.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package types
22

33
import (
44
"encoding"
5-
"errors"
65
"fmt"
76
"time"
87

@@ -85,7 +84,7 @@ func (h *Header) Verify(untrst header.Header) error {
8584
}
8685
// sanity check fields
8786
if err := verifyNewHeaderAndVals(h, untrstH); err != nil {
88-
return &header.VerifyError{Reason: err}
87+
return err
8988
}
9089

9190
// Check the validator hashes are the same in the case headers are adjacent
@@ -132,16 +131,22 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error {
132131
}
133132

134133
if !untrusted.Time().After(trusted.Time()) {
135-
return fmt.Errorf("expected new header time %v to be after old header time %v",
136-
untrusted.Time(),
137-
trusted.Time())
134+
return fmt.Errorf("%w: %w",
135+
ErrNewHeaderTimeBeforeOldHeaderTime,
136+
fmt.Errorf("expected new header time %v to be after %v",
137+
untrusted.Time(),
138+
trusted.Time(),
139+
),
140+
)
138141
}
139142

140143
if !untrusted.Time().Before(time.Now().Add(maxClockDrift)) {
141-
return fmt.Errorf("new header has a time from the future %v (now: %v; max clock drift: %v)",
144+
return fmt.Errorf("%w: new header time %v (now: %v; max clock drift: %v)",
145+
ErrNewHeaderTimeFromFuture,
142146
untrusted.Time(),
143147
time.Now(),
144-
maxClockDrift)
148+
maxClockDrift,
149+
)
145150
}
146151

147152
return nil
@@ -150,7 +155,7 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error {
150155
// ValidateBasic performs basic validation of a header.
151156
func (h *Header) ValidateBasic() error {
152157
if len(h.ProposerAddress) == 0 {
153-
return errors.New("no proposer address")
158+
return ErrNoProposerAddress
154159
}
155160

156161
return nil

types/signed_header.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ func (sH *SignedHeader) IsZero() bool {
1717
return sH == nil
1818
}
1919

20+
var (
21+
ErrAggregatorSetHashMismatch = errors.New("aggregator set hash in signed header and hash of validator set do not match")
22+
ErrSignatureVerificationFailed = errors.New("signature verification failed")
23+
ErrNoProposerAddress = errors.New("no proposer address")
24+
ErrLastHeaderHashMismatch = errors.New("last header hash mismatch")
25+
ErrLastCommitHashMismatch = errors.New("last commit hash mismatch")
26+
ErrNewHeaderTimeBeforeOldHeaderTime = errors.New("new header has time before old header time")
27+
ErrNewHeaderTimeFromFuture = errors.New("new header has time from future")
28+
)
29+
2030
func (sH *SignedHeader) Verify(untrst header.Header) error {
2131
// Explicit type checks are required due to embedded Header which also does the explicit type check
2232
untrstH, ok := untrst.(*SignedHeader)
@@ -44,13 +54,19 @@ func (sH *SignedHeader) Verify(untrst header.Header) error {
4454
sHHash := sH.Header.Hash()
4555
if !bytes.Equal(untrstH.LastHeaderHash[:], sHHash) {
4656
return &header.VerifyError{
47-
Reason: fmt.Errorf("last header hash %v does not match hash of previous header %v", untrstH.LastHeaderHash[:], sHHash),
57+
Reason: fmt.Errorf("%w: expected %v, but got %v",
58+
ErrLastHeaderHashMismatch,
59+
untrstH.LastHeaderHash[:], sHHash,
60+
),
4861
}
4962
}
5063
sHLastCommitHash := sH.Commit.GetCommitHash(&untrstH.Header, sH.ProposerAddress)
5164
if !bytes.Equal(untrstH.LastCommitHash[:], sHLastCommitHash) {
5265
return &header.VerifyError{
53-
Reason: fmt.Errorf("last commit hash %v does not match hash of previous header %v", untrstH.LastCommitHash[:], sHHash),
66+
Reason: fmt.Errorf("%w: expected %v, but got %v",
67+
ErrLastCommitHashMismatch,
68+
untrstH.LastCommitHash[:], sHHash,
69+
),
5470
}
5571
}
5672
return nil
@@ -78,7 +94,7 @@ func (h *SignedHeader) ValidateBasic() error {
7894
}
7995

8096
if !bytes.Equal(h.Validators.Hash(), h.AggregatorsHash[:]) {
81-
return errors.New("aggregator set hash in signed header and hash of validator set do not match")
97+
return ErrAggregatorSetHashMismatch
8298
}
8399

84100
// Make sure there is exactly one signature
@@ -94,7 +110,7 @@ func (h *SignedHeader) ValidateBasic() error {
94110
return errors.New("signature verification failed, unable to marshal header")
95111
}
96112
if !pubKey.VerifySignature(msg, signature) {
97-
return errors.New("signature verification failed")
113+
return ErrSignatureVerificationFailed
98114
}
99115

100116
return nil

types/signed_header_test.go

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/celestiaorg/go-header"
89
"github.com/stretchr/testify/assert"
910
"github.com/stretchr/testify/require"
1011
)
@@ -15,37 +16,46 @@ func TestVerify(t *testing.T) {
1516
time.Sleep(time.Second)
1617
untrustedAdj, err := GetNextRandomHeader(trusted, privKey)
1718
require.NoError(t, err)
19+
fakeAggregatorsHash := header.Hash(GetRandomBytes(32))
20+
fakeLastHeaderHash := header.Hash(GetRandomBytes(32))
21+
fakeLastCommitHash := header.Hash(GetRandomBytes(32))
1822
tests := []struct {
1923
prepare func() (*SignedHeader, bool)
20-
err bool
24+
err error
2125
}{
2226
{
2327
prepare: func() (*SignedHeader, bool) { return untrustedAdj, false },
24-
err: false,
28+
err: nil,
2529
},
2630
{
2731
prepare: func() (*SignedHeader, bool) {
2832
untrusted := *untrustedAdj
29-
untrusted.AggregatorsHash = GetRandomBytes(32)
30-
return &untrusted, true
33+
untrusted.AggregatorsHash = fakeAggregatorsHash
34+
return &untrusted, false
35+
},
36+
err: &header.VerifyError{
37+
Reason: ErrAggregatorSetHashMismatch,
3138
},
32-
err: true,
3339
},
3440
{
3541
prepare: func() (*SignedHeader, bool) {
3642
untrusted := *untrustedAdj
37-
untrusted.LastHeaderHash = GetRandomBytes(32)
43+
untrusted.LastHeaderHash = fakeLastHeaderHash
3844
return &untrusted, true
3945
},
40-
err: true,
46+
err: &header.VerifyError{
47+
Reason: ErrLastHeaderHashMismatch,
48+
},
4149
},
4250
{
4351
prepare: func() (*SignedHeader, bool) {
4452
untrusted := *untrustedAdj
45-
untrusted.LastCommitHash = GetRandomBytes(32)
53+
untrusted.LastCommitHash = fakeLastCommitHash
4654
return &untrusted, true
4755
},
48-
err: true,
56+
err: &header.VerifyError{
57+
Reason: ErrLastCommitHashMismatch,
58+
},
4959
},
5060
{
5161
prepare: func() (*SignedHeader, bool) {
@@ -54,47 +64,57 @@ func TestVerify(t *testing.T) {
5464
untrusted.Header.BaseHeader.Height++
5565
return &untrusted, true
5666
},
57-
err: false, // Accepts non-adjacent headers
67+
err: nil, // Accepts non-adjacent headers
5868
},
5969
{
6070
prepare: func() (*SignedHeader, bool) {
6171
untrusted := *untrustedAdj
6272
untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Truncate(time.Hour).UnixNano())
6373
return &untrusted, true
6474
},
65-
err: true,
75+
err: &header.VerifyError{
76+
Reason: ErrNewHeaderTimeBeforeOldHeaderTime,
77+
},
6678
},
6779
{
6880
prepare: func() (*SignedHeader, bool) {
6981
untrusted := *untrustedAdj
7082
untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Add(time.Minute).UnixNano())
7183
return &untrusted, true
7284
},
73-
err: true,
85+
err: &header.VerifyError{
86+
Reason: ErrNewHeaderTimeFromFuture,
87+
},
7488
},
7589
{
7690
prepare: func() (*SignedHeader, bool) {
7791
untrusted := *untrustedAdj
7892
untrusted.BaseHeader.ChainID = "toaster"
7993
return &untrusted, false // Signature verification should fail
8094
},
81-
err: true,
95+
err: &header.VerifyError{
96+
Reason: ErrSignatureVerificationFailed,
97+
},
8298
},
8399
{
84100
prepare: func() (*SignedHeader, bool) {
85101
untrusted := *untrustedAdj
86102
untrusted.Version.App = untrusted.Version.App + 1
87103
return &untrusted, false // Signature verification should fail
88104
},
89-
err: true,
105+
err: &header.VerifyError{
106+
Reason: ErrSignatureVerificationFailed,
107+
},
90108
},
91109
{
92110
prepare: func() (*SignedHeader, bool) {
93111
untrusted := *untrustedAdj
94112
untrusted.ProposerAddress = nil
95113
return &untrusted, true
96114
},
97-
err: true,
115+
err: &header.VerifyError{
116+
Reason: ErrNoProposerAddress,
117+
},
98118
},
99119
}
100120

@@ -107,11 +127,17 @@ func TestVerify(t *testing.T) {
107127
preparedHeader.Commit = *commit
108128
}
109129
err = trusted.Verify(preparedHeader)
110-
if test.err {
111-
assert.Error(t, err)
112-
} else {
130+
if test.err == nil {
113131
assert.NoError(t, err)
132+
return
133+
}
134+
if err == nil {
135+
t.Errorf("expected err: %v, got nil", test.err)
136+
return
114137
}
138+
reason := err.(*header.VerifyError).Reason
139+
testReason := test.err.(*header.VerifyError).Reason
140+
assert.ErrorIs(t, reason, testReason)
115141
})
116142
}
117143
}

0 commit comments

Comments
 (0)