@@ -135,11 +135,61 @@ func TestSign_NilMessage(t *testing.T) {
135135
136136 signer , _ := newTestSigner (t , FormatOpenPGP , "ABC" , nil )
137137
138- sig , err := signer .Sign (nil )
138+ sig , err := signer .Sign (t . Context (), nil )
139139 require .ErrorIs (t , err , ErrNilMessage )
140140 require .Nil (t , sig )
141141}
142142
143+ // TestSign_ThreadsContext asserts Sign threads the caller's context into the
144+ // command invocation for every format, rather than building the command with a
145+ // fresh context.Background(). The proof is propagation: cancelling the context
146+ // passed to Sign is observable through the context the command was created
147+ // with, which is only possible if it is the same context.
148+ func TestSign_ThreadsContext (t * testing.T ) {
149+ t .Parallel ()
150+
151+ tests := []struct {
152+ name string
153+ format Format
154+ signingKey string
155+ }{
156+ {name : "openpgp" , format : FormatOpenPGP , signingKey : "KEYID" },
157+ {name : "x509" , format : FormatX509 , signingKey : "KEYID" },
158+ {name : "ssh" , format : FormatSSH , signingKey : "/path/to/key" },
159+ }
160+
161+ for _ , test := range tests {
162+ t .Run (test .name , func (t * testing.T ) {
163+ t .Parallel ()
164+
165+ signer , calls := newTestSigner (t , test .format , test .signingKey , func (cmd * mockCommand ) error {
166+ if test .format == FormatSSH {
167+ return writeSignatureFile (cmd .args [len (cmd .args )- 1 ])
168+ }
169+
170+ _ , err := io .WriteString (cmd .stdout , "SIG\n " )
171+ require .NoError (t , err )
172+
173+ return nil
174+ })
175+
176+ ctx , cancel := context .WithCancel (t .Context ())
177+ t .Cleanup (cancel )
178+
179+ _ , err := signer .Sign (ctx , strings .NewReader ("body" ))
180+ require .NoError (t , err )
181+
182+ require .Len (t , calls (), 1 )
183+ got := calls ()[0 ].ctx
184+ require .NotNil (t , got )
185+
186+ require .NoError (t , got .Err ())
187+ cancel ()
188+ assert .ErrorIs (t , got .Err (), context .Canceled )
189+ })
190+ }
191+ }
192+
143193func TestSign_StdioFormats (t * testing.T ) {
144194 t .Parallel ()
145195
@@ -163,7 +213,7 @@ func TestSign_StdioFormats(t *testing.T) {
163213 return nil
164214 })
165215
166- sig , err := signer .Sign (strings .NewReader ("commit body\n " ))
216+ sig , err := signer .Sign (t . Context (), strings .NewReader ("commit body\n " ))
167217 require .NoError (t , err )
168218 assert .Equal (t , "STDIO-SIG\n " , string (sig ))
169219 assert .Equal (t , "commit body\n " , stdin )
@@ -184,7 +234,7 @@ func TestSign_StdioFailure(t *testing.T) {
184234 return errStdioExit
185235 })
186236
187- sig , err := signer .Sign (strings .NewReader ("body" ))
237+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
188238 require .Error (t , err )
189239 assert .Contains (t , err .Error (), "stdio failed" )
190240 require .Nil (t , sig )
@@ -206,7 +256,7 @@ func TestSign_SSH(t *testing.T) {
206256 return writeSignatureFile (bufferFile )
207257 })
208258
209- sig , err := signer .Sign (strings .NewReader ("commit body\n " ))
259+ sig , err := signer .Sign (t . Context (), strings .NewReader ("commit body\n " ))
210260 require .NoError (t , err )
211261 assert .Equal (t , "SSH-SIG\n " , string (sig ))
212262 assert .Equal (t , "commit body\n " , buffer )
@@ -234,7 +284,7 @@ func TestSign_SSHExpandsHomePath(t *testing.T) {
234284 return writeSignatureFile (bufferFile )
235285 })
236286
237- sig , err := signer .Sign (strings .NewReader ("commit body\n " ))
287+ sig , err := signer .Sign (t . Context (), strings .NewReader ("commit body\n " ))
238288 require .NoError (t , err )
239289 assert .Equal (t , "SSH-SIG\n " , string (sig ))
240290
@@ -290,7 +340,7 @@ func TestSign_SSHLiteralKey(t *testing.T) {
290340 return writeSignatureFile (bufferFile )
291341 })
292342
293- sig , err := signer .Sign (strings .NewReader ("commit body\n " ))
343+ sig , err := signer .Sign (t . Context (), strings .NewReader ("commit body\n " ))
294344 require .NoError (t , err )
295345 assert .Equal (t , "SSH-SIG\n " , string (sig ))
296346 assert .Equal (t , "commit body\n " , buffer )
@@ -320,7 +370,7 @@ func TestSign_SSHFailure(t *testing.T) {
320370 return errSSHExit
321371 })
322372
323- sig , err := signer .Sign (strings .NewReader ("body" ))
373+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
324374 require .Error (t , err )
325375 assert .Contains (t , err .Error (), "ssh failed" )
326376 require .Nil (t , sig )
@@ -335,7 +385,7 @@ func TestSign_SSHPathPrefixedSshDash(t *testing.T) {
335385 return writeSignatureFile (bufferFile )
336386 })
337387
338- sig , err := signer .Sign (strings .NewReader ("body" ))
388+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
339389 require .NoError (t , err )
340390 require .NotNil (t , sig )
341391
@@ -359,7 +409,7 @@ func TestSign_StdioOutputTooLarge(t *testing.T) {
359409 return nil
360410 })
361411
362- sig , err := signer .Sign (strings .NewReader ("body" ))
412+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
363413 require .ErrorIs (t , err , ErrOutputLimitExceeded )
364414 assert .Contains (t , err .Error (), "stdout" )
365415 require .Nil (t , sig )
@@ -379,7 +429,7 @@ func TestSign_StderrTooLarge(t *testing.T) {
379429 return nil
380430 })
381431
382- sig , err := signer .Sign (strings .NewReader ("body" ))
432+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
383433 require .ErrorIs (t , err , ErrOutputLimitExceeded )
384434 assert .Contains (t , err .Error (), "stderr" )
385435 require .Nil (t , sig )
@@ -401,12 +451,13 @@ func TestSign_SSHSignatureTooLarge(t *testing.T) {
401451 return nil
402452 })
403453
404- sig , err := signer .Sign (strings .NewReader ("body" ))
454+ sig , err := signer .Sign (t . Context (), strings .NewReader ("body" ))
405455 require .ErrorIs (t , err , ErrSignatureTooLarge )
406456 require .Nil (t , sig )
407457}
408458
409459type mockCommand struct {
460+ ctx context.Context //nolint:containedctx // captured to assert Sign threads its context into the command.
410461 run func (* mockCommand ) error
411462 stdin io.Reader
412463 stdout io.Writer
@@ -457,8 +508,9 @@ func stubCommand(run func(*mockCommand) error) (
457508) {
458509 calls := make ([]* mockCommand , 0 , 1 )
459510
460- commandContext := func (_ context.Context , binary string , args ... string ) command {
511+ commandContext := func (ctx context.Context , binary string , args ... string ) command {
461512 cmd := & mockCommand {
513+ ctx : ctx ,
462514 run : run ,
463515 stdin : nil ,
464516 stdout : nil ,
0 commit comments