@@ -12,6 +12,7 @@ import (
1212 "time"
1313
1414 "github.com/creativeprojects/clog"
15+ "github.com/stretchr/testify/assert"
1516 "github.com/stretchr/testify/require"
1617)
1718
@@ -103,6 +104,26 @@ func TestSSHClient(t *testing.T) {
103104 },
104105 connectErr : false ,
105106 },
107+ {
108+ name : "successful connection using any of the provided key" ,
109+ config : Config {
110+ Host : "localhost" ,
111+ Port : 2222 ,
112+ Username : "resticprofile" ,
113+ KnownHostsPath : filepath .Join (tmpDir , "known_hosts" ),
114+ PrivateKeyPaths : []string {
115+ filepath .Join (tmpDir , "file-not-found" ), // Next key should be used
116+ filepath .Join (tmpDir , "id_ed25519" ),
117+ filepath .Join (tmpDir , "id_ecdsa" ),
118+ filepath .Join (tmpDir , "id_rsa" ),
119+ },
120+ Handler : http .HandlerFunc (func (resp http.ResponseWriter , req * http.Request ) {
121+ resp .Write ([]byte ("Connection successful any of the provided key\n " ))
122+ }),
123+ ConnectTimeout : 10 * time .Second ,
124+ },
125+ connectErr : false ,
126+ },
106127 }
107128
108129 for _ , fixture := range fixtures {
@@ -127,3 +148,83 @@ func TestSSHClient(t *testing.T) {
127148 }
128149 }
129150}
151+
152+ func TestSSHClientRunCommandWithCancelledContext (t * testing.T ) {
153+ clog .SetTestLog (t )
154+ defer clog .CloseTestLog ()
155+
156+ tmpDir := os .Getenv ("SSH_TESTS_TMPDIR" )
157+ if tmpDir == "" {
158+ tmpDir = filepath .Join (os .TempDir (), "resticprofile-ssh-tests" )
159+ }
160+
161+ config := Config {
162+ Host : "localhost" ,
163+ Port : 2222 ,
164+ Username : "resticprofile" ,
165+ KnownHostsPath : filepath .Join (tmpDir , "known_hosts" ),
166+ PrivateKeyPaths : []string {
167+ filepath .Join (tmpDir , "id_ed25519" ),
168+ filepath .Join (tmpDir , "id_ecdsa" ),
169+ filepath .Join (tmpDir , "id_rsa" ),
170+ },
171+ Handler : http .HandlerFunc (func (resp http.ResponseWriter , req * http.Request ) {
172+ t .Error ("should not have been called" )
173+ }),
174+ }
175+
176+ for _ , client := range []Client {NewOpenSSHClient (config ), NewInternalClient (config )} {
177+ t .Run (client .Name (), func (t * testing.T ) {
178+ defer client .Close (context .Background ())
179+
180+ ctx , cancel := context .WithCancel (context .Background ())
181+
182+ err := client .Connect (ctx )
183+ require .NoError (t , err )
184+
185+ cancel ()
186+
187+ err = client .Run (ctx , "curl" , fmt .Sprintf ("http://localhost:%d/" , client .TunnelPeerPort ()))
188+ require .Error (t , err )
189+ assert .ErrorIs (t , err , context .Canceled )
190+ })
191+ }
192+ }
193+
194+ func TestSSHClientRunCommandThenCancelContext (t * testing.T ) {
195+ clog .SetTestLog (t )
196+ defer clog .CloseTestLog ()
197+
198+ tmpDir := os .Getenv ("SSH_TESTS_TMPDIR" )
199+ if tmpDir == "" {
200+ tmpDir = filepath .Join (os .TempDir (), "resticprofile-ssh-tests" )
201+ }
202+
203+ config := Config {
204+ Host : "localhost" ,
205+ Port : 2222 ,
206+ Username : "resticprofile" ,
207+ KnownHostsPath : filepath .Join (tmpDir , "known_hosts" ),
208+ PrivateKeyPaths : []string {
209+ filepath .Join (tmpDir , "id_ed25519" ),
210+ filepath .Join (tmpDir , "id_ecdsa" ),
211+ filepath .Join (tmpDir , "id_rsa" ),
212+ },
213+ }
214+
215+ for _ , client := range []Client {NewOpenSSHClient (config ), NewInternalClient (config )} {
216+ t .Run (client .Name (), func (t * testing.T ) {
217+ defer client .Close (context .Background ())
218+
219+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
220+ defer cancel ()
221+
222+ err := client .Connect (ctx )
223+ require .NoError (t , err )
224+
225+ err = client .Run (ctx , "sleep" , "10" )
226+ require .Error (t , err )
227+ t .Log (err )
228+ })
229+ }
230+ }
0 commit comments