Skip to content

Commit b6693a3

Browse files
feat: enhance SSH command execution with cancellation and wait delay (#617)
* feat: enhance SSH command execution with cancellation and wait delay * feat: add connection timeout and enhance SSH client command execution * feat: improve SSH command execution by ensuring proper cleanup and conditional timeout setting
1 parent fd44cf6 commit b6693a3

6 files changed

Lines changed: 157 additions & 10 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,4 @@ stop-ssh-server: ## Stop the SSH server and clean up temporary files
384384
.PHONY: ssh-test
385385
ssh-test: ## Run SSH client tests
386386
@echo "[*] $@"
387-
@go test -run TestSSHClient -v -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh
387+
@go test -run TestSSHClient -v -race -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh

serve.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"os/signal"
1111
"path"
12+
"syscall"
1213
"time"
1314

1415
"github.com/creativeprojects/clog"
@@ -101,6 +102,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
101102
KnownHostsPath: remoteConfig.KnownHostsPath,
102103
SSHConfigPath: remoteConfig.SSHConfig,
103104
Handler: handler,
105+
ConnectTimeout: 20 * time.Second,
104106
}
105107
var cnx ssh.Client
106108
switch remoteConfig.Connection {
@@ -112,8 +114,11 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
112114
return fmt.Errorf("unsupported connection type %q for remote %q", remoteConfig.Connection, remoteName)
113115
}
114116

115-
err = cnx.Connect(context.Background())
116-
defer cnx.Close(context.Background())
117+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGABRT)
118+
defer cancel()
119+
120+
err = cnx.Connect(ctx)
121+
defer cnx.Close(context.WithoutCancel(ctx))
117122
if err != nil {
118123
return err
119124
}
@@ -126,7 +131,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
126131
"-v",
127132
"-r", fmt.Sprintf("http://localhost:%d/configuration/%s", cnx.TunnelPeerPort(), remoteName),
128133
}
129-
err = cnx.Run(context.Background(), binaryPath, arguments...)
134+
err = cnx.Run(ctx, binaryPath, arguments...)
130135
if err != nil {
131136
return fmt.Errorf("failed to run resticprofile on peer: %w", err)
132137
}

ssh/client_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/creativeprojects/clog"
15+
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
)
1718

@@ -103,6 +104,26 @@ func TestSSHClient(t *testing.T) {
103104
},
104105
connectErr: false,
105106
},
107+
{
108+
name: "successful connection using any of the provided key",
109+
config: Config{
110+
Host: "localhost",
111+
Port: 2222,
112+
Username: "resticprofile",
113+
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
114+
PrivateKeyPaths: []string{
115+
filepath.Join(tmpDir, "file-not-found"), // Next key should be used
116+
filepath.Join(tmpDir, "id_ed25519"),
117+
filepath.Join(tmpDir, "id_ecdsa"),
118+
filepath.Join(tmpDir, "id_rsa"),
119+
},
120+
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
121+
resp.Write([]byte("Connection successful any of the provided key\n"))
122+
}),
123+
ConnectTimeout: 10 * time.Second,
124+
},
125+
connectErr: false,
126+
},
106127
}
107128

108129
for _, fixture := range fixtures {
@@ -127,3 +148,83 @@ func TestSSHClient(t *testing.T) {
127148
}
128149
}
129150
}
151+
152+
func TestSSHClientRunCommandWithCancelledContext(t *testing.T) {
153+
clog.SetTestLog(t)
154+
defer clog.CloseTestLog()
155+
156+
tmpDir := os.Getenv("SSH_TESTS_TMPDIR")
157+
if tmpDir == "" {
158+
tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests")
159+
}
160+
161+
config := Config{
162+
Host: "localhost",
163+
Port: 2222,
164+
Username: "resticprofile",
165+
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
166+
PrivateKeyPaths: []string{
167+
filepath.Join(tmpDir, "id_ed25519"),
168+
filepath.Join(tmpDir, "id_ecdsa"),
169+
filepath.Join(tmpDir, "id_rsa"),
170+
},
171+
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
172+
t.Error("should not have been called")
173+
}),
174+
}
175+
176+
for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} {
177+
t.Run(client.Name(), func(t *testing.T) {
178+
defer client.Close(context.Background())
179+
180+
ctx, cancel := context.WithCancel(context.Background())
181+
182+
err := client.Connect(ctx)
183+
require.NoError(t, err)
184+
185+
cancel()
186+
187+
err = client.Run(ctx, "curl", fmt.Sprintf("http://localhost:%d/", client.TunnelPeerPort()))
188+
require.Error(t, err)
189+
assert.ErrorIs(t, err, context.Canceled)
190+
})
191+
}
192+
}
193+
194+
func TestSSHClientRunCommandThenCancelContext(t *testing.T) {
195+
clog.SetTestLog(t)
196+
defer clog.CloseTestLog()
197+
198+
tmpDir := os.Getenv("SSH_TESTS_TMPDIR")
199+
if tmpDir == "" {
200+
tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests")
201+
}
202+
203+
config := Config{
204+
Host: "localhost",
205+
Port: 2222,
206+
Username: "resticprofile",
207+
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
208+
PrivateKeyPaths: []string{
209+
filepath.Join(tmpDir, "id_ed25519"),
210+
filepath.Join(tmpDir, "id_ecdsa"),
211+
filepath.Join(tmpDir, "id_rsa"),
212+
},
213+
}
214+
215+
for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} {
216+
t.Run(client.Name(), func(t *testing.T) {
217+
defer client.Close(context.Background())
218+
219+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
220+
defer cancel()
221+
222+
err := client.Connect(ctx)
223+
require.NoError(t, err)
224+
225+
err = client.Run(ctx, "sleep", "10")
226+
require.Error(t, err)
227+
t.Log(err)
228+
})
229+
}
230+
}

ssh/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"os"
77
"path/filepath"
8+
"time"
89
)
910

1011
// Config holds the configuration to connect to the SSH server
@@ -16,6 +17,7 @@ type Config struct {
1617
KnownHostsPath string
1718
SSHConfigPath string // Path to the OpenSSH config file, if any
1819
Handler http.Handler
20+
ConnectTimeout time.Duration
1921
}
2022

2123
func (c *Config) ValidateOpenSSH() error {

ssh/internal_client.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func (s *InternalClient) Connect(_ context.Context) error {
9191
Ciphers: algorithms.Ciphers,
9292
MACs: algorithms.MACs,
9393
},
94+
Timeout: s.config.ConnectTimeout,
9495
}
9596

9697
host := s.config.Host
@@ -115,11 +116,11 @@ func (s *InternalClient) Connect(_ context.Context) error {
115116
}
116117
s.tunnelPort = int(addrWithPort.AddrPort().Port())
117118

119+
s.server = &http.Server{
120+
Handler: s.config.Handler,
121+
ReadHeaderTimeout: 5 * time.Second,
122+
}
118123
s.wg.Go(func() {
119-
s.server = &http.Server{
120-
Handler: s.config.Handler,
121-
ReadHeaderTimeout: 5 * time.Second,
122-
}
123124
// Serve HTTP with your SSH server acting as a reverse proxy.
124125
err := s.server.Serve(s.tunnel)
125126
if err != nil && err != http.ErrServerClosed && !errors.Is(err, io.EOF) {
@@ -134,7 +135,10 @@ func (s *InternalClient) TunnelPeerPort() int {
134135
return s.tunnelPort
135136
}
136137

137-
func (s *InternalClient) Run(_ context.Context, command string, arguments ...string) error {
138+
func (s *InternalClient) Run(ctx context.Context, command string, arguments ...string) error {
139+
if ctx.Err() != nil {
140+
return ctx.Err()
141+
}
138142
// Each ClientConn can support multiple interactive sessions,
139143
// represented by a Session.
140144
session, err := s.client.NewSession()
@@ -143,6 +147,27 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str
143147
}
144148
defer session.Close()
145149

150+
done := make(chan struct{})
151+
wg := sync.WaitGroup{}
152+
wg.Go(func() {
153+
select {
154+
case <-ctx.Done():
155+
if session != nil {
156+
err := session.Signal(ssh.SIGINT)
157+
if err != nil {
158+
clog.Warningf("unable to send interrupt signal to ssh session: %s", err)
159+
}
160+
}
161+
return
162+
case <-done:
163+
return
164+
}
165+
})
166+
defer func() {
167+
close(done)
168+
wg.Wait()
169+
}()
170+
146171
// request a pseudo terminal to display colors
147172
if termType := os.Getenv("TERM"); termType != "" {
148173
modes := ssh.TerminalModes{
@@ -162,7 +187,7 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str
162187
if err := session.Run(cmdline); err != nil {
163188
return fmt.Errorf("failed to run: %w", err)
164189
}
165-
return nil
190+
return ctx.Err() // in case the context was cancelled
166191
}
167192

168193
func (s *InternalClient) Close(ctx context.Context) {

ssh/openssh_client.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ func (c *OpenSSHClient) startSSH(ctx context.Context) error {
113113
if c.config.KnownHostsPath != "" {
114114
args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath))
115115
}
116+
if c.config.ConnectTimeout > 0 {
117+
timeout := int(c.config.ConnectTimeout.Seconds())
118+
if timeout > 0 {
119+
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", timeout))
120+
}
121+
}
116122
for _, privateKeyPath := range c.config.PrivateKeyPaths {
117123
args = append(args, "-i", privateKeyPath)
118124
}
@@ -215,6 +221,14 @@ func (c *OpenSSHClient) Run(ctx context.Context, command string, arguments ...st
215221
cmd := exec.CommandContext(ctx, "ssh", args...)
216222
cmd.Stdout = os.Stdout
217223
cmd.Stderr = os.Stderr
224+
cmd.Cancel = func() error {
225+
if cmd.Process == nil {
226+
return os.ErrProcessDone
227+
}
228+
return cmd.Process.Signal(os.Interrupt)
229+
}
230+
cmd.WaitDelay = 10 * time.Second
231+
218232
clog.Debugf("running command: %s", cmd.String())
219233
err := cmd.Run()
220234
if err != nil {

0 commit comments

Comments
 (0)