@@ -257,6 +257,8 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
257257 gpus_per_node_num := ex .clusterInfo .GPUSPerJob
258258 gpus_num := nodes_num * gpus_per_node_num
259259
260+ mpiHostfilePath := filepath .Join (ex .homeDir , ".dstack/mpi/hostfile" )
261+
260262 jobEnvs := map [string ]string {
261263 "DSTACK_RUN_ID" : ex .run .Id ,
262264 "DSTACK_JOB_ID" : ex .jobSubmission .Id ,
@@ -268,6 +270,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
268270 "DSTACK_NODES_NUM" : strconv .Itoa (nodes_num ),
269271 "DSTACK_GPUS_PER_NODE" : strconv .Itoa (gpus_per_node_num ),
270272 "DSTACK_GPUS_NUM" : strconv .Itoa (gpus_num ),
273+ "DSTACK_MPI_HOSTFILE" : mpiHostfilePath ,
271274 }
272275
273276 // Call buildLDLibraryPathEnv and update jobEnvs if no error occurs
@@ -390,6 +393,11 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
390393 }
391394 }
392395
396+ err = writeMpiHostfile (ctx , ex .clusterInfo .JobIPs , gpus_per_node_num , mpiHostfilePath )
397+ if err != nil {
398+ return err
399+ }
400+
393401 cmd .Env = envMap .Render ()
394402
395403 log .Trace (ctx , "Starting exec" , "cmd" , cmd .String (), "working_dir" , cmd .Dir , "env" , cmd .Env )
@@ -696,6 +704,34 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
696704 return sshDir , nil
697705}
698706
707+ func writeMpiHostfile (ctx context.Context , ips []string , gpus_per_node int , path string ) error {
708+ if err := os .MkdirAll (filepath .Dir (path ), 0o755 ); err != nil {
709+ return err
710+ }
711+ file , err := os .OpenFile (path , os .O_CREATE | os .O_TRUNC | os .O_WRONLY , 0o644 )
712+ if err != nil {
713+ return err
714+ }
715+ defer file .Close ()
716+ nonEmptyIps := []string {}
717+ for _ , ip := range ips {
718+ if ip != "" {
719+ nonEmptyIps = append (nonEmptyIps , ip )
720+ }
721+ }
722+ if len (nonEmptyIps ) == len (ips ) {
723+ for _ , ip := range nonEmptyIps {
724+ line := fmt .Sprintf ("%s slots=%d\n " , ip , gpus_per_node )
725+ if _ , err = file .WriteString (line ); err != nil {
726+ return err
727+ }
728+ }
729+ } else {
730+ log .Info (ctx , "creating empty MPI hostfile: no internal IPs assigned" )
731+ }
732+ return nil
733+ }
734+
699735func writeDstackProfile (env map [string ]string , path string ) error {
700736 file , err := os .OpenFile (path , os .O_CREATE | os .O_TRUNC | os .O_WRONLY , 0o644 )
701737 if err != nil {
0 commit comments