diff --git a/Makefile b/Makefile index 5f1b10a9..272f09ae 100644 --- a/Makefile +++ b/Makefile @@ -384,4 +384,4 @@ stop-ssh-server: ## Stop the SSH server and clean up temporary files .PHONY: ssh-test ssh-test: ## Run SSH client tests @echo "[*] $@" - @go test -run TestSSHClient -v -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh + @go test -run TestSSHClient -v -race -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh diff --git a/serve.go b/serve.go index c9bfae3b..ef043f44 100644 --- a/serve.go +++ b/serve.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "path" + "syscall" "time" "github.com/creativeprojects/clog" @@ -101,6 +102,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { KnownHostsPath: remoteConfig.KnownHostsPath, SSHConfigPath: remoteConfig.SSHConfig, Handler: handler, + ConnectTimeout: 20 * time.Second, } var cnx ssh.Client switch remoteConfig.Connection { @@ -112,8 +114,11 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { return fmt.Errorf("unsupported connection type %q for remote %q", remoteConfig.Connection, remoteName) } - err = cnx.Connect(context.Background()) - defer cnx.Close(context.Background()) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGABRT) + defer cancel() + + err = cnx.Connect(ctx) + defer cnx.Close(context.WithoutCancel(ctx)) if err != nil { return err } @@ -126,7 +131,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { "-v", "-r", fmt.Sprintf("http://localhost:%d/configuration/%s", cnx.TunnelPeerPort(), remoteName), } - err = cnx.Run(context.Background(), binaryPath, arguments...) + err = cnx.Run(ctx, binaryPath, arguments...) if err != nil { return fmt.Errorf("failed to run resticprofile on peer: %w", err) } diff --git a/ssh/client_test.go b/ssh/client_test.go index abf0a8ed..0ba6467a 100644 --- a/ssh/client_test.go +++ b/ssh/client_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/creativeprojects/clog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -103,6 +104,26 @@ func TestSSHClient(t *testing.T) { }, connectErr: false, }, + { + name: "successful connection using any of the provided key", + config: Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "file-not-found"), // Next key should be used + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Write([]byte("Connection successful any of the provided key\n")) + }), + ConnectTimeout: 10 * time.Second, + }, + connectErr: false, + }, } for _, fixture := range fixtures { @@ -127,3 +148,83 @@ func TestSSHClient(t *testing.T) { } } } + +func TestSSHClientRunCommandWithCancelledContext(t *testing.T) { + clog.SetTestLog(t) + defer clog.CloseTestLog() + + tmpDir := os.Getenv("SSH_TESTS_TMPDIR") + if tmpDir == "" { + tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests") + } + + config := Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + t.Error("should not have been called") + }), + } + + for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} { + t.Run(client.Name(), func(t *testing.T) { + defer client.Close(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + + err := client.Connect(ctx) + require.NoError(t, err) + + cancel() + + err = client.Run(ctx, "curl", fmt.Sprintf("http://localhost:%d/", client.TunnelPeerPort())) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) + } +} + +func TestSSHClientRunCommandThenCancelContext(t *testing.T) { + clog.SetTestLog(t) + defer clog.CloseTestLog() + + tmpDir := os.Getenv("SSH_TESTS_TMPDIR") + if tmpDir == "" { + tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests") + } + + config := Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + } + + for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} { + t.Run(client.Name(), func(t *testing.T) { + defer client.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.Connect(ctx) + require.NoError(t, err) + + err = client.Run(ctx, "sleep", "10") + require.Error(t, err) + t.Log(err) + }) + } +} diff --git a/ssh/config.go b/ssh/config.go index 36ffba06..b3c4a098 100644 --- a/ssh/config.go +++ b/ssh/config.go @@ -5,6 +5,7 @@ import ( "net/http" "os" "path/filepath" + "time" ) // Config holds the configuration to connect to the SSH server @@ -16,6 +17,7 @@ type Config struct { KnownHostsPath string SSHConfigPath string // Path to the OpenSSH config file, if any Handler http.Handler + ConnectTimeout time.Duration } func (c *Config) ValidateOpenSSH() error { diff --git a/ssh/internal_client.go b/ssh/internal_client.go index 4f3f5da0..fcdf3455 100644 --- a/ssh/internal_client.go +++ b/ssh/internal_client.go @@ -91,6 +91,7 @@ func (s *InternalClient) Connect(_ context.Context) error { Ciphers: algorithms.Ciphers, MACs: algorithms.MACs, }, + Timeout: s.config.ConnectTimeout, } host := s.config.Host @@ -115,11 +116,11 @@ func (s *InternalClient) Connect(_ context.Context) error { } s.tunnelPort = int(addrWithPort.AddrPort().Port()) + s.server = &http.Server{ + Handler: s.config.Handler, + ReadHeaderTimeout: 5 * time.Second, + } s.wg.Go(func() { - s.server = &http.Server{ - Handler: s.config.Handler, - ReadHeaderTimeout: 5 * time.Second, - } // Serve HTTP with your SSH server acting as a reverse proxy. err := s.server.Serve(s.tunnel) if err != nil && err != http.ErrServerClosed && !errors.Is(err, io.EOF) { @@ -134,7 +135,10 @@ func (s *InternalClient) TunnelPeerPort() int { return s.tunnelPort } -func (s *InternalClient) Run(_ context.Context, command string, arguments ...string) error { +func (s *InternalClient) Run(ctx context.Context, command string, arguments ...string) error { + if ctx.Err() != nil { + return ctx.Err() + } // Each ClientConn can support multiple interactive sessions, // represented by a Session. session, err := s.client.NewSession() @@ -143,6 +147,27 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str } defer session.Close() + done := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Go(func() { + select { + case <-ctx.Done(): + if session != nil { + err := session.Signal(ssh.SIGINT) + if err != nil { + clog.Warningf("unable to send interrupt signal to ssh session: %s", err) + } + } + return + case <-done: + return + } + }) + defer func() { + close(done) + wg.Wait() + }() + // request a pseudo terminal to display colors if termType := os.Getenv("TERM"); termType != "" { modes := ssh.TerminalModes{ @@ -162,7 +187,7 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str if err := session.Run(cmdline); err != nil { return fmt.Errorf("failed to run: %w", err) } - return nil + return ctx.Err() // in case the context was cancelled } func (s *InternalClient) Close(ctx context.Context) { diff --git a/ssh/openssh_client.go b/ssh/openssh_client.go index a46f64b5..e24ccb29 100644 --- a/ssh/openssh_client.go +++ b/ssh/openssh_client.go @@ -113,6 +113,12 @@ func (c *OpenSSHClient) startSSH(ctx context.Context) error { if c.config.KnownHostsPath != "" { args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath)) } + if c.config.ConnectTimeout > 0 { + timeout := int(c.config.ConnectTimeout.Seconds()) + if timeout > 0 { + args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", timeout)) + } + } for _, privateKeyPath := range c.config.PrivateKeyPaths { args = append(args, "-i", privateKeyPath) } @@ -215,6 +221,14 @@ func (c *OpenSSHClient) Run(ctx context.Context, command string, arguments ...st cmd := exec.CommandContext(ctx, "ssh", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.Cancel = func() error { + if cmd.Process == nil { + return os.ErrProcessDone + } + return cmd.Process.Signal(os.Interrupt) + } + cmd.WaitDelay = 10 * time.Second + clog.Debugf("running command: %s", cmd.String()) err := cmd.Run() if err != nil {