Skip to content

Commit 5317f24

Browse files
committed
Review feedback
1 parent 2a7a93f commit 5317f24

4 files changed

Lines changed: 302 additions & 237 deletions

File tree

pkg/signer/aws/signer.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ func (s *KmsSigner) fetchPublicKey(ctx context.Context) error {
116116
if err != nil {
117117
return fmt.Errorf("KMS GetPublicKey failed: %w", err)
118118
}
119-
119+
if out.KeyId == nil || *out.KeyId != s.keyID {
120+
return fmt.Errorf("KMS returned unexpected key ID: %v", out.KeyId)
121+
}
120122
// AWS returns the public key as a DER-encoded X.509 SubjectPublicKeyInfo.
121123
pub, err := x509.ParsePKIXPublicKey(out.PublicKey)
122124
if err != nil {
@@ -187,6 +189,9 @@ func (s *KmsSigner) Sign(ctx context.Context, message []byte) ([]byte, error) {
187189
}
188190
continue
189191
}
192+
if out.KeyId == nil || *out.KeyId != s.keyID {
193+
return nil, fmt.Errorf("KMS returned unexpected key ID: %v", out.KeyId)
194+
}
190195

191196
return out.Signature, nil
192197
}

pkg/signer/aws/signer_test.go

Lines changed: 113 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@ import (
1515
"github.com/stretchr/testify/require"
1616
)
1717

18+
const awsTestKeyID = "arn:aws:kms:us-east-1:123456789012:key/test-key-id"
19+
1820
// mockKMSClient is a test double implementing KMSClient.
1921
type mockKMSClient struct {
2022
pubKeyDER []byte
2123
signFn func(ctx context.Context, params *kms.SignInput) (*kms.SignOutput, error)
2224
getPubFn func(ctx context.Context, params *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
25+
keyID string
2326
}
2427

2528
func (m *mockKMSClient) Sign(ctx context.Context, params *kms.SignInput, _ ...func(*kms.Options)) (*kms.SignOutput, error) {
2629
if m.signFn != nil {
2730
return m.signFn(ctx, params)
2831
}
29-
return &kms.SignOutput{Signature: []byte("mock-signature")}, nil
32+
return &kms.SignOutput{Signature: []byte("mock-signature"), KeyId: &m.keyID}, nil
3033
}
3134

3235
func (m *mockKMSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, _ ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) {
@@ -35,26 +38,15 @@ func (m *mockKMSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicK
3538
}
3639
return &kms.GetPublicKeyOutput{
3740
PublicKey: m.pubKeyDER,
41+
KeyId: &m.keyID,
3842
}, nil
3943
}
4044

41-
// generateTestEd25519DER generates an Ed25519 key pair and returns
42-
// the public key in DER (X.509 SubjectPublicKeyInfo) format.
43-
func generateTestEd25519DER(t *testing.T) (ed25519.PublicKey, []byte) {
44-
t.Helper()
45-
pub, _, err := ed25519.GenerateKey(nil)
46-
require.NoError(t, err)
47-
48-
der, err := x509.MarshalPKIXPublicKey(pub)
49-
require.NoError(t, err)
50-
return pub, der
51-
}
52-
5345
func TestNewKmsSignerFromClient_Success(t *testing.T) {
5446
_, der := generateTestEd25519DER(t)
5547

56-
mock := &mockKMSClient{pubKeyDER: der}
57-
s, err := kmsSignerFromClient(context.Background(), mock, "arn:aws:kms:us-east-1:123456789012:key/test-key-id", nil)
48+
mock := &mockKMSClient{pubKeyDER: der, keyID: awsTestKeyID}
49+
s, err := kmsSignerFromClient(t.Context(), mock, awsTestKeyID, nil)
5850
require.NoError(t, err)
5951
require.NotNil(t, s)
6052

@@ -69,16 +61,31 @@ func TestNewKmsSignerFromClient_Success(t *testing.T) {
6961
assert.Len(t, addr, 32) // sha256 output
7062
}
7163

72-
func TestNewKmsSignerFromClient_EmptyKeyID(t *testing.T) {
73-
_, err := kmsSignerFromClient(context.Background(), &mockKMSClient{}, "", nil)
74-
require.Error(t, err)
75-
assert.Contains(t, err.Error(), "key ID is required")
76-
}
64+
func TestNewKmsSignerFromClient_Validation(t *testing.T) {
65+
specs := map[string]struct {
66+
client KMSClient
67+
keyID string
68+
errSubstr string
69+
}{
70+
"empty key id": {
71+
client: &mockKMSClient{},
72+
keyID: "",
73+
errSubstr: "key ID is required",
74+
},
75+
"nil client": {
76+
client: nil,
77+
keyID: awsTestKeyID,
78+
errSubstr: "client is required",
79+
},
80+
}
7781

78-
func TestNewKmsSignerFromClient_NilClient(t *testing.T) {
79-
_, err := kmsSignerFromClient(context.Background(), nil, "test-key", nil)
80-
require.Error(t, err)
81-
assert.Contains(t, err.Error(), "client is required")
82+
for name, spec := range specs {
83+
t.Run(name, func(t *testing.T) {
84+
_, err := kmsSignerFromClient(t.Context(), spec.client, spec.keyID, nil)
85+
require.Error(t, err)
86+
assert.Contains(t, err.Error(), spec.errSubstr)
87+
})
88+
}
8289
}
8390

8491
func TestNewKmsSignerFromClient_GetPublicKeyFails(t *testing.T) {
@@ -88,7 +95,7 @@ func TestNewKmsSignerFromClient_GetPublicKeyFails(t *testing.T) {
8895
},
8996
}
9097

91-
_, err := kmsSignerFromClient(context.Background(), mock, "test-key", nil)
98+
_, err := kmsSignerFromClient(t.Context(), mock, "test-key", nil)
9299
require.Error(t, err)
93100
assert.Contains(t, err.Error(), "access denied")
94101
}
@@ -98,90 +105,100 @@ func TestSign_Success(t *testing.T) {
98105

99106
expectedSig := []byte("test-signature-bytes")
100107
mock := &mockKMSClient{
108+
keyID: awsTestKeyID,
101109
pubKeyDER: der,
102110
signFn: func(_ context.Context, params *kms.SignInput) (*kms.SignOutput, error) {
111+
keyID := awsTestKeyID
103112
assert.Equal(t, types.MessageTypeRaw, params.MessageType)
104113
assert.Equal(t, types.SigningAlgorithmSpecEd25519Sha512, params.SigningAlgorithm)
105-
return &kms.SignOutput{Signature: expectedSig}, nil
114+
return &kms.SignOutput{Signature: expectedSig, KeyId: &keyID}, nil
106115
},
107116
}
108117

109-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", nil)
118+
s, err := kmsSignerFromClient(t.Context(), mock, awsTestKeyID, nil)
110119
require.NoError(t, err)
111120

112-
sig, err := s.Sign(context.Background(), []byte("hello world"))
121+
sig, err := s.Sign(t.Context(), []byte("hello world"))
113122
require.NoError(t, err)
114123
assert.Equal(t, expectedSig, sig)
115124
}
116125

117-
func TestSign_KMSFailure(t *testing.T) {
118-
_, der := generateTestEd25519DER(t)
119-
120-
var calls int32
121-
mock := &mockKMSClient{
122-
pubKeyDER: der,
123-
signFn: func(_ context.Context, _ *kms.SignInput) (*kms.SignOutput, error) {
124-
atomic.AddInt32(&calls, 1)
125-
return nil, &smithy.GenericAPIError{Code: "ThrottlingException", Message: "rate limit"}
126+
func TestSign_RetryBehavior(t *testing.T) {
127+
specs := map[string]struct {
128+
opts *Options
129+
signErr error
130+
errSubstr string
131+
expectedCall int32
132+
}{
133+
"retryable uses default retries": {
134+
opts: nil,
135+
signErr: &smithy.GenericAPIError{Code: "ThrottlingException", Message: "rate limit"},
136+
errSubstr: "AWS KMS sign failed",
137+
expectedCall: 4,
126138
},
127-
}
128-
129-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", nil)
130-
require.NoError(t, err)
131-
132-
_, err = s.Sign(context.Background(), []byte("hello world"))
133-
require.Error(t, err)
134-
assert.Contains(t, err.Error(), "AWS KMS sign failed")
135-
assert.Equal(t, int32(4), atomic.LoadInt32(&calls), "default retries should make 4 attempts")
136-
}
137-
138-
func TestSign_MaxRetriesZero_DisablesRetries(t *testing.T) {
139-
_, der := generateTestEd25519DER(t)
140-
141-
var calls int32
142-
mock := &mockKMSClient{
143-
pubKeyDER: der,
144-
signFn: func(_ context.Context, _ *kms.SignInput) (*kms.SignOutput, error) {
145-
atomic.AddInt32(&calls, 1)
146-
return nil, &smithy.GenericAPIError{Code: "ThrottlingException", Message: "rate limit"}
139+
"max retries zero disables retries": {
140+
opts: &Options{MaxRetries: 0},
141+
signErr: &smithy.GenericAPIError{Code: "ThrottlingException", Message: "rate limit"},
142+
errSubstr: "AWS KMS sign failed",
143+
expectedCall: 1,
144+
},
145+
"non retryable fails fast": {
146+
opts: &Options{MaxRetries: 3},
147+
signErr: fmt.Errorf("access denied"),
148+
errSubstr: "non-retryable",
149+
expectedCall: 1,
147150
},
148151
}
149152

150-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", &Options{MaxRetries: 0})
151-
require.NoError(t, err)
152-
153-
_, err = s.Sign(context.Background(), []byte("hello world"))
154-
require.Error(t, err)
155-
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "max retries 0 should only make one attempt")
153+
for name, spec := range specs {
154+
t.Run(name, func(t *testing.T) {
155+
_, der := generateTestEd25519DER(t)
156+
157+
var calls int32
158+
signer := newTestSigner(t, &mockKMSClient{
159+
keyID: awsTestKeyID,
160+
pubKeyDER: der,
161+
signFn: func(_ context.Context, _ *kms.SignInput) (*kms.SignOutput, error) {
162+
atomic.AddInt32(&calls, 1)
163+
return nil, spec.signErr
164+
},
165+
}, spec.opts)
166+
167+
_, err := signer.Sign(t.Context(), []byte("hello world"))
168+
require.Error(t, err)
169+
assert.Contains(t, err.Error(), spec.errSubstr)
170+
assert.Equal(t, spec.expectedCall, atomic.LoadInt32(&calls))
171+
})
172+
}
156173
}
157174

158-
func TestSign_NonRetryableError_NoRetries(t *testing.T) {
175+
func TestSign_KeyIDMismatch_ReturnsError(t *testing.T) {
159176
_, der := generateTestEd25519DER(t)
160177

161-
var calls int32
178+
unexpectedKeyID := "other-key"
162179
mock := &mockKMSClient{
180+
keyID: awsTestKeyID,
163181
pubKeyDER: der,
164182
signFn: func(_ context.Context, _ *kms.SignInput) (*kms.SignOutput, error) {
165-
atomic.AddInt32(&calls, 1)
166-
return nil, fmt.Errorf("access denied")
183+
return &kms.SignOutput{
184+
Signature: []byte("test-signature-bytes"),
185+
KeyId: &unexpectedKeyID,
186+
}, nil
167187
},
168188
}
169189

170-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", &Options{MaxRetries: 3})
171-
require.NoError(t, err)
190+
s := newTestSigner(t, mock, nil)
172191

173-
_, err = s.Sign(context.Background(), []byte("hello world"))
192+
_, err := s.Sign(t.Context(), []byte("hello world"))
174193
require.Error(t, err)
175-
assert.Contains(t, err.Error(), "non-retryable")
176-
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "non-retryable errors should fail fast")
194+
assert.Contains(t, err.Error(), "unexpected key ID")
177195
}
178196

179197
func TestGetPublic_Cached(t *testing.T) {
180198
pub, der := generateTestEd25519DER(t)
181199

182-
mock := &mockKMSClient{pubKeyDER: der}
183-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", nil)
184-
require.NoError(t, err)
200+
mock := &mockKMSClient{pubKeyDER: der, keyID: awsTestKeyID}
201+
s := newTestSigner(t, mock, nil)
185202

186203
cryptoPub, err := s.GetPublic()
187204
require.NoError(t, err)
@@ -194,9 +211,8 @@ func TestGetPublic_Cached(t *testing.T) {
194211
func TestGetAddress_Deterministic(t *testing.T) {
195212
_, der := generateTestEd25519DER(t)
196213

197-
mock := &mockKMSClient{pubKeyDER: der}
198-
s, err := kmsSignerFromClient(context.Background(), mock, "test-key", nil)
199-
require.NoError(t, err)
214+
mock := &mockKMSClient{pubKeyDER: der, keyID: awsTestKeyID}
215+
s := newTestSigner(t, mock, nil)
200216

201217
addr1, err := s.GetAddress()
202218
require.NoError(t, err)
@@ -206,3 +222,22 @@ func TestGetAddress_Deterministic(t *testing.T) {
206222

207223
assert.Equal(t, addr1, addr2, "address should be deterministic")
208224
}
225+
226+
// generateTestEd25519DER generates an Ed25519 key pair and returns
227+
// the public key in DER (X.509 SubjectPublicKeyInfo) format.
228+
func generateTestEd25519DER(t *testing.T) (ed25519.PublicKey, []byte) {
229+
t.Helper()
230+
pub, _, err := ed25519.GenerateKey(nil)
231+
require.NoError(t, err)
232+
233+
der, err := x509.MarshalPKIXPublicKey(pub)
234+
require.NoError(t, err)
235+
return pub, der
236+
}
237+
238+
func newTestSigner(t *testing.T, mock *mockKMSClient, opts *Options) *KmsSigner {
239+
t.Helper()
240+
s, err := kmsSignerFromClient(t.Context(), mock, awsTestKeyID, opts)
241+
require.NoError(t, err)
242+
return s
243+
}

pkg/signer/gcp/signer.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ func (s *KmsSigner) fetchPublicKey(ctx context.Context) error {
136136
return fmt.Errorf("KMS GetPublicKey failed: %w", err)
137137
}
138138

139+
if out.GetName() != s.keyName {
140+
return fmt.Errorf("KMS GetPublicKey integrity check failed: unexpected key name %q", out.GetName())
141+
}
142+
if out.GetPemCrc32C() == nil {
143+
return fmt.Errorf("KMS GetPublicKey integrity check failed: pem_crc32c is missing")
144+
}
145+
if got, want := out.GetPemCrc32C().GetValue(), int64(crc32.Checksum([]byte(out.GetPem()), castagnoliTable)); got != want {
146+
return fmt.Errorf("KMS GetPublicKey integrity check failed: pem_crc32c mismatch")
147+
}
139148
block, _ := pem.Decode([]byte(out.GetPem()))
140149
if block == nil {
141150
return fmt.Errorf("failed to decode PEM public key")
@@ -211,7 +220,7 @@ func (s *KmsSigner) Sign(ctx context.Context, message []byte) ([]byte, error) {
211220
continue
212221
}
213222

214-
if err := verifySignResponse(out); err != nil {
223+
if err := verifySignResponse(out, s.keyName); err != nil {
215224
lastErr = err
216225
continue
217226
}
@@ -242,11 +251,13 @@ func retryBackoff(attempt int) time.Duration {
242251
return backoff
243252
}
244253

245-
func verifySignResponse(out *kmspb.AsymmetricSignResponse) error {
254+
func verifySignResponse(out *kmspb.AsymmetricSignResponse, expectedName string) error {
246255
if !out.GetVerifiedDataCrc32C() {
247256
return fmt.Errorf("KMS Sign integrity check failed: verified_data_crc32c is false")
248257
}
249-
258+
if out.GetName() != expectedName {
259+
return fmt.Errorf("KMS Sign integrity check failed: unexpected key name %q", out.GetName())
260+
}
250261
signatureCRC32C := out.GetSignatureCrc32C()
251262
if signatureCRC32C == nil {
252263
return fmt.Errorf("KMS Sign integrity check failed: signature_crc32c is missing")

0 commit comments

Comments
 (0)