Skip to content

Commit b23fae8

Browse files
authored
[runner] Rework env variables exporting (#2535)
Replace `~/.ssh/environment` (processed by OpenSSH) with a shell script exporting env variables. The shell script is sourced in `/etc/profile`, which is only executed by login shell. Advantages: * multi-line variables support (e.g, `DSTACK_NODES_IPS`) * no hard limit of 1000 variables * shared by all users * vars are not redefined by `pam_env` (`/etc/environment`) * it's unlikely that `/tmp/dstack_profile` and `/etc/profile` are located on a volume, unlike `~/.ssh`; no need for cleanup Tested supported shells: * dash (Debian and derivatives, BusyBox distros, e.g., Alpine) * bash as sh (Fedora and downstream distros) * bash (almost every distro) Not supported shells (do not use `/etc/profile`): * zsh * fish Fixes: #2371
1 parent 95c723d commit b23fae8

5 files changed

Lines changed: 51 additions & 57 deletions

File tree

runner/internal/executor/executor.go

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,18 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
310310
// `env` interpolation feature is postponed to some future release
311311
envMap.Update(ex.jobSpec.Env, false)
312312

313+
const profilePath = "/etc/profile"
314+
const dstackProfilePath = "/tmp/dstack_profile"
315+
if err := writeDstackProfile(envMap, dstackProfilePath); err != nil {
316+
log.Warning(ctx, "failed to write dstack_profile", "path", dstackProfilePath, "err", err)
317+
} else if err := includeDstackProfile(profilePath, dstackProfilePath); err != nil {
318+
log.Warning(ctx, "failed to include dstack_profile", "path", profilePath, "err", err)
319+
}
320+
313321
// As of 2024-11-29, ex.homeDir is always set to /root
314322
rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir)
315323
if err != nil {
316324
log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err)
317-
} else {
318-
rootSSHEnvPath := filepath.Join(rootSSHDir, "environment")
319-
restoreRootSSHEnv := backupFile(ctx, rootSSHEnvPath)
320-
defer restoreRootSSHEnv(ctx)
321-
if err := writeSSHEnvironment(envMap, -1, -1, rootSSHEnvPath); err != nil {
322-
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
323-
}
324325
}
325326
userSSHDir := ""
326327
uid := -1
@@ -337,12 +338,6 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
337338
if err != nil {
338339
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
339340
} else {
340-
userSSHEnvPath := filepath.Join(userSSHDir, "environment")
341-
restoreUserSSHEnv := backupFile(ctx, userSSHEnvPath)
342-
defer restoreUserSSHEnv(ctx)
343-
if err := writeSSHEnvironment(envMap, uid, gid, userSSHEnvPath); err != nil {
344-
log.Warning(ctx, "failed to write SSH environment", "path", homeDir, "err", err)
345-
}
346341
rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys")
347342
userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys")
348343
restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath)
@@ -676,53 +671,40 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
676671
return sshDir, nil
677672
}
678673

679-
func writeSSHEnvironment(env map[string]string, uid int, gid int, envPath string) error {
680-
info, err := os.Stat(envPath)
681-
if err == nil {
682-
if info.IsDir() {
683-
return fmt.Errorf("is a directory: %s", envPath)
684-
}
685-
if err = os.Chmod(envPath, 0o600); err != nil {
686-
return err
687-
}
688-
} else if !errors.Is(err, os.ErrNotExist) {
689-
return err
690-
}
691-
692-
envFile, err := os.OpenFile(envPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
674+
func writeDstackProfile(env map[string]string, path string) error {
675+
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
693676
if err != nil {
694677
return err
695678
}
696-
defer envFile.Close()
679+
defer file.Close()
697680
for key, value := range env {
698681
switch key {
699-
case "USER", "HOME", "SHELL", "PWD", "_":
700-
continue
701-
}
702-
// sshd doesn't support multiline variable values in .ssh/environment:
703-
// VAR1=line1
704-
// line2
705-
// line3
706-
// VAR2=singleline
707-
// leads to:
708-
// Bad line 2 in /root/.ssh/environment
709-
// Bad line 3 in /root/.ssh/environment
710-
// Assuming a single trailing newline is not a big deal
711-
value := strings.TrimSuffix(value, "\n")
712-
// If there is any non-trailing newline, or more than one
713-
// trailing newline, skip the variable
714-
if strings.Contains(value, "\n") {
682+
case "HOSTNAME", "USER", "HOME", "SHELL", "SHLVL", "PWD", "_":
715683
continue
716684
}
717-
line := fmt.Sprintf("%s=%s\n", key, value)
718-
if _, err = envFile.WriteString(line); err != nil {
685+
line := fmt.Sprintf("export %s='%s'\n", key, strings.ReplaceAll(value, `'`, `'"'"'`))
686+
if _, err = file.WriteString(line); err != nil {
719687
return err
720688
}
721689
}
722-
if err = os.Chown(envPath, uid, gid); err != nil {
690+
if err = os.Chmod(path, 0o644); err != nil {
723691
return err
724692
}
693+
return nil
694+
}
725695

696+
func includeDstackProfile(profilePath string, dstackProfilePath string) error {
697+
file, err := os.OpenFile(profilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
698+
if err != nil {
699+
return err
700+
}
701+
defer file.Close()
702+
if _, err = file.WriteString(fmt.Sprintf("\n. '%s'\n", dstackProfilePath)); err != nil {
703+
return err
704+
}
705+
if err = os.Chmod(profilePath, 0o644); err != nil {
706+
return err
707+
}
726708
return nil
727709
}
728710

runner/internal/executor/executor_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"os"
10+
"os/exec"
1011
"path/filepath"
1112
"testing"
1213
"time"
@@ -207,3 +208,22 @@ func makeCodeTar(t *testing.T, path string) {
207208
}
208209
require.NoError(t, tw.Close())
209210
}
211+
212+
func TestWriteDstackProfile(t *testing.T) {
213+
testCases := []string{
214+
"",
215+
"string 'with 'single' quotes",
216+
"multi\nline\tstring",
217+
}
218+
tmp := t.TempDir()
219+
path := tmp + "/dstack_profile"
220+
script := fmt.Sprintf(`. '%s'; printf '%%s' "$VAR"`, path)
221+
for _, value := range testCases {
222+
env := map[string]string{"VAR": value}
223+
writeDstackProfile(env, path)
224+
cmd := exec.Command("/bin/sh", "-c", script)
225+
out, err := cmd.Output()
226+
assert.NoError(t, err)
227+
assert.Equal(t, value, string(out))
228+
}
229+
}

runner/internal/shim/docker.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,6 @@ func getSSHShellCommands(openSSHPort int, publicSSHKey string) []string {
896896
"chmod 700 ~/.ssh",
897897
fmt.Sprintf("echo '%s' > ~/.ssh/authorized_keys", publicSSHKey),
898898
"chmod 600 ~/.ssh/authorized_keys",
899-
`if [ -f ~/.profile ]; then sed -ie '1s@^@export PATH="'"$PATH"':$PATH"\n\n@' ~/.profile; fi`,
900899
// regenerate host keys
901900
"rm -rf /etc/ssh/ssh_host_*",
902901
"ssh-keygen -A > /dev/null",
@@ -914,7 +913,6 @@ func getSSHShellCommands(openSSHPort int, publicSSHKey string) []string {
914913
" -o PidFile=none"+
915914
" -o PasswordAuthentication=no"+
916915
" -o AllowTcpForwarding=yes"+
917-
" -o PermitUserEnvironment=yes"+
918916
" -o ClientAliveInterval=30"+
919917
" -o ClientAliveCountMax=4",
920918
openSSHPort,

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
555555
)
556556

557557

558-
def get_docker_commands(
559-
authorized_keys: List[str], fix_path_in_dot_profile: bool = True
560-
) -> List[str]:
558+
def get_docker_commands(authorized_keys: list[str]) -> list[str]:
561559
authorized_keys_content = "\n".join(authorized_keys).strip()
562560
commands = [
563561
# save and unset ld.so variables
@@ -581,9 +579,6 @@ def get_docker_commands(
581579
"chmod 700 ~/.ssh",
582580
f"echo '{authorized_keys_content}' > ~/.ssh/authorized_keys",
583581
"chmod 600 ~/.ssh/authorized_keys",
584-
r"""if [ -f ~/.profile ]; then sed -ie '1s@^@export PATH="'"$PATH"':$PATH"\n\n@' ~/.profile; fi"""
585-
if fix_path_in_dot_profile
586-
else ":",
587582
# regenerate host keys
588583
"rm -rf /etc/ssh/ssh_host_*",
589584
"ssh-keygen -A > /dev/null",
@@ -601,7 +596,6 @@ def get_docker_commands(
601596
" -o PidFile=none"
602597
" -o PasswordAuthentication=no"
603598
" -o AllowTcpForwarding=yes"
604-
" -o PermitUserEnvironment=yes"
605599
" -o ClientAliveInterval=30"
606600
" -o ClientAliveCountMax=4"
607601
),

src/dstack/_internal/core/backends/runpod/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _clean_stale_container_registry_auths(self) -> None:
260260

261261

262262
def _get_docker_args(authorized_keys: List[str]) -> str:
263-
commands = get_docker_commands(authorized_keys, False)
263+
commands = get_docker_commands(authorized_keys)
264264
command = " && ".join(commands)
265265
docker_args = {"cmd": [command], "entrypoint": ["/bin/sh", "-c"]}
266266
docker_args_escaped = json.dumps(json.dumps(docker_args)).strip('"')

0 commit comments

Comments
 (0)