diff --git a/internal/sshproxy/server.go b/internal/sshproxy/server.go index 9cde3f5..e52ce19 100644 --- a/internal/sshproxy/server.go +++ b/internal/sshproxy/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "slices" "strconv" "strings" "sync" @@ -28,6 +29,20 @@ const ( upstreamDialTimeout = time.Second * 10 ) +var blacklistedGlobalRequests = []string{ + // Host key update mechanism for SSH: https://www.ietf.org/archive/id/draft-miller-sshm-hostkey-update-02.html + // Reasons to blacklist: + // 1. Signature check always fail as the signed data contains session identifier, which is not the same on client + // and upstream side, since they don't talk directly but through sshproxy (there are two SSH transport sessions + // with their own unique identifiers). + // 2. Even if it worked somehow, we don't want to inflate user's known_hosts file with garbage records, + // since container host keys are ephemeral -- they are generated on dstack-runner startup (= unique for each job). + "hostkeys", + "hostkeys-00@openssh.com", + "hostkeys-prove", + "hostkeys-prove-00@openssh.com", +} + type direction string var ( @@ -386,15 +401,22 @@ func bridgeGlobalRequests(ctx context.Context, dir direction, inReqs <-chan *ssh logger := log.GetLogger(ctx).WithField("dir", dir) for req := range inReqs { logger := logger.WithField("type", req.Type) - logger.Trace("global request") - reply, payload, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload) - if req.WantReply { - _ = req.Reply(reply, payload) - } + if slices.Contains(blacklistedGlobalRequests, req.Type) { + logger.Trace("blacklisted global request, ignoring") + if req.WantReply { + _ = req.Reply(false, nil) + } + } else { + logger.Trace("global request") + ok, payload, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload) + if req.WantReply { + _ = req.Reply(ok, payload) + } - if err != nil && !isClosedError(err) { - logger.WithError(err).Error("failed to forward global request") + if err != nil && !isClosedError(err) { + logger.WithError(err).Error("failed to forward global request") + } } } } @@ -488,9 +510,9 @@ func bridgeChannelRequests(ctx context.Context, dir direction, inReqs <-chan *ss logger := logger.WithField("type", req.Type) logger.Trace("request") - reply, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload) + ok, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload) if req.WantReply { - _ = req.Reply(reply, nil) + _ = req.Reply(ok, nil) } if err != nil && !isClosedError(err) {