Skip to content

Commit b79aae8

Browse files
authored
Allow to override upstream user (#7)
The syntax is <upstream_id>[_<user>], e.g., * log in as the default job user: 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com * log in as `root`: 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com
1 parent 0b5e6fb commit b79aae8

File tree

1 file changed

+82
-10
lines changed

1 file changed

+82
-10
lines changed

internal/sshproxy/server.go

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ var (
126126
errUnknownPublicKey = errors.New("unknown public key")
127127
)
128128

129+
// upstreamAuthFailureError represents an SSH client auth failure (that is, SSH_MSG_USERAUTH_FAILURE) when connecting
130+
// to any host in the Upstream.hosts chain (either a jump host or a target host)
131+
type upstreamAuthFailureError struct {
132+
sshErr error
133+
isTargetHost bool
134+
}
135+
136+
func (e *upstreamAuthFailureError) Error() string {
137+
return fmt.Sprintf("auth failure: %s", e.sshErr.Error())
138+
}
139+
140+
func (e *upstreamAuthFailureError) Unwrap() error {
141+
return e.sshErr
142+
}
143+
129144
type Server struct {
130145
address string
131146

@@ -246,7 +261,7 @@ func (s *Server) Close(ctx context.Context) error {
246261
}
247262

248263
func (s *Server) publicKeyCallback(conn ssh.ConnMetadata, publicKey ssh.PublicKey) (*ssh.Permissions, error) {
249-
upstreamID := conn.User()
264+
upstreamID, _ := parseAuthUser(conn.User())
250265
logger := log.GetLogger(s.serveCtx).WithField("id", upstreamID)
251266

252267
upstream, found := s.upstreamCache.Get(upstreamID)
@@ -312,9 +327,22 @@ func handleConnection(ctx context.Context, conn net.Conn, config *ssh.ServerConf
312327
logger.Debug("client logged in")
313328

314329
upstream := clientConn.Permissions.ExtraData[upstreamExtraDataKey].(Upstream)
315-
upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream)
330+
_, user := parseAuthUser(clientConn.User())
331+
upstreamConn, upstreamNewChans, upstreamReqs, err := connectToUpstream(ctx, upstream, user)
316332
if err != nil {
317-
logger.WithError(err).Error("failed to connect to upstream")
333+
logger = logger.WithError(err)
334+
if user != "" {
335+
logger = logger.WithField("user", user)
336+
}
337+
338+
const msg = "failed to connect to upstream"
339+
// Don't log as an error if it is a client auth error on the last host in the chain and the user is overridden
340+
// to avoid log noise in case a non-existent user is requested
341+
if authErr, ok := errors.AsType[*upstreamAuthFailureError](err); ok && authErr.isTargetHost && user != "" {
342+
logger.Debug(msg)
343+
} else {
344+
logger.Error(msg)
345+
}
318346

319347
return
320348
}
@@ -420,22 +448,40 @@ func handleConnectionError(ctx context.Context, err error) {
420448
logger.WithError(err).Error("failed to handshake client")
421449
}
422450

451+
// parseAuthUser extracts upstreamID and optional upstreamUser (overrides the default upstream user)
452+
// from the "user name" field of the SSH_MSG_USERAUTH_REQUEST request (the `user` in the `ssh user@hostname` command)
453+
// The optional user is appended to the upstreamID after the `_` delimiter:
454+
// 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com - log in as the default job user
455+
// 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com - log in as `root`
456+
func parseAuthUser(user string) (upstreamID string, upstreamUser string) {
457+
upstreamID, upstreamUser, _ = strings.Cut(user, "_")
458+
return upstreamID, upstreamUser
459+
}
460+
423461
func connectToUpstream(
424-
ctx context.Context,
425-
upstream Upstream,
462+
ctx context.Context, upstream Upstream, user string,
426463
) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) {
464+
logger := log.GetLogger(ctx)
465+
427466
var conn ssh.Conn
428467
var chans <-chan ssh.NewChannel
429468
var reqs <-chan *ssh.Request
430469

431-
for i, host := range upstream.hosts {
470+
// A target host is the last host in the Upstream.hosts chanin. All other hosts are jump hosts.
471+
targetHostIdx := len(upstream.hosts) - 1
472+
for hostIdx, host := range upstream.hosts {
473+
isTargetHost := hostIdx == targetHostIdx
474+
hostUser := host.user
475+
if isTargetHost && user != "" {
476+
hostUser = user
477+
}
432478
config := &ssh.ClientConfig{
433479
Config: ssh.Config{
434480
KeyExchanges: allowedKeyExchanges,
435481
Ciphers: allowedCiphers,
436482
MACs: allowedMACs,
437483
},
438-
User: host.user,
484+
User: hostUser,
439485
Auth: []ssh.AuthMethod{
440486
ssh.PublicKeys(host.privateKey),
441487
},
@@ -446,7 +492,7 @@ func connectToUpstream(
446492
var netConn net.Conn
447493
var err error
448494

449-
if i == 0 {
495+
if hostIdx == 0 {
450496
d := net.Dialer{
451497
Timeout: upstreamDialTimeout,
452498
}
@@ -457,14 +503,31 @@ func connectToUpstream(
457503
netConn, err = client.Dial("tcp", host.address)
458504
}
459505

506+
var hostType string
507+
if isTargetHost {
508+
hostType = "target"
509+
} else {
510+
hostType = "jump"
511+
}
512+
460513
if err != nil {
461-
return nil, nil, nil, fmt.Errorf("dial upstream %d %s: %w", i, host.address, err)
514+
return nil, nil, nil, fmt.Errorf("dial %s host #%d %s: %w", hostType, hostIdx, host.address, err)
462515
}
463516

464517
conn, chans, reqs, err = ssh.NewClientConn(netConn, host.address, config)
465518
if err != nil {
466-
return nil, nil, nil, fmt.Errorf("create SSH connection %d %s: %w", i, host.address, err)
519+
if isClientAuthFailureError(err) {
520+
err = &upstreamAuthFailureError{
521+
sshErr: err,
522+
isTargetHost: isTargetHost,
523+
}
524+
}
525+
526+
return nil, nil, nil, fmt.Errorf(
527+
"create SSH connection to %s host #%d %s: %w", hostType, hostIdx, host.address, err)
467528
}
529+
530+
logger.Tracef("connected to %s host #%d %s", hostType, hostIdx, host.address)
468531
}
469532

470533
return conn, chans, reqs, nil
@@ -611,3 +674,12 @@ func getSSHError(err error) error {
611674

612675
return nil
613676
}
677+
678+
func isClientAuthFailureError(err error) bool {
679+
sshErr := getSSHError(err)
680+
if sshErr == nil {
681+
return false
682+
}
683+
// https://github.com/golang/crypto/blob/982eaa62dfb7273603b97fc1835561450096f3bd/ssh/client_auth.go#L118
684+
return strings.Contains(sshErr.Error(), "unable to authenticate")
685+
}

0 commit comments

Comments
 (0)