@@ -126,6 +126,21 @@ var (
126126 errUnknownPublicKey = errors .New ("unknown public key" )
127127)
128128
129+ // upstreamAuthFailureError represents an SSH client auth failure (that is, SSH_MSG_USERAUTH_FAILURE) when connecting
130+ // to any host in the Upstream.hosts chain (either a jump host or a target host)
131+ type upstreamAuthFailureError struct {
132+ sshErr error
133+ isTargetHost bool
134+ }
135+
136+ func (e * upstreamAuthFailureError ) Error () string {
137+ return fmt .Sprintf ("auth failure: %s" , e .sshErr .Error ())
138+ }
139+
140+ func (e * upstreamAuthFailureError ) Unwrap () error {
141+ return e .sshErr
142+ }
143+
129144type Server struct {
130145 address string
131146
@@ -246,7 +261,7 @@ func (s *Server) Close(ctx context.Context) error {
246261}
247262
248263func (s * Server ) publicKeyCallback (conn ssh.ConnMetadata , publicKey ssh.PublicKey ) (* ssh.Permissions , error ) {
249- upstreamID := conn .User ()
264+ upstreamID , _ := parseAuthUser ( conn .User () )
250265 logger := log .GetLogger (s .serveCtx ).WithField ("id" , upstreamID )
251266
252267 upstream , found := s .upstreamCache .Get (upstreamID )
@@ -312,9 +327,22 @@ func handleConnection(ctx context.Context, conn net.Conn, config *ssh.ServerConf
312327 logger .Debug ("client logged in" )
313328
314329 upstream := clientConn .Permissions .ExtraData [upstreamExtraDataKey ].(Upstream )
315- upstreamConn , upstreamNewChans , upstreamReqs , err := connectToUpstream (ctx , upstream )
330+ _ , user := parseAuthUser (clientConn .User ())
331+ upstreamConn , upstreamNewChans , upstreamReqs , err := connectToUpstream (ctx , upstream , user )
316332 if err != nil {
317- logger .WithError (err ).Error ("failed to connect to upstream" )
333+ logger = logger .WithError (err )
334+ if user != "" {
335+ logger = logger .WithField ("user" , user )
336+ }
337+
338+ const msg = "failed to connect to upstream"
339+ // Don't log as an error if it is a client auth error on the last host in the chain and the user is overridden
340+ // to avoid log noise in case a non-existent user is requested
341+ if authErr , ok := errors.AsType [* upstreamAuthFailureError ](err ); ok && authErr .isTargetHost && user != "" {
342+ logger .Debug (msg )
343+ } else {
344+ logger .Error (msg )
345+ }
318346
319347 return
320348 }
@@ -420,22 +448,40 @@ func handleConnectionError(ctx context.Context, err error) {
420448 logger .WithError (err ).Error ("failed to handshake client" )
421449}
422450
451+ // parseAuthUser extracts upstreamID and optional upstreamUser (overrides the default upstream user)
452+ // from the "user name" field of the SSH_MSG_USERAUTH_REQUEST request (the `user` in the `ssh user@hostname` command)
453+ // The optional user is appended to the upstreamID after the `_` delimiter:
454+ // 3b07781fc52d4427b3f4e83f16abb104@ssh.dstack.example.com - log in as the default job user
455+ // 3b07781fc52d4427b3f4e83f16abb104_root@ssh.dstack.example.com - log in as `root`
456+ func parseAuthUser (user string ) (upstreamID string , upstreamUser string ) {
457+ upstreamID , upstreamUser , _ = strings .Cut (user , "_" )
458+ return upstreamID , upstreamUser
459+ }
460+
423461func connectToUpstream (
424- ctx context.Context ,
425- upstream Upstream ,
462+ ctx context.Context , upstream Upstream , user string ,
426463) (ssh.Conn , <- chan ssh.NewChannel , <- chan * ssh.Request , error ) {
464+ logger := log .GetLogger (ctx )
465+
427466 var conn ssh.Conn
428467 var chans <- chan ssh.NewChannel
429468 var reqs <- chan * ssh.Request
430469
431- for i , host := range upstream .hosts {
470+ // A target host is the last host in the Upstream.hosts chanin. All other hosts are jump hosts.
471+ targetHostIdx := len (upstream .hosts ) - 1
472+ for hostIdx , host := range upstream .hosts {
473+ isTargetHost := hostIdx == targetHostIdx
474+ hostUser := host .user
475+ if isTargetHost && user != "" {
476+ hostUser = user
477+ }
432478 config := & ssh.ClientConfig {
433479 Config : ssh.Config {
434480 KeyExchanges : allowedKeyExchanges ,
435481 Ciphers : allowedCiphers ,
436482 MACs : allowedMACs ,
437483 },
438- User : host . user ,
484+ User : hostUser ,
439485 Auth : []ssh.AuthMethod {
440486 ssh .PublicKeys (host .privateKey ),
441487 },
@@ -446,7 +492,7 @@ func connectToUpstream(
446492 var netConn net.Conn
447493 var err error
448494
449- if i == 0 {
495+ if hostIdx == 0 {
450496 d := net.Dialer {
451497 Timeout : upstreamDialTimeout ,
452498 }
@@ -457,14 +503,31 @@ func connectToUpstream(
457503 netConn , err = client .Dial ("tcp" , host .address )
458504 }
459505
506+ var hostType string
507+ if isTargetHost {
508+ hostType = "target"
509+ } else {
510+ hostType = "jump"
511+ }
512+
460513 if err != nil {
461- return nil , nil , nil , fmt .Errorf ("dial upstream %d %s: %w" , i , host .address , err )
514+ return nil , nil , nil , fmt .Errorf ("dial %s host # %d %s: %w" , hostType , hostIdx , host .address , err )
462515 }
463516
464517 conn , chans , reqs , err = ssh .NewClientConn (netConn , host .address , config )
465518 if err != nil {
466- return nil , nil , nil , fmt .Errorf ("create SSH connection %d %s: %w" , i , host .address , err )
519+ if isClientAuthFailureError (err ) {
520+ err = & upstreamAuthFailureError {
521+ sshErr : err ,
522+ isTargetHost : isTargetHost ,
523+ }
524+ }
525+
526+ return nil , nil , nil , fmt .Errorf (
527+ "create SSH connection to %s host #%d %s: %w" , hostType , hostIdx , host .address , err )
467528 }
529+
530+ logger .Tracef ("connected to %s host #%d %s" , hostType , hostIdx , host .address )
468531 }
469532
470533 return conn , chans , reqs , nil
@@ -611,3 +674,12 @@ func getSSHError(err error) error {
611674
612675 return nil
613676}
677+
678+ func isClientAuthFailureError (err error ) bool {
679+ sshErr := getSSHError (err )
680+ if sshErr == nil {
681+ return false
682+ }
683+ // https://github.com/golang/crypto/blob/982eaa62dfb7273603b97fc1835561450096f3bd/ssh/client_auth.go#L118
684+ return strings .Contains (sshErr .Error (), "unable to authenticate" )
685+ }
0 commit comments