From 5a21c127daeeae41a29d16a4dab8a7ffdd40181b Mon Sep 17 00:00:00 2001 From: Fred Date: Tue, 31 Mar 2026 22:56:59 +0100 Subject: [PATCH 1/3] feat: enhance SSH command execution with cancellation and wait delay --- serve.go | 10 +++++++--- ssh/openssh_client.go | 5 +++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/serve.go b/serve.go index c9bfae3b..a5b89e83 100644 --- a/serve.go +++ b/serve.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "path" + "syscall" "time" "github.com/creativeprojects/clog" @@ -112,8 +113,11 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { return fmt.Errorf("unsupported connection type %q for remote %q", remoteConfig.Connection, remoteName) } - err = cnx.Connect(context.Background()) - defer cnx.Close(context.Background()) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGABRT) + defer cancel() + + err = cnx.Connect(ctx) + defer cnx.Close(context.WithoutCancel(ctx)) if err != nil { return err } @@ -126,7 +130,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { "-v", "-r", fmt.Sprintf("http://localhost:%d/configuration/%s", cnx.TunnelPeerPort(), remoteName), } - err = cnx.Run(context.Background(), binaryPath, arguments...) + err = cnx.Run(ctx, binaryPath, arguments...) if err != nil { return fmt.Errorf("failed to run resticprofile on peer: %w", err) } diff --git a/ssh/openssh_client.go b/ssh/openssh_client.go index a46f64b5..e56857c9 100644 --- a/ssh/openssh_client.go +++ b/ssh/openssh_client.go @@ -215,6 +215,11 @@ func (c *OpenSSHClient) Run(ctx context.Context, command string, arguments ...st cmd := exec.CommandContext(ctx, "ssh", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.Cancel = func() error { + return cmd.Process.Signal(os.Interrupt) + } + cmd.WaitDelay = 10 * time.Second + clog.Debugf("running command: %s", cmd.String()) err := cmd.Run() if err != nil { From 95a9d813ffa13859050a8413af9f16fb0fc4b3b8 Mon Sep 17 00:00:00 2001 From: Fred Date: Thu, 2 Apr 2026 12:52:57 +0100 Subject: [PATCH 2/3] feat: add connection timeout and enhance SSH client command execution --- Makefile | 2 +- serve.go | 1 + ssh/client_test.go | 101 +++++++++++++++++++++++++++++++++++++++++ ssh/config.go | 2 + ssh/internal_client.go | 35 +++++++++++--- ssh/openssh_client.go | 6 +++ 6 files changed, 140 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 5f1b10a9..272f09ae 100644 --- a/Makefile +++ b/Makefile @@ -384,4 +384,4 @@ stop-ssh-server: ## Stop the SSH server and clean up temporary files .PHONY: ssh-test ssh-test: ## Run SSH client tests @echo "[*] $@" - @go test -run TestSSHClient -v -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh + @go test -run TestSSHClient -v -race -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh diff --git a/serve.go b/serve.go index a5b89e83..ef043f44 100644 --- a/serve.go +++ b/serve.go @@ -102,6 +102,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { KnownHostsPath: remoteConfig.KnownHostsPath, SSHConfigPath: remoteConfig.SSHConfig, Handler: handler, + ConnectTimeout: 20 * time.Second, } var cnx ssh.Client switch remoteConfig.Connection { diff --git a/ssh/client_test.go b/ssh/client_test.go index abf0a8ed..0ba6467a 100644 --- a/ssh/client_test.go +++ b/ssh/client_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/creativeprojects/clog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -103,6 +104,26 @@ func TestSSHClient(t *testing.T) { }, connectErr: false, }, + { + name: "successful connection using any of the provided key", + config: Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "file-not-found"), // Next key should be used + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Write([]byte("Connection successful any of the provided key\n")) + }), + ConnectTimeout: 10 * time.Second, + }, + connectErr: false, + }, } for _, fixture := range fixtures { @@ -127,3 +148,83 @@ func TestSSHClient(t *testing.T) { } } } + +func TestSSHClientRunCommandWithCancelledContext(t *testing.T) { + clog.SetTestLog(t) + defer clog.CloseTestLog() + + tmpDir := os.Getenv("SSH_TESTS_TMPDIR") + if tmpDir == "" { + tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests") + } + + config := Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + t.Error("should not have been called") + }), + } + + for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} { + t.Run(client.Name(), func(t *testing.T) { + defer client.Close(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + + err := client.Connect(ctx) + require.NoError(t, err) + + cancel() + + err = client.Run(ctx, "curl", fmt.Sprintf("http://localhost:%d/", client.TunnelPeerPort())) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) + } +} + +func TestSSHClientRunCommandThenCancelContext(t *testing.T) { + clog.SetTestLog(t) + defer clog.CloseTestLog() + + tmpDir := os.Getenv("SSH_TESTS_TMPDIR") + if tmpDir == "" { + tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests") + } + + config := Config{ + Host: "localhost", + Port: 2222, + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPaths: []string{ + filepath.Join(tmpDir, "id_ed25519"), + filepath.Join(tmpDir, "id_ecdsa"), + filepath.Join(tmpDir, "id_rsa"), + }, + } + + for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} { + t.Run(client.Name(), func(t *testing.T) { + defer client.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.Connect(ctx) + require.NoError(t, err) + + err = client.Run(ctx, "sleep", "10") + require.Error(t, err) + t.Log(err) + }) + } +} diff --git a/ssh/config.go b/ssh/config.go index 36ffba06..b3c4a098 100644 --- a/ssh/config.go +++ b/ssh/config.go @@ -5,6 +5,7 @@ import ( "net/http" "os" "path/filepath" + "time" ) // Config holds the configuration to connect to the SSH server @@ -16,6 +17,7 @@ type Config struct { KnownHostsPath string SSHConfigPath string // Path to the OpenSSH config file, if any Handler http.Handler + ConnectTimeout time.Duration } func (c *Config) ValidateOpenSSH() error { diff --git a/ssh/internal_client.go b/ssh/internal_client.go index 4f3f5da0..fd6a4843 100644 --- a/ssh/internal_client.go +++ b/ssh/internal_client.go @@ -91,6 +91,7 @@ func (s *InternalClient) Connect(_ context.Context) error { Ciphers: algorithms.Ciphers, MACs: algorithms.MACs, }, + Timeout: s.config.ConnectTimeout, } host := s.config.Host @@ -115,11 +116,11 @@ func (s *InternalClient) Connect(_ context.Context) error { } s.tunnelPort = int(addrWithPort.AddrPort().Port()) + s.server = &http.Server{ + Handler: s.config.Handler, + ReadHeaderTimeout: 5 * time.Second, + } s.wg.Go(func() { - s.server = &http.Server{ - Handler: s.config.Handler, - ReadHeaderTimeout: 5 * time.Second, - } // Serve HTTP with your SSH server acting as a reverse proxy. err := s.server.Serve(s.tunnel) if err != nil && err != http.ErrServerClosed && !errors.Is(err, io.EOF) { @@ -134,7 +135,10 @@ func (s *InternalClient) TunnelPeerPort() int { return s.tunnelPort } -func (s *InternalClient) Run(_ context.Context, command string, arguments ...string) error { +func (s *InternalClient) Run(ctx context.Context, command string, arguments ...string) error { + if ctx.Err() != nil { + return ctx.Err() + } // Each ClientConn can support multiple interactive sessions, // represented by a Session. session, err := s.client.NewSession() @@ -143,6 +147,23 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str } defer session.Close() + done := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Go(func() { + select { + case <-ctx.Done(): + if session != nil { + err := session.Signal(ssh.SIGINT) + if err != nil { + clog.Warningf("unable to send interrupt signal to ssh session: %s", err) + } + } + return + case <-done: + return + } + }) + // request a pseudo terminal to display colors if termType := os.Getenv("TERM"); termType != "" { modes := ssh.TerminalModes{ @@ -162,7 +183,9 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str if err := session.Run(cmdline); err != nil { return fmt.Errorf("failed to run: %w", err) } - return nil + close(done) + wg.Wait() + return ctx.Err() // in case the context was cancelled } func (s *InternalClient) Close(ctx context.Context) { diff --git a/ssh/openssh_client.go b/ssh/openssh_client.go index e56857c9..021d51c2 100644 --- a/ssh/openssh_client.go +++ b/ssh/openssh_client.go @@ -113,6 +113,9 @@ func (c *OpenSSHClient) startSSH(ctx context.Context) error { if c.config.KnownHostsPath != "" { args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath)) } + if c.config.ConnectTimeout > 0 { + args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", int(c.config.ConnectTimeout.Seconds()))) + } for _, privateKeyPath := range c.config.PrivateKeyPaths { args = append(args, "-i", privateKeyPath) } @@ -216,6 +219,9 @@ func (c *OpenSSHClient) Run(ctx context.Context, command string, arguments ...st cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Cancel = func() error { + if cmd.Process == nil { + return os.ErrProcessDone + } return cmd.Process.Signal(os.Interrupt) } cmd.WaitDelay = 10 * time.Second From 0bdca975ea1a68547fa2c264cdd7d0c7979b6c59 Mon Sep 17 00:00:00 2001 From: Fred Date: Thu, 2 Apr 2026 13:06:36 +0100 Subject: [PATCH 3/3] feat: improve SSH command execution by ensuring proper cleanup and conditional timeout setting --- ssh/internal_client.go | 6 ++++-- ssh/openssh_client.go | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ssh/internal_client.go b/ssh/internal_client.go index fd6a4843..fcdf3455 100644 --- a/ssh/internal_client.go +++ b/ssh/internal_client.go @@ -163,6 +163,10 @@ func (s *InternalClient) Run(ctx context.Context, command string, arguments ...s return } }) + defer func() { + close(done) + wg.Wait() + }() // request a pseudo terminal to display colors if termType := os.Getenv("TERM"); termType != "" { @@ -183,8 +187,6 @@ func (s *InternalClient) Run(ctx context.Context, command string, arguments ...s if err := session.Run(cmdline); err != nil { return fmt.Errorf("failed to run: %w", err) } - close(done) - wg.Wait() return ctx.Err() // in case the context was cancelled } diff --git a/ssh/openssh_client.go b/ssh/openssh_client.go index 021d51c2..e24ccb29 100644 --- a/ssh/openssh_client.go +++ b/ssh/openssh_client.go @@ -114,7 +114,10 @@ func (c *OpenSSHClient) startSSH(ctx context.Context) error { args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath)) } if c.config.ConnectTimeout > 0 { - args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", int(c.config.ConnectTimeout.Seconds()))) + timeout := int(c.config.ConnectTimeout.Seconds()) + if timeout > 0 { + args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", timeout)) + } } for _, privateKeyPath := range c.config.PrivateKeyPaths { args = append(args, "-i", privateKeyPath)