Skip to content

Commit 71dcd5d

Browse files
committed
Implement DSTACK_MPI_HOSTFILE
1 parent 1e16fe1 commit 71dcd5d

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

runner/internal/executor/executor.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
257257
gpus_per_node_num := ex.clusterInfo.GPUSPerJob
258258
gpus_num := nodes_num * gpus_per_node_num
259259

260+
mpiHostfilePath := filepath.Join(ex.homeDir, ".dstack/mpi/hostfile")
261+
260262
jobEnvs := map[string]string{
261263
"DSTACK_RUN_ID": ex.run.Id,
262264
"DSTACK_JOB_ID": ex.jobSubmission.Id,
@@ -268,6 +270,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
268270
"DSTACK_NODES_NUM": strconv.Itoa(nodes_num),
269271
"DSTACK_GPUS_PER_NODE": strconv.Itoa(gpus_per_node_num),
270272
"DSTACK_GPUS_NUM": strconv.Itoa(gpus_num),
273+
"DSTACK_MPI_HOSTFILE": mpiHostfilePath,
271274
}
272275

273276
// Call buildLDLibraryPathEnv and update jobEnvs if no error occurs
@@ -390,6 +393,11 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
390393
}
391394
}
392395

396+
err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpus_per_node_num, mpiHostfilePath)
397+
if err != nil {
398+
return err
399+
}
400+
393401
cmd.Env = envMap.Render()
394402

395403
log.Trace(ctx, "Starting exec", "cmd", cmd.String(), "working_dir", cmd.Dir, "env", cmd.Env)
@@ -696,6 +704,34 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
696704
return sshDir, nil
697705
}
698706

707+
func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path string) error {
708+
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
709+
return err
710+
}
711+
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
712+
if err != nil {
713+
return err
714+
}
715+
defer file.Close()
716+
nonEmptyIps := []string{}
717+
for _, ip := range ips {
718+
if ip != "" {
719+
nonEmptyIps = append(nonEmptyIps, ip)
720+
}
721+
}
722+
if len(nonEmptyIps) == len(ips) {
723+
for _, ip := range nonEmptyIps {
724+
line := fmt.Sprintf("%s slots=%d\n", ip, gpus_per_node)
725+
if _, err = file.WriteString(line); err != nil {
726+
return err
727+
}
728+
}
729+
} else {
730+
log.Info(ctx, "creating empty MPI hostfile: no internal IPs assigned")
731+
}
732+
return nil
733+
}
734+
699735
func writeDstackProfile(env map[string]string, path string) error {
700736
file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
701737
if err != nil {

0 commit comments

Comments
 (0)