@@ -29,6 +29,7 @@ import (
2929 "github.com/dstackai/dstack/runner/internal/connections"
3030 "github.com/dstackai/dstack/runner/internal/log"
3131 "github.com/dstackai/dstack/runner/internal/schemas"
32+ "github.com/dstackai/dstack/runner/internal/ssh"
3233 "github.com/dstackai/dstack/runner/internal/types"
3334)
3435
@@ -54,7 +55,8 @@ type RunExecutor struct {
5455 tempDir string
5556 homeDir string
5657 archiveDir string
57- sshPort int
58+ sshd ssh.SshdManager
59+
5860 currentUid uint32
5961
6062 run schemas.Run
@@ -89,7 +91,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
8991func (s * stubConnectionTracker ) Track (ticker <- chan time.Time ) {}
9092func (s * stubConnectionTracker ) Stop () {}
9193
92- func NewRunExecutor (tempDir string , homeDir string , sshPort int ) (* RunExecutor , error ) {
94+ func NewRunExecutor (tempDir string , homeDir string , sshd ssh. SshdManager ) (* RunExecutor , error ) {
9395 mu := & sync.RWMutex {}
9496 timestamp := NewMonotonicTimestamp ()
9597 user , err := osuser .Current ()
@@ -110,7 +112,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
110112 return nil , fmt .Errorf ("initialize procfs: %w" , err )
111113 }
112114 connectionTracker = connections .NewConnectionTracker (connections.ConnectionTrackerConfig {
113- Port : uint64 (sshPort ),
115+ Port : uint64 (sshd . Port () ),
114116 MinConnDuration : 10 * time .Second , // shorter connections are likely from dstack-server
115117 Procfs : proc ,
116118 })
@@ -123,7 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
123125 tempDir : tempDir ,
124126 homeDir : homeDir ,
125127 archiveDir : filepath .Join (tempDir , "file_archives" ),
126- sshPort : sshPort ,
128+ sshd : sshd ,
127129 currentUid : uid ,
128130 jobUid : - 1 ,
129131 jobGid : - 1 ,
@@ -466,8 +468,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
466468 }
467469
468470 // As of 2024-11-29, ex.homeDir is always set to /root
469- rootSSHDir , err := prepareSSHDir (- 1 , - 1 , ex .homeDir )
470- if err != nil {
471+ if _ , err := prepareSSHDir (- 1 , - 1 , ex .homeDir ); err != nil {
471472 log .Warning (ctx , "failed to prepare ssh dir" , "home" , ex .homeDir , "err" , err )
472473 }
473474 userSSHDir := ""
@@ -484,14 +485,6 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
484485 userSSHDir , err = prepareSSHDir (uid , gid , homeDir )
485486 if err != nil {
486487 log .Warning (ctx , "failed to prepare ssh dir" , "home" , homeDir , "err" , err )
487- } else {
488- rootSSHKeysPath := filepath .Join (rootSSHDir , "authorized_keys" )
489- userSSHKeysPath := filepath .Join (userSSHDir , "authorized_keys" )
490- restoreUserSSHKeys := backupFile (ctx , userSSHKeysPath )
491- defer restoreUserSSHKeys (ctx )
492- if err := copyAuthorizedKeys (rootSSHKeysPath , uid , gid , userSSHKeysPath ); err != nil {
493- log .Warning (ctx , "failed to copy authorized keys" , "path" , homeDir , "err" , err )
494- }
495488 }
496489 } else {
497490 log .Trace (ctx , "homeDir is not accessible, skipping provisioning" , "path" , homeDir )
@@ -504,9 +497,12 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
504497
505498 if ex .jobSpec .SSHKey != nil && userSSHDir != "" {
506499 err := configureSSH (
507- ex .jobSpec .SSHKey .Private , ex .jobSpec . SSHKey . Public , ex . clusterInfo .JobIPs , ex .sshPort ,
500+ ex .jobSpec .SSHKey .Private , ex .clusterInfo .JobIPs , ex .sshd . Port () ,
508501 uid , gid , userSSHDir ,
509502 )
503+ if err == nil {
504+ err = ex .sshd .AddAuthorizedKeys (ctx , ex .jobSpec .SSHKey .Public )
505+ }
510506 if err != nil {
511507 log .Warning (ctx , "failed to configure SSH" , "err" , err )
512508 }
@@ -914,7 +910,7 @@ func includeDstackProfile(profilePath string, dstackProfilePath string) error {
914910 return nil
915911}
916912
917- func configureSSH (private string , public string , ips []string , port int , uid int , gid int , sshDir string ) error {
913+ func configureSSH (private string , ips []string , port int , uid int , gid int , sshDir string ) error {
918914 privatePath := filepath .Join (sshDir , "dstack_job" )
919915 privateFile , err := os .OpenFile (privatePath , os .O_TRUNC | os .O_WRONLY | os .O_CREATE , 0o600 )
920916 if err != nil {
@@ -928,19 +924,9 @@ func configureSSH(private string, public string, ips []string, port int, uid int
928924 return fmt .Errorf ("write private key: %w" , err )
929925 }
930926
931- akPath := filepath .Join (sshDir , "authorized_keys" )
932- akFile , err := os .OpenFile (akPath , os .O_APPEND | os .O_WRONLY | os .O_CREATE , 0o600 )
933- if err != nil {
934- return fmt .Errorf ("open authorized_keys: %w" , err )
935- }
936- defer akFile .Close ()
937- if err := os .Chown (akPath , uid , gid ); err != nil {
938- return fmt .Errorf ("chown authorized_keys: %w" , err )
939- }
940- if _ , err := akFile .WriteString (public ); err != nil {
941- return fmt .Errorf ("write public key: %w" , err )
942- }
943-
927+ // TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job
928+ // and add "Include ~/.dstack/ssh/config.d/*" directive to ~/.ssh/config if not present
929+ // instead of appending job hosts config directly (don't bloat user's ssh_config)
944930 configPath := filepath .Join (sshDir , "config" )
945931 configFile , err := os .OpenFile (configPath , os .O_APPEND | os .O_WRONLY | os .O_CREATE , 0o600 )
946932 if err != nil {
@@ -963,104 +949,3 @@ func configureSSH(private string, public string, ips []string, port int, uid int
963949 }
964950 return nil
965951}
966-
967- // A makeshift solution to deliver authorized_keys to a non-root user
968- // without modifying the existing API/bootstrap process
969- // TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,
970- // not the outer sh script that launches sshd and runner
971- func copyAuthorizedKeys (srcPath string , uid int , gid int , dstPath string ) error {
972- srcFile , err := os .Open (srcPath )
973- if err != nil {
974- return fmt .Errorf ("open source authorized_keys: %w" , err )
975- }
976- defer srcFile .Close ()
977-
978- dstExists := false
979- info , err := os .Stat (dstPath )
980- if err == nil {
981- dstExists = true
982- if info .IsDir () {
983- return fmt .Errorf ("is a directory: %s" , dstPath )
984- }
985- if err = os .Chmod (dstPath , 0o600 ); err != nil {
986- return fmt .Errorf ("chmod destination authorized_keys: %w" , err )
987- }
988- } else if ! errors .Is (err , os .ErrNotExist ) {
989- return fmt .Errorf ("stat destination authorized_keys: %w" , err )
990- }
991-
992- dstFile , err := os .OpenFile (dstPath , os .O_APPEND | os .O_WRONLY | os .O_CREATE , 0o600 )
993- if err != nil {
994- return fmt .Errorf ("open destination authorized_keys: %w" , err )
995- }
996- defer dstFile .Close ()
997-
998- if dstExists {
999- // visually separate our keys from existing ones
1000- if _ , err := dstFile .WriteString ("\n \n " ); err != nil {
1001- return fmt .Errorf ("write separator to authorized_keys: %w" , err )
1002- }
1003- }
1004- if _ , err := io .Copy (dstFile , srcFile ); err != nil {
1005- return fmt .Errorf ("copy authorized_keys: %w" , err )
1006- }
1007- if err = os .Chown (dstPath , uid , gid ); err != nil {
1008- return fmt .Errorf ("chown destination authorized_keys: %w" , err )
1009- }
1010-
1011- return nil
1012- }
1013-
1014- // backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`,
1015- // creates a new file with the same content, and returns restore function that
1016- // renames the backup back to the original name.
1017- // If the original file does not exist, restore function removes the file if it is created.
1018- // NB: A newly created file has default uid:gid and permissions, probably not
1019- // the same as the original file.
1020- func backupFile (ctx context.Context , path string ) func (context.Context ) {
1021- var existed bool
1022- backupPath := path + ".dstack.bak"
1023-
1024- restoreFunc := func (ctx context.Context ) {
1025- if ! existed {
1026- err := os .Remove (path )
1027- if err != nil && ! errors .Is (err , os .ErrNotExist ) {
1028- log .Error (ctx , "failed to remove" , "path" , path , "err" , err )
1029- }
1030- return
1031- }
1032- err := os .Rename (backupPath , path )
1033- if err != nil && ! errors .Is (err , os .ErrNotExist ) {
1034- log .Error (ctx , "failed to restore" , "path" , path , "err" , err )
1035- }
1036- }
1037-
1038- err := os .Rename (path , backupPath )
1039- if errors .Is (err , os .ErrNotExist ) {
1040- existed = false
1041- return restoreFunc
1042- }
1043- existed = true
1044- if err != nil {
1045- log .Error (ctx , "failed to back up" , "path" , path , "err" , err )
1046- return restoreFunc
1047- }
1048-
1049- src , err := os .Open (backupPath )
1050- if err != nil {
1051- log .Error (ctx , "failed to open backup src" , "path" , backupPath , "err" , err )
1052- return restoreFunc
1053- }
1054- defer src .Close ()
1055- dst , err := os .Create (path )
1056- if err != nil {
1057- log .Error (ctx , "failed to open backup dest" , "path" , path , "err" , err )
1058- return restoreFunc
1059- }
1060- defer dst .Close ()
1061- _ , err = io .Copy (dst , src )
1062- if err != nil {
1063- log .Error (ctx , "failed to copy backup" , "path" , backupPath , "err" , err )
1064- }
1065- return restoreFunc
1066- }
0 commit comments