Skip to content

Commit dde66b5

Browse files
authored
[runner] Streamline authorized_keys management (#3435)
* Add keys by the runner, not the shell script * Don't touch user's ~/.ssh/authorized_keys, use our own file * Share that file between all users Part-of: #3419
1 parent d472448 commit dde66b5

File tree

17 files changed

+191
-214
lines changed

17 files changed

+191
-214
lines changed

runner/cmd/runner/main.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"os"
@@ -30,6 +31,7 @@ func mainInner() int {
3031
var homeDir string
3132
var httpPort int
3233
var sshPort int
34+
var sshAuthorizedKeys []string
3335
var logLevel int
3436

3537
cmd := &cli.Command{
@@ -76,9 +78,14 @@ func mainInner() int {
7678
Value: consts.RunnerSSHPort,
7779
Destination: &sshPort,
7880
},
81+
&cli.StringSliceFlag{
82+
Name: "ssh-authorized-key",
83+
Usage: "dstack server or user authorized key. May be specified multiple times",
84+
Destination: &sshAuthorizedKeys,
85+
},
7986
},
80-
Action: func(cxt context.Context, cmd *cli.Command) error {
81-
return start(cxt, tempDir, homeDir, httpPort, sshPort, logLevel, Version)
87+
Action: func(ctx context.Context, cmd *cli.Command) error {
88+
return start(ctx, tempDir, homeDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version)
8289
},
8390
},
8491
},
@@ -95,7 +102,7 @@ func mainInner() int {
95102
return 0
96103
}
97104

98-
func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, logLevel int, version string) error {
105+
func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error {
99106
if err := os.MkdirAll(tempDir, 0o755); err != nil {
100107
return fmt.Errorf("create temp directory: %w", err)
101108
}
@@ -114,15 +121,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
114121
log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
115122
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))
116123

117-
server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshPort, version)
118-
if err != nil {
119-
return fmt.Errorf("create server: %w", err)
124+
// To ensure that all components of the authorized_keys path are owned by root and no directories
125+
// are group or world writable, as required by sshd with "StrictModes yes" (the default value),
126+
// we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created
127+
// in Sshd.Prepare())
128+
// See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272
129+
if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) {
130+
if err := os.Chown("/dstack", 0, 0); err != nil {
131+
return fmt.Errorf("chown dstack dir: %w", err)
132+
}
133+
if err := os.Chmod("/dstack", 0o755); err != nil {
134+
return fmt.Errorf("chmod dstack dir: %w", err)
135+
}
136+
} else if err != nil {
137+
return fmt.Errorf("create dstack dir: %w", err)
138+
}
139+
if err := os.RemoveAll("/dstack/ssh"); err != nil {
140+
return fmt.Errorf("remove dstack ssh dir: %w", err)
120141
}
121142

122143
sshd := ssh.NewSshd("/usr/sbin/sshd")
123-
if err := sshd.Prepare(ctx, "/dstack/ssh/conf", "/dstack/ssh/log", sshPort); err != nil {
144+
if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil {
124145
return fmt.Errorf("prepare sshd: %w", err)
125146
}
147+
if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil {
148+
return fmt.Errorf("add authorized keys: %w", err)
149+
}
126150
if err := sshd.Start(ctx); err != nil {
127151
return fmt.Errorf("start sshd: %w", err)
128152
}
@@ -132,6 +156,10 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
132156
}
133157
}()
134158

159+
server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version)
160+
if err != nil {
161+
return fmt.Errorf("create server: %w", err)
162+
}
135163
log.Trace(ctx, "Starting API server", "port", httpPort)
136164
if err := server.Run(ctx); err != nil {
137165
return fmt.Errorf("server failed: %w", err)

runner/cmd/shim/main.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,6 @@ func mainInner() int {
147147
Destination: &args.Docker.Privileged,
148148
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
149149
},
150-
&cli.StringFlag{
151-
Name: "ssh-key",
152-
Usage: "Public SSH key",
153-
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
154-
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
155-
},
156150
&cli.StringFlag{
157151
Name: "pjrt-device",
158152
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",

runner/internal/executor/executor.go

Lines changed: 15 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/dstackai/dstack/runner/internal/connections"
3030
"github.com/dstackai/dstack/runner/internal/log"
3131
"github.com/dstackai/dstack/runner/internal/schemas"
32+
"github.com/dstackai/dstack/runner/internal/ssh"
3233
"github.com/dstackai/dstack/runner/internal/types"
3334
)
3435

@@ -54,7 +55,8 @@ type RunExecutor struct {
5455
tempDir string
5556
homeDir string
5657
archiveDir string
57-
sshPort int
58+
sshd ssh.SshdManager
59+
5860
currentUid uint32
5961

6062
run schemas.Run
@@ -89,7 +91,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
8991
func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {}
9092
func (s *stubConnectionTracker) Stop() {}
9193

92-
func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor, error) {
94+
func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
9395
mu := &sync.RWMutex{}
9496
timestamp := NewMonotonicTimestamp()
9597
user, err := osuser.Current()
@@ -110,7 +112,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
110112
return nil, fmt.Errorf("initialize procfs: %w", err)
111113
}
112114
connectionTracker = connections.NewConnectionTracker(connections.ConnectionTrackerConfig{
113-
Port: uint64(sshPort),
115+
Port: uint64(sshd.Port()),
114116
MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server
115117
Procfs: proc,
116118
})
@@ -123,7 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
123125
tempDir: tempDir,
124126
homeDir: homeDir,
125127
archiveDir: filepath.Join(tempDir, "file_archives"),
126-
sshPort: sshPort,
128+
sshd: sshd,
127129
currentUid: uid,
128130
jobUid: -1,
129131
jobGid: -1,
@@ -466,8 +468,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
466468
}
467469

468470
// As of 2024-11-29, ex.homeDir is always set to /root
469-
rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir)
470-
if err != nil {
471+
if _, err := prepareSSHDir(-1, -1, ex.homeDir); err != nil {
471472
log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err)
472473
}
473474
userSSHDir := ""
@@ -484,14 +485,6 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
484485
userSSHDir, err = prepareSSHDir(uid, gid, homeDir)
485486
if err != nil {
486487
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
487-
} else {
488-
rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys")
489-
userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys")
490-
restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath)
491-
defer restoreUserSSHKeys(ctx)
492-
if err := copyAuthorizedKeys(rootSSHKeysPath, uid, gid, userSSHKeysPath); err != nil {
493-
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
494-
}
495488
}
496489
} else {
497490
log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir)
@@ -504,9 +497,12 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
504497

505498
if ex.jobSpec.SSHKey != nil && userSSHDir != "" {
506499
err := configureSSH(
507-
ex.jobSpec.SSHKey.Private, ex.jobSpec.SSHKey.Public, ex.clusterInfo.JobIPs, ex.sshPort,
500+
ex.jobSpec.SSHKey.Private, ex.clusterInfo.JobIPs, ex.sshd.Port(),
508501
uid, gid, userSSHDir,
509502
)
503+
if err == nil {
504+
err = ex.sshd.AddAuthorizedKeys(ctx, ex.jobSpec.SSHKey.Public)
505+
}
510506
if err != nil {
511507
log.Warning(ctx, "failed to configure SSH", "err", err)
512508
}
@@ -914,7 +910,7 @@ func includeDstackProfile(profilePath string, dstackProfilePath string) error {
914910
return nil
915911
}
916912

917-
func configureSSH(private string, public string, ips []string, port int, uid int, gid int, sshDir string) error {
913+
func configureSSH(private string, ips []string, port int, uid int, gid int, sshDir string) error {
918914
privatePath := filepath.Join(sshDir, "dstack_job")
919915
privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600)
920916
if err != nil {
@@ -928,19 +924,9 @@ func configureSSH(private string, public string, ips []string, port int, uid int
928924
return fmt.Errorf("write private key: %w", err)
929925
}
930926

931-
akPath := filepath.Join(sshDir, "authorized_keys")
932-
akFile, err := os.OpenFile(akPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
933-
if err != nil {
934-
return fmt.Errorf("open authorized_keys: %w", err)
935-
}
936-
defer akFile.Close()
937-
if err := os.Chown(akPath, uid, gid); err != nil {
938-
return fmt.Errorf("chown authorized_keys: %w", err)
939-
}
940-
if _, err := akFile.WriteString(public); err != nil {
941-
return fmt.Errorf("write public key: %w", err)
942-
}
943-
927+
// TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job
928+
// and add "Include ~/.dstack/ssh/config.d/*" directive to ~/.ssh/config if not present
929+
// instead of appending job hosts config directly (don't bloat user's ssh_config)
944930
configPath := filepath.Join(sshDir, "config")
945931
configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
946932
if err != nil {
@@ -963,104 +949,3 @@ func configureSSH(private string, public string, ips []string, port int, uid int
963949
}
964950
return nil
965951
}
966-
967-
// A makeshift solution to deliver authorized_keys to a non-root user
968-
// without modifying the existing API/bootstrap process
969-
// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,
970-
// not the outer sh script that launches sshd and runner
971-
func copyAuthorizedKeys(srcPath string, uid int, gid int, dstPath string) error {
972-
srcFile, err := os.Open(srcPath)
973-
if err != nil {
974-
return fmt.Errorf("open source authorized_keys: %w", err)
975-
}
976-
defer srcFile.Close()
977-
978-
dstExists := false
979-
info, err := os.Stat(dstPath)
980-
if err == nil {
981-
dstExists = true
982-
if info.IsDir() {
983-
return fmt.Errorf("is a directory: %s", dstPath)
984-
}
985-
if err = os.Chmod(dstPath, 0o600); err != nil {
986-
return fmt.Errorf("chmod destination authorized_keys: %w", err)
987-
}
988-
} else if !errors.Is(err, os.ErrNotExist) {
989-
return fmt.Errorf("stat destination authorized_keys: %w", err)
990-
}
991-
992-
dstFile, err := os.OpenFile(dstPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
993-
if err != nil {
994-
return fmt.Errorf("open destination authorized_keys: %w", err)
995-
}
996-
defer dstFile.Close()
997-
998-
if dstExists {
999-
// visually separate our keys from existing ones
1000-
if _, err := dstFile.WriteString("\n\n"); err != nil {
1001-
return fmt.Errorf("write separator to authorized_keys: %w", err)
1002-
}
1003-
}
1004-
if _, err := io.Copy(dstFile, srcFile); err != nil {
1005-
return fmt.Errorf("copy authorized_keys: %w", err)
1006-
}
1007-
if err = os.Chown(dstPath, uid, gid); err != nil {
1008-
return fmt.Errorf("chown destination authorized_keys: %w", err)
1009-
}
1010-
1011-
return nil
1012-
}
1013-
1014-
// backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`,
1015-
// creates a new file with the same content, and returns restore function that
1016-
// renames the backup back to the original name.
1017-
// If the original file does not exist, restore function removes the file if it is created.
1018-
// NB: A newly created file has default uid:gid and permissions, probably not
1019-
// the same as the original file.
1020-
func backupFile(ctx context.Context, path string) func(context.Context) {
1021-
var existed bool
1022-
backupPath := path + ".dstack.bak"
1023-
1024-
restoreFunc := func(ctx context.Context) {
1025-
if !existed {
1026-
err := os.Remove(path)
1027-
if err != nil && !errors.Is(err, os.ErrNotExist) {
1028-
log.Error(ctx, "failed to remove", "path", path, "err", err)
1029-
}
1030-
return
1031-
}
1032-
err := os.Rename(backupPath, path)
1033-
if err != nil && !errors.Is(err, os.ErrNotExist) {
1034-
log.Error(ctx, "failed to restore", "path", path, "err", err)
1035-
}
1036-
}
1037-
1038-
err := os.Rename(path, backupPath)
1039-
if errors.Is(err, os.ErrNotExist) {
1040-
existed = false
1041-
return restoreFunc
1042-
}
1043-
existed = true
1044-
if err != nil {
1045-
log.Error(ctx, "failed to back up", "path", path, "err", err)
1046-
return restoreFunc
1047-
}
1048-
1049-
src, err := os.Open(backupPath)
1050-
if err != nil {
1051-
log.Error(ctx, "failed to open backup src", "path", backupPath, "err", err)
1052-
return restoreFunc
1053-
}
1054-
defer src.Close()
1055-
dst, err := os.Create(path)
1056-
if err != nil {
1057-
log.Error(ctx, "failed to open backup dest", "path", path, "err", err)
1058-
return restoreFunc
1059-
}
1060-
defer dst.Close()
1061-
_, err = io.Copy(dst, src)
1062-
if err != nil {
1063-
log.Error(ctx, "failed to copy backup", "path", backupPath, "err", err)
1064-
}
1065-
return restoreFunc
1066-
}

runner/internal/executor/executor_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
208208
_ = os.Mkdir(temp, 0o700)
209209
home := filepath.Join(baseDir, "home")
210210
_ = os.Mkdir(home, 0o700)
211-
ex, _ := NewRunExecutor(temp, home, 10022)
211+
ex, _ := NewRunExecutor(temp, home, new(sshdMock))
212212
ex.SetJob(body)
213213
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
214214
ex.setJobWorkingDir(context.Background())
@@ -341,6 +341,24 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {
341341
}
342342
}
343343

344+
type sshdMock struct{}
345+
346+
func (d *sshdMock) Port() int {
347+
return 0
348+
}
349+
350+
func (d *sshdMock) Start(context.Context) error {
351+
return nil
352+
}
353+
354+
func (d *sshdMock) Stop(context.Context) error {
355+
return nil
356+
}
357+
358+
func (d *sshdMock) AddAuthorizedKeys(context.Context, ...string) error {
359+
return nil
360+
}
361+
344362
func combineLogMessages(logHistory []schemas.LogEvent) string {
345363
var logOutput bytes.Buffer
346364
for _, logEvent := range logHistory {

runner/internal/runner/api/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/dstackai/dstack/runner/internal/executor"
1212
"github.com/dstackai/dstack/runner/internal/log"
1313
"github.com/dstackai/dstack/runner/internal/metrics"
14+
"github.com/dstackai/dstack/runner/internal/ssh"
1415
)
1516

1617
type Server struct {
@@ -33,9 +34,9 @@ type Server struct {
3334
version string
3435
}
3536

36-
func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshPort int, version string) (*Server, error) {
37+
func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) {
3738
r := api.NewRouter()
38-
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshPort)
39+
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd)
3940
if err != nil {
4041
return nil, err
4142
}

0 commit comments

Comments
 (0)