Skip to content

Commit 1311364

Browse files
authored
[client] Increase ssh detection timeout (netbirdio#4827)
1 parent 68f56b7 commit 1311364

6 files changed

Lines changed: 46 additions & 35 deletions

File tree

client/cmd/ssh.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
749749
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
750750
logOutput = firstLogFile
751751
}
752-
if err := util.InitLog(logLevel, logOutput); err != nil {
752+
753+
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
754+
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
753755
return fmt.Errorf("init log: %w", err)
754756
}
755757

@@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{
788790
}
789791

790792
func sshDetectFn(cmd *cobra.Command, args []string) error {
791-
if err := util.InitLog(logLevel, "console"); err != nil {
793+
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
794+
if err := util.InitLog(detectLogLevel, "console"); err != nil {
792795
os.Exit(detection.ServerTypeRegular.ExitCode())
793796
}
794797

@@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
797800

798801
port, err := strconv.Atoi(portStr)
799802
if err != nil {
803+
log.Debugf("invalid port %q: %v", portStr, err)
800804
os.Exit(detection.ServerTypeRegular.ExitCode())
801805
}
802806

803-
dialer := &net.Dialer{Timeout: detection.Timeout}
804-
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
807+
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
808+
809+
dialer := &net.Dialer{}
810+
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
805811
if err != nil {
812+
log.Debugf("SSH server detection failed: %v", err)
813+
cancel()
806814
os.Exit(detection.ServerTypeRegular.ExitCode())
807815
}
808816

817+
cancel()
809818
os.Exit(serverType.ExitCode())
810819
return nil
811820
}

client/ssh/client/client.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
343343
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
344344
}
345345

346-
dialer := &net.Dialer{Timeout: detection.Timeout}
347-
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
346+
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
347+
defer cancel()
348+
349+
dialer := &net.Dialer{}
350+
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
348351
if err != nil {
349-
return nil, fmt.Errorf("SSH server detection failed: %w", err)
352+
return nil, fmt.Errorf("SSH server detection: %w", err)
350353
}
351354

352355
if !serverType.RequiresJWT() {

client/ssh/config/manager.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
189189

190190
hostLine := strings.Join(deduplicatedPatterns, " ")
191191
config := fmt.Sprintf("Host %s\n", hostLine)
192-
193-
if runtime.GOOS == "windows" {
194-
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
195-
} else {
196-
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
197-
}
192+
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
198193
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
199194
config += " PasswordAuthentication yes\n"
200195
config += " PubkeyAuthentication yes\n"

client/ssh/detection/detection.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package detection
33
import (
44
"bufio"
55
"context"
6+
"fmt"
67
"net"
78
"strconv"
89
"strings"
@@ -19,8 +20,8 @@ const (
1920
// JWTRequiredMarker is appended to responses when JWT is required
2021
JWTRequiredMarker = "NetBird-JWT-Required"
2122

22-
// Timeout is the timeout for SSH server detection
23-
Timeout = 5 * time.Second
23+
// DefaultTimeout is the default timeout for SSH server detection
24+
DefaultTimeout = 5 * time.Second
2425
)
2526

2627
type ServerType string
@@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i
6162

6263
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
6364
if err != nil {
64-
log.Debugf("SSH connection failed for detection: %v", err)
65-
return ServerTypeRegular, nil
65+
return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
6666
}
6767
defer conn.Close()
6868

69-
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
70-
log.Debugf("set read deadline: %v", err)
71-
return ServerTypeRegular, nil
69+
if deadline, ok := ctx.Deadline(); ok {
70+
if err := conn.SetReadDeadline(deadline); err != nil {
71+
return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
72+
}
7273
}
7374

7475
reader := bufio.NewReader(conn)
7576
serverBanner, err := reader.ReadString('\n')
7677
if err != nil {
77-
log.Debugf("read SSH banner: %v", err)
78-
return ServerTypeRegular, nil
78+
return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
7979
}
8080

8181
serverBanner = strings.TrimSpace(serverBanner)

client/ssh/server/jwt_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) {
5858
require.NoError(t, err)
5959
port, err := strconv.Atoi(portStr)
6060
require.NoError(t, err)
61-
dialer := &net.Dialer{Timeout: detection.Timeout}
61+
dialer := &net.Dialer{}
6262
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
6363
if err != nil {
6464
t.Logf("Detection failed: %v", err)
@@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) {
9393
portNoJWT, err := strconv.Atoi(portStrNoJWT)
9494
require.NoError(t, err)
9595

96-
dialer := &net.Dialer{Timeout: detection.Timeout}
96+
dialer := &net.Dialer{}
9797
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
9898
require.NoError(t, err)
9999
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
@@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) {
218218
port, err := strconv.Atoi(portStr)
219219
require.NoError(t, err)
220220

221-
dialer := &net.Dialer{Timeout: detection.Timeout}
221+
dialer := &net.Dialer{}
222222
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
223223
require.NoError(t, err)
224224
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)

client/wasm/cmd/main.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ import (
1919
)
2020

2121
const (
22-
clientStartTimeout = 30 * time.Second
23-
clientStopTimeout = 10 * time.Second
24-
defaultLogLevel = "warn"
22+
clientStartTimeout = 30 * time.Second
23+
clientStopTimeout = 10 * time.Second
24+
defaultLogLevel = "warn"
25+
defaultSSHDetectionTimeout = 20 * time.Second
2526
)
2627

2728
func main() {
@@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
207208
host := args[0].String()
208209
port := args[1].Int()
209210

211+
timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds())
212+
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
213+
timeoutMs = args[2].Int()
214+
if timeoutMs <= 0 {
215+
return js.ValueOf("error: timeout must be positive")
216+
}
217+
}
218+
210219
return createPromise(func(resolve, reject js.Value) {
211-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
220+
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
212221
defer cancel()
213222

214-
serverType, err := detectSSHServerType(ctx, client, host, port)
223+
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
215224
if err != nil {
216225
reject.Invoke(err.Error())
217226
return
@@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
222231
})
223232
}
224233

225-
// detectSSHServerType detects SSH server type using NetBird network connection
226-
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
227-
return sshdetection.DetectSSHServerType(ctx, client, host, port)
228-
}
229-
230234
// createClientObject wraps the NetBird client in a JavaScript object
231235
func createClientObject(client *netbird.Client) js.Value {
232236
obj := make(map[string]interface{})

0 commit comments

Comments
 (0)