Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,4 @@ stop-ssh-server: ## Stop the SSH server and clean up temporary files
.PHONY: ssh-test
ssh-test: ## Run SSH client tests
@echo "[*] $@"
@go test -run TestSSHClient -v -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh
@go test -run TestSSHClient -v -race -tags ssh -coverprofile='$(COVERAGE_SSH_FILE)' ./ssh
11 changes: 8 additions & 3 deletions serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/signal"
"path"
"syscall"
"time"

"github.com/creativeprojects/clog"
Expand Down Expand Up @@ -101,6 +102,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
KnownHostsPath: remoteConfig.KnownHostsPath,
SSHConfigPath: remoteConfig.SSHConfig,
Handler: handler,
ConnectTimeout: 20 * time.Second,
}
var cnx ssh.Client
switch remoteConfig.Connection {
Expand All @@ -112,8 +114,11 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
return fmt.Errorf("unsupported connection type %q for remote %q", remoteConfig.Connection, remoteName)
}

err = cnx.Connect(context.Background())
defer cnx.Close(context.Background())
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGABRT)
defer cancel()

err = cnx.Connect(ctx)
defer cnx.Close(context.WithoutCancel(ctx))
Comment thread
creativeprojects marked this conversation as resolved.
if err != nil {
return err
}
Expand All @@ -126,7 +131,7 @@ func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
"-v",
"-r", fmt.Sprintf("http://localhost:%d/configuration/%s", cnx.TunnelPeerPort(), remoteName),
}
err = cnx.Run(context.Background(), binaryPath, arguments...)
err = cnx.Run(ctx, binaryPath, arguments...)
if err != nil {
return fmt.Errorf("failed to run resticprofile on peer: %w", err)
}
Expand Down
101 changes: 101 additions & 0 deletions ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/creativeprojects/clog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -103,6 +104,26 @@ func TestSSHClient(t *testing.T) {
},
connectErr: false,
},
{
name: "successful connection using any of the provided key",
config: Config{
Host: "localhost",
Port: 2222,
Username: "resticprofile",
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
PrivateKeyPaths: []string{
filepath.Join(tmpDir, "file-not-found"), // Next key should be used
filepath.Join(tmpDir, "id_ed25519"),
filepath.Join(tmpDir, "id_ecdsa"),
filepath.Join(tmpDir, "id_rsa"),
},
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
resp.Write([]byte("Connection successful any of the provided key\n"))
}),
ConnectTimeout: 10 * time.Second,
},
connectErr: false,
},
}

for _, fixture := range fixtures {
Expand All @@ -127,3 +148,83 @@ func TestSSHClient(t *testing.T) {
}
}
}

func TestSSHClientRunCommandWithCancelledContext(t *testing.T) {
clog.SetTestLog(t)
defer clog.CloseTestLog()

tmpDir := os.Getenv("SSH_TESTS_TMPDIR")
if tmpDir == "" {
tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests")
}

config := Config{
Host: "localhost",
Port: 2222,
Username: "resticprofile",
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
PrivateKeyPaths: []string{
filepath.Join(tmpDir, "id_ed25519"),
filepath.Join(tmpDir, "id_ecdsa"),
filepath.Join(tmpDir, "id_rsa"),
},
Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
t.Error("should not have been called")
}),
}

for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} {
t.Run(client.Name(), func(t *testing.T) {
defer client.Close(context.Background())

ctx, cancel := context.WithCancel(context.Background())

err := client.Connect(ctx)
require.NoError(t, err)

cancel()

err = client.Run(ctx, "curl", fmt.Sprintf("http://localhost:%d/", client.TunnelPeerPort()))
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
})
}
}

func TestSSHClientRunCommandThenCancelContext(t *testing.T) {
clog.SetTestLog(t)
defer clog.CloseTestLog()

tmpDir := os.Getenv("SSH_TESTS_TMPDIR")
if tmpDir == "" {
tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests")
}

config := Config{
Host: "localhost",
Port: 2222,
Username: "resticprofile",
KnownHostsPath: filepath.Join(tmpDir, "known_hosts"),
PrivateKeyPaths: []string{
filepath.Join(tmpDir, "id_ed25519"),
filepath.Join(tmpDir, "id_ecdsa"),
filepath.Join(tmpDir, "id_rsa"),
},
}

for _, client := range []Client{NewOpenSSHClient(config), NewInternalClient(config)} {
t.Run(client.Name(), func(t *testing.T) {
defer client.Close(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

err := client.Connect(ctx)
require.NoError(t, err)

err = client.Run(ctx, "sleep", "10")
require.Error(t, err)
Comment thread
creativeprojects marked this conversation as resolved.
t.Log(err)
})
}
}
2 changes: 2 additions & 0 deletions ssh/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"os"
"path/filepath"
"time"
)

// Config holds the configuration to connect to the SSH server
Expand All @@ -16,6 +17,7 @@ type Config struct {
KnownHostsPath string
SSHConfigPath string // Path to the OpenSSH config file, if any
Handler http.Handler
ConnectTimeout time.Duration
}

func (c *Config) ValidateOpenSSH() error {
Expand Down
37 changes: 31 additions & 6 deletions ssh/internal_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func (s *InternalClient) Connect(_ context.Context) error {
Ciphers: algorithms.Ciphers,
MACs: algorithms.MACs,
},
Timeout: s.config.ConnectTimeout,
}

host := s.config.Host
Expand All @@ -115,11 +116,11 @@ func (s *InternalClient) Connect(_ context.Context) error {
}
s.tunnelPort = int(addrWithPort.AddrPort().Port())

s.server = &http.Server{
Handler: s.config.Handler,
ReadHeaderTimeout: 5 * time.Second,
}
s.wg.Go(func() {
s.server = &http.Server{
Handler: s.config.Handler,
ReadHeaderTimeout: 5 * time.Second,
}
// Serve HTTP with your SSH server acting as a reverse proxy.
err := s.server.Serve(s.tunnel)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, io.EOF) {
Expand All @@ -134,7 +135,10 @@ func (s *InternalClient) TunnelPeerPort() int {
return s.tunnelPort
}

func (s *InternalClient) Run(_ context.Context, command string, arguments ...string) error {
func (s *InternalClient) Run(ctx context.Context, command string, arguments ...string) error {
if ctx.Err() != nil {
return ctx.Err()
}
// Each ClientConn can support multiple interactive sessions,
// represented by a Session.
session, err := s.client.NewSession()
Expand All @@ -143,6 +147,27 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str
}
defer session.Close()

done := make(chan struct{})
wg := sync.WaitGroup{}
wg.Go(func() {
select {
case <-ctx.Done():
if session != nil {
err := session.Signal(ssh.SIGINT)
if err != nil {
clog.Warningf("unable to send interrupt signal to ssh session: %s", err)
}
}
return
case <-done:
return
}
})
defer func() {
close(done)
wg.Wait()
}()

// request a pseudo terminal to display colors
if termType := os.Getenv("TERM"); termType != "" {
modes := ssh.TerminalModes{
Expand All @@ -162,7 +187,7 @@ func (s *InternalClient) Run(_ context.Context, command string, arguments ...str
if err := session.Run(cmdline); err != nil {
return fmt.Errorf("failed to run: %w", err)
}
return nil
return ctx.Err() // in case the context was cancelled
Comment thread
creativeprojects marked this conversation as resolved.
}

func (s *InternalClient) Close(ctx context.Context) {
Expand Down
14 changes: 14 additions & 0 deletions ssh/openssh_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ func (c *OpenSSHClient) startSSH(ctx context.Context) error {
if c.config.KnownHostsPath != "" {
args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath))
}
if c.config.ConnectTimeout > 0 {
timeout := int(c.config.ConnectTimeout.Seconds())
if timeout > 0 {
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", timeout))
}
}
for _, privateKeyPath := range c.config.PrivateKeyPaths {
args = append(args, "-i", privateKeyPath)
}
Expand Down Expand Up @@ -215,6 +221,14 @@ func (c *OpenSSHClient) Run(ctx context.Context, command string, arguments ...st
cmd := exec.CommandContext(ctx, "ssh", args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Cancel = func() error {
if cmd.Process == nil {
return os.ErrProcessDone
}
return cmd.Process.Signal(os.Interrupt)
Comment thread
creativeprojects marked this conversation as resolved.
Comment thread
creativeprojects marked this conversation as resolved.
}
cmd.WaitDelay = 10 * time.Second

clog.Debugf("running command: %s", cmd.String())
err := cmd.Run()
if err != nil {
Expand Down
Loading