Skip to content

Commit aeba4b3

Browse files
committed
fix: allow ssh local forwarding in okdev-sshd
1 parent 0ab0daa commit aeba4b3

2 files changed

Lines changed: 46 additions & 22 deletions

File tree

cmd/okdev-sshd/main.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,29 @@ func main() {
3333
log.Fatalf("failed to load authorized keys: %v", err)
3434
}
3535

36+
srv := newServer(fmt.Sprintf(":%d", *port), *shell, keys)
37+
38+
log.Printf("okdev-sshd listening on :%d", *port)
39+
log.Fatal(srv.ListenAndServe())
40+
}
41+
42+
func newServer(addr, shell string, keys []ssh.PublicKey) *ssh.Server {
43+
channelHandlers := map[string]ssh.ChannelHandler{}
44+
for name, handler := range ssh.DefaultChannelHandlers {
45+
channelHandlers[name] = handler
46+
}
47+
channelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
48+
3649
srv := &ssh.Server{
37-
Addr: fmt.Sprintf(":%d", *port),
38-
Handler: sessionHandler(*shell),
50+
Addr: addr,
51+
Handler: sessionHandler(shell),
3952
SubsystemHandlers: map[string]ssh.SubsystemHandler{
4053
"sftp": sftpHandler,
4154
},
55+
ChannelHandlers: channelHandlers,
56+
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
57+
return true
58+
},
4259
}
4360

4461
if keys != nil {
@@ -52,8 +69,7 @@ func main() {
5269
}
5370
}
5471

55-
log.Printf("okdev-sshd listening on :%d", *port)
56-
log.Fatal(srv.ListenAndServe())
72+
return srv
5773
}
5874

5975
func detectShell() string {

cmd/okdev-sshd/main_test.go

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
11
package main
22

33
import (
4-
"strings"
54
"testing"
5+
6+
"github.com/gliderlabs/ssh"
67
)
78

8-
func TestBuildInteractiveLoginScriptIncludesDevTmuxBootstrap(t *testing.T) {
9-
script := buildInteractiveLoginScript(map[string]string{}, "/bin/bash", "/workspace", "1")
9+
func TestNewServerEnablesLocalPortForwarding(t *testing.T) {
10+
srv := newServer(":2222", "/bin/sh", nil)
1011

11-
for _, want := range []string{
12-
"cd '/workspace'",
13-
"/workspace/.okdev/post-attach.sh",
14-
"/var/okdev/dev.tmux.conf",
15-
"exec tmux new-session -A -s okdev",
16-
} {
17-
if !strings.Contains(script, want) {
18-
t.Fatalf("expected script to contain %q, got:\n%s", want, script)
19-
}
12+
if srv.ChannelHandlers == nil {
13+
t.Fatal("expected channel handlers to be configured")
14+
}
15+
if _, ok := srv.ChannelHandlers["session"]; !ok {
16+
t.Fatal("expected default session channel handler")
17+
}
18+
if _, ok := srv.ChannelHandlers["direct-tcpip"]; !ok {
19+
t.Fatal("expected direct-tcpip channel handler for ssh forwarding")
20+
}
21+
if srv.LocalPortForwardingCallback == nil {
22+
t.Fatal("expected local port forwarding callback")
23+
}
24+
if !srv.LocalPortForwardingCallback(nil, "127.0.0.1", 8080) {
25+
t.Fatal("expected local port forwarding to be allowed")
2026
}
2127
}
2228

23-
func TestBuildInteractiveLoginScriptSkipsTmuxWhenDisabledForSession(t *testing.T) {
24-
script := buildInteractiveLoginScript(map[string]string{"OKDEV_NO_TMUX": "1"}, "/bin/sh", "/workspace", "1")
25-
if strings.Contains(script, "tmux") {
26-
t.Fatalf("expected tmux bootstrap to be skipped, got:\n%s", script)
29+
func TestNewServerAddsPublicKeyHandlerWhenKeysProvided(t *testing.T) {
30+
pub, _, _, _, err := ssh.ParseAuthorizedKey([]byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIE9mN6e2Q2x8tQz4pT2r8j04YfGLwRoTSesFiNUFDXL9 test\n"))
31+
if err != nil {
32+
t.Fatalf("parse authorized key: %v", err)
2733
}
28-
if !strings.Contains(script, "exec '/bin/sh' -l") {
29-
t.Fatalf("expected shell fallback, got:\n%s", script)
34+
35+
srv := newServer(":2222", "/bin/sh", []ssh.PublicKey{pub})
36+
if srv.PublicKeyHandler == nil {
37+
t.Fatal("expected public key handler to be configured")
3038
}
3139
}

0 commit comments

Comments
 (0)