@@ -15,18 +15,21 @@ import (
1515 "github.com/stretchr/testify/require"
1616)
1717
18+ const awsTestKeyID = "arn:aws:kms:us-east-1:123456789012:key/test-key-id"
19+
1820// mockKMSClient is a test double implementing KMSClient.
1921type mockKMSClient struct {
2022 pubKeyDER []byte
2123 signFn func (ctx context.Context , params * kms.SignInput ) (* kms.SignOutput , error )
2224 getPubFn func (ctx context.Context , params * kms.GetPublicKeyInput ) (* kms.GetPublicKeyOutput , error )
25+ keyID string
2326}
2427
2528func (m * mockKMSClient ) Sign (ctx context.Context , params * kms.SignInput , _ ... func (* kms.Options )) (* kms.SignOutput , error ) {
2629 if m .signFn != nil {
2730 return m .signFn (ctx , params )
2831 }
29- return & kms.SignOutput {Signature : []byte ("mock-signature" )}, nil
32+ return & kms.SignOutput {Signature : []byte ("mock-signature" ), KeyId : & m . keyID }, nil
3033}
3134
3235func (m * mockKMSClient ) GetPublicKey (ctx context.Context , params * kms.GetPublicKeyInput , _ ... func (* kms.Options )) (* kms.GetPublicKeyOutput , error ) {
@@ -35,26 +38,15 @@ func (m *mockKMSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicK
3538 }
3639 return & kms.GetPublicKeyOutput {
3740 PublicKey : m .pubKeyDER ,
41+ KeyId : & m .keyID ,
3842 }, nil
3943}
4044
41- // generateTestEd25519DER generates an Ed25519 key pair and returns
42- // the public key in DER (X.509 SubjectPublicKeyInfo) format.
43- func generateTestEd25519DER (t * testing.T ) (ed25519.PublicKey , []byte ) {
44- t .Helper ()
45- pub , _ , err := ed25519 .GenerateKey (nil )
46- require .NoError (t , err )
47-
48- der , err := x509 .MarshalPKIXPublicKey (pub )
49- require .NoError (t , err )
50- return pub , der
51- }
52-
5345func TestNewKmsSignerFromClient_Success (t * testing.T ) {
5446 _ , der := generateTestEd25519DER (t )
5547
56- mock := & mockKMSClient {pubKeyDER : der }
57- s , err := kmsSignerFromClient (context . Background (), mock , "arn:aws:kms:us-east-1:123456789012:key/test-key-id" , nil )
48+ mock := & mockKMSClient {pubKeyDER : der , keyID : awsTestKeyID }
49+ s , err := kmsSignerFromClient (t . Context (), mock , awsTestKeyID , nil )
5850 require .NoError (t , err )
5951 require .NotNil (t , s )
6052
@@ -69,16 +61,31 @@ func TestNewKmsSignerFromClient_Success(t *testing.T) {
6961 assert .Len (t , addr , 32 ) // sha256 output
7062}
7163
72- func TestNewKmsSignerFromClient_EmptyKeyID (t * testing.T ) {
73- _ , err := kmsSignerFromClient (context .Background (), & mockKMSClient {}, "" , nil )
74- require .Error (t , err )
75- assert .Contains (t , err .Error (), "key ID is required" )
76- }
64+ func TestNewKmsSignerFromClient_Validation (t * testing.T ) {
65+ specs := map [string ]struct {
66+ client KMSClient
67+ keyID string
68+ errSubstr string
69+ }{
70+ "empty key id" : {
71+ client : & mockKMSClient {},
72+ keyID : "" ,
73+ errSubstr : "key ID is required" ,
74+ },
75+ "nil client" : {
76+ client : nil ,
77+ keyID : awsTestKeyID ,
78+ errSubstr : "client is required" ,
79+ },
80+ }
7781
78- func TestNewKmsSignerFromClient_NilClient (t * testing.T ) {
79- _ , err := kmsSignerFromClient (context .Background (), nil , "test-key" , nil )
80- require .Error (t , err )
81- assert .Contains (t , err .Error (), "client is required" )
82+ for name , spec := range specs {
83+ t .Run (name , func (t * testing.T ) {
84+ _ , err := kmsSignerFromClient (t .Context (), spec .client , spec .keyID , nil )
85+ require .Error (t , err )
86+ assert .Contains (t , err .Error (), spec .errSubstr )
87+ })
88+ }
8289}
8390
8491func TestNewKmsSignerFromClient_GetPublicKeyFails (t * testing.T ) {
@@ -88,7 +95,7 @@ func TestNewKmsSignerFromClient_GetPublicKeyFails(t *testing.T) {
8895 },
8996 }
9097
91- _ , err := kmsSignerFromClient (context . Background (), mock , "test-key" , nil )
98+ _ , err := kmsSignerFromClient (t . Context (), mock , "test-key" , nil )
9299 require .Error (t , err )
93100 assert .Contains (t , err .Error (), "access denied" )
94101}
@@ -98,90 +105,100 @@ func TestSign_Success(t *testing.T) {
98105
99106 expectedSig := []byte ("test-signature-bytes" )
100107 mock := & mockKMSClient {
108+ keyID : awsTestKeyID ,
101109 pubKeyDER : der ,
102110 signFn : func (_ context.Context , params * kms.SignInput ) (* kms.SignOutput , error ) {
111+ keyID := awsTestKeyID
103112 assert .Equal (t , types .MessageTypeRaw , params .MessageType )
104113 assert .Equal (t , types .SigningAlgorithmSpecEd25519Sha512 , params .SigningAlgorithm )
105- return & kms.SignOutput {Signature : expectedSig }, nil
114+ return & kms.SignOutput {Signature : expectedSig , KeyId : & keyID }, nil
106115 },
107116 }
108117
109- s , err := kmsSignerFromClient (context . Background (), mock , "test-key" , nil )
118+ s , err := kmsSignerFromClient (t . Context (), mock , awsTestKeyID , nil )
110119 require .NoError (t , err )
111120
112- sig , err := s .Sign (context . Background (), []byte ("hello world" ))
121+ sig , err := s .Sign (t . Context (), []byte ("hello world" ))
113122 require .NoError (t , err )
114123 assert .Equal (t , expectedSig , sig )
115124}
116125
117- func TestSign_KMSFailure (t * testing.T ) {
118- _ , der := generateTestEd25519DER (t )
119-
120- var calls int32
121- mock := & mockKMSClient {
122- pubKeyDER : der ,
123- signFn : func (_ context.Context , _ * kms.SignInput ) (* kms.SignOutput , error ) {
124- atomic .AddInt32 (& calls , 1 )
125- return nil , & smithy.GenericAPIError {Code : "ThrottlingException" , Message : "rate limit" }
126+ func TestSign_RetryBehavior (t * testing.T ) {
127+ specs := map [string ]struct {
128+ opts * Options
129+ signErr error
130+ errSubstr string
131+ expectedCall int32
132+ }{
133+ "retryable uses default retries" : {
134+ opts : nil ,
135+ signErr : & smithy.GenericAPIError {Code : "ThrottlingException" , Message : "rate limit" },
136+ errSubstr : "AWS KMS sign failed" ,
137+ expectedCall : 4 ,
126138 },
127- }
128-
129- s , err := kmsSignerFromClient (context .Background (), mock , "test-key" , nil )
130- require .NoError (t , err )
131-
132- _ , err = s .Sign (context .Background (), []byte ("hello world" ))
133- require .Error (t , err )
134- assert .Contains (t , err .Error (), "AWS KMS sign failed" )
135- assert .Equal (t , int32 (4 ), atomic .LoadInt32 (& calls ), "default retries should make 4 attempts" )
136- }
137-
138- func TestSign_MaxRetriesZero_DisablesRetries (t * testing.T ) {
139- _ , der := generateTestEd25519DER (t )
140-
141- var calls int32
142- mock := & mockKMSClient {
143- pubKeyDER : der ,
144- signFn : func (_ context.Context , _ * kms.SignInput ) (* kms.SignOutput , error ) {
145- atomic .AddInt32 (& calls , 1 )
146- return nil , & smithy.GenericAPIError {Code : "ThrottlingException" , Message : "rate limit" }
139+ "max retries zero disables retries" : {
140+ opts : & Options {MaxRetries : 0 },
141+ signErr : & smithy.GenericAPIError {Code : "ThrottlingException" , Message : "rate limit" },
142+ errSubstr : "AWS KMS sign failed" ,
143+ expectedCall : 1 ,
144+ },
145+ "non retryable fails fast" : {
146+ opts : & Options {MaxRetries : 3 },
147+ signErr : fmt .Errorf ("access denied" ),
148+ errSubstr : "non-retryable" ,
149+ expectedCall : 1 ,
147150 },
148151 }
149152
150- s , err := kmsSignerFromClient (context .Background (), mock , "test-key" , & Options {MaxRetries : 0 })
151- require .NoError (t , err )
152-
153- _ , err = s .Sign (context .Background (), []byte ("hello world" ))
154- require .Error (t , err )
155- assert .Equal (t , int32 (1 ), atomic .LoadInt32 (& calls ), "max retries 0 should only make one attempt" )
153+ for name , spec := range specs {
154+ t .Run (name , func (t * testing.T ) {
155+ _ , der := generateTestEd25519DER (t )
156+
157+ var calls int32
158+ signer := newTestSigner (t , & mockKMSClient {
159+ keyID : awsTestKeyID ,
160+ pubKeyDER : der ,
161+ signFn : func (_ context.Context , _ * kms.SignInput ) (* kms.SignOutput , error ) {
162+ atomic .AddInt32 (& calls , 1 )
163+ return nil , spec .signErr
164+ },
165+ }, spec .opts )
166+
167+ _ , err := signer .Sign (t .Context (), []byte ("hello world" ))
168+ require .Error (t , err )
169+ assert .Contains (t , err .Error (), spec .errSubstr )
170+ assert .Equal (t , spec .expectedCall , atomic .LoadInt32 (& calls ))
171+ })
172+ }
156173}
157174
158- func TestSign_NonRetryableError_NoRetries (t * testing.T ) {
175+ func TestSign_KeyIDMismatch_ReturnsError (t * testing.T ) {
159176 _ , der := generateTestEd25519DER (t )
160177
161- var calls int32
178+ unexpectedKeyID := "other-key"
162179 mock := & mockKMSClient {
180+ keyID : awsTestKeyID ,
163181 pubKeyDER : der ,
164182 signFn : func (_ context.Context , _ * kms.SignInput ) (* kms.SignOutput , error ) {
165- atomic .AddInt32 (& calls , 1 )
166- return nil , fmt .Errorf ("access denied" )
183+ return & kms.SignOutput {
184+ Signature : []byte ("test-signature-bytes" ),
185+ KeyId : & unexpectedKeyID ,
186+ }, nil
167187 },
168188 }
169189
170- s , err := kmsSignerFromClient (context .Background (), mock , "test-key" , & Options {MaxRetries : 3 })
171- require .NoError (t , err )
190+ s := newTestSigner (t , mock , nil )
172191
173- _ , err = s .Sign (context . Background (), []byte ("hello world" ))
192+ _ , err : = s .Sign (t . Context (), []byte ("hello world" ))
174193 require .Error (t , err )
175- assert .Contains (t , err .Error (), "non-retryable" )
176- assert .Equal (t , int32 (1 ), atomic .LoadInt32 (& calls ), "non-retryable errors should fail fast" )
194+ assert .Contains (t , err .Error (), "unexpected key ID" )
177195}
178196
179197func TestGetPublic_Cached (t * testing.T ) {
180198 pub , der := generateTestEd25519DER (t )
181199
182- mock := & mockKMSClient {pubKeyDER : der }
183- s , err := kmsSignerFromClient (context .Background (), mock , "test-key" , nil )
184- require .NoError (t , err )
200+ mock := & mockKMSClient {pubKeyDER : der , keyID : awsTestKeyID }
201+ s := newTestSigner (t , mock , nil )
185202
186203 cryptoPub , err := s .GetPublic ()
187204 require .NoError (t , err )
@@ -194,9 +211,8 @@ func TestGetPublic_Cached(t *testing.T) {
194211func TestGetAddress_Deterministic (t * testing.T ) {
195212 _ , der := generateTestEd25519DER (t )
196213
197- mock := & mockKMSClient {pubKeyDER : der }
198- s , err := kmsSignerFromClient (context .Background (), mock , "test-key" , nil )
199- require .NoError (t , err )
214+ mock := & mockKMSClient {pubKeyDER : der , keyID : awsTestKeyID }
215+ s := newTestSigner (t , mock , nil )
200216
201217 addr1 , err := s .GetAddress ()
202218 require .NoError (t , err )
@@ -206,3 +222,22 @@ func TestGetAddress_Deterministic(t *testing.T) {
206222
207223 assert .Equal (t , addr1 , addr2 , "address should be deterministic" )
208224}
225+
226+ // generateTestEd25519DER generates an Ed25519 key pair and returns
227+ // the public key in DER (X.509 SubjectPublicKeyInfo) format.
228+ func generateTestEd25519DER (t * testing.T ) (ed25519.PublicKey , []byte ) {
229+ t .Helper ()
230+ pub , _ , err := ed25519 .GenerateKey (nil )
231+ require .NoError (t , err )
232+
233+ der , err := x509 .MarshalPKIXPublicKey (pub )
234+ require .NoError (t , err )
235+ return pub , der
236+ }
237+
238+ func newTestSigner (t * testing.T , mock * mockKMSClient , opts * Options ) * KmsSigner {
239+ t .Helper ()
240+ s , err := kmsSignerFromClient (t .Context (), mock , awsTestKeyID , opts )
241+ require .NoError (t , err )
242+ return s
243+ }
0 commit comments