|
8 | 8 |
|
9 | 9 | import psutil |
10 | 10 | from dask.distributed import Client |
11 | | -from dask_jobqueue.slurm import SLURMCluster |
| 11 | +from dask_jobqueue import SLURMCluster |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class HPCDaskManager: |
@@ -37,18 +37,36 @@ def system_network_interface(self) -> str: |
37 | 37 | """ |
38 | 38 | Select the most appropriate network interface for HPC communication. |
39 | 39 |
|
| 40 | + If args.hpc_interface is provided, that value is used directly. Otherwise, |
| 41 | + commonly used HPC interfaces are preferred. Loopback and container interfaces |
| 42 | + are avoided because Dask workers on other nodes cannot connect to a scheduler |
| 43 | + advertised on 127.0.0.1. |
| 44 | +
|
40 | 45 | Returns: |
41 | 46 | str: Name of selected network interface. |
| 47 | +
|
| 48 | + Raises: |
| 49 | + RuntimeError: If no suitable non-loopback interface can be found. |
42 | 50 | """ |
| 51 | + configured = getattr(self.args, "hpc_interface", None) |
| 52 | + if configured: |
| 53 | + return configured |
| 54 | + |
43 | 55 | preferred_nics = ["bond0", "ib0", "hsn0", "eth0"] |
44 | 56 | interfaces = list(psutil.net_if_addrs().keys()) |
45 | 57 |
|
46 | 58 | for iface in preferred_nics: |
47 | 59 | if iface in interfaces: |
48 | 60 | return iface |
49 | 61 |
|
50 | | - # fallback to first available interface |
51 | | - return interfaces[0] |
| 62 | + for iface in interfaces: |
| 63 | + if not iface.startswith(("lo", "docker", "veth")): |
| 64 | + return iface |
| 65 | + |
| 66 | + raise RuntimeError( |
| 67 | + "Could not find a non-loopback network interface for Dask workers. " |
| 68 | + f"Available interfaces: {interfaces}. Set 'hpc_interface' in config.yaml." |
| 69 | + ) |
52 | 70 |
|
53 | 71 | def slurm_directives(self) -> tuple[list[str], list[str]]: |
54 | 72 | """ |
@@ -82,6 +100,9 @@ def slurm_prologues(self) -> list[str]: |
82 | 100 | args = self.args |
83 | 101 | prologue: list[str] = [] |
84 | 102 |
|
| 103 | + for module_name in getattr(args, "hpc_modules", None) or []: |
| 104 | + prologue.append(f"module load {module_name}") |
| 105 | + |
85 | 106 | prologue.append(f'eval "$({args.conda_path} shell.bash hook)"') |
86 | 107 |
|
87 | 108 | if args.conda_exec == "mamba": |
@@ -119,6 +140,7 @@ def configure_cluster(self) -> Client: |
119 | 140 | shebang="#!/bin/bash --login", |
120 | 141 | local_directory="$PWD", |
121 | 142 | interface=iface, |
| 143 | + scheduler_options={"interface": iface}, |
122 | 144 | job_script_prologue=prologue, |
123 | 145 | ) |
124 | 146 |
|
@@ -164,6 +186,10 @@ def submit_master(self) -> None: |
164 | 186 | f.write(f"#SBATCH --qos={self.args.hpc_qos}\n") |
165 | 187 |
|
166 | 188 | f.write("\n") |
| 189 | + |
| 190 | + for module_name in getattr(self.args, "hpc_modules", None) or []: |
| 191 | + f.write(f"module load {module_name}\n") |
| 192 | + |
167 | 193 | f.write(f'eval "$({self.args.conda_path} shell.bash hook)"\n') |
168 | 194 |
|
169 | 195 | if self.args.conda_exec == "mamba": |
|
0 commit comments