|
11 | 11 | import tempfile |
12 | 12 | import time |
13 | 13 | import unittest |
| 14 | +import uuid |
14 | 15 |
|
15 | 16 | from arc.job.adapters.mockter import MockAdapter |
16 | 17 | from arc.job.pipe.pipe_state import TaskState, PipeRunState, TaskSpec, read_task_state, update_task_state |
@@ -134,12 +135,14 @@ def setUp(self): |
134 | 135 | def tearDown(self): |
135 | 136 | shutil.rmtree(self.tmpdir, ignore_errors=True) |
136 | 137 |
|
137 | | - def _make_run(self, cluster_software, max_workers=10, n_tasks=None): |
| 138 | + def _make_run(self, cluster_software, max_workers=10, n_tasks=None, |
| 139 | + max_concurrent=None, run_id=None): |
138 | 140 | n = n_tasks if n_tasks is not None else max_workers |
139 | 141 | tasks = [_make_spec(f't_{i}') for i in range(n)] |
140 | | - run = PipeRun(project_directory=self.tmpdir, run_id='sub_test', |
| 142 | + run = PipeRun(project_directory=self.tmpdir, |
| 143 | + run_id=run_id or f'sub_test_{uuid.uuid4().hex[:8]}', |
141 | 144 | tasks=tasks, cluster_software=cluster_software, |
142 | | - max_workers=max_workers) |
| 145 | + max_workers=max_workers, max_concurrent=max_concurrent) |
143 | 146 | run.stage() |
144 | 147 | return run |
145 | 148 |
|
@@ -170,6 +173,104 @@ def test_htcondor_content(self): |
170 | 173 | content = f.read() |
171 | 174 | self.assertIn('queue 12', content) |
172 | 175 |
|
| 176 | + def test_slurm_throttle(self): |
| 177 | + run = self._make_run('slurm', max_workers=100, n_tasks=100, max_concurrent=8) |
| 178 | + with open(run.write_submit_script()) as f: |
| 179 | + content = f.read() |
| 180 | + self.assertIn('#SBATCH --array=1-100%8', content) |
| 181 | + |
| 182 | + def test_pbs_throttle(self): |
| 183 | + run = self._make_run('pbs', max_workers=50, n_tasks=50, max_concurrent=4) |
| 184 | + with open(run.write_submit_script()) as f: |
| 185 | + content = f.read() |
| 186 | + self.assertIn('#PBS -J 1-50%4', content) |
| 187 | + |
| 188 | + def test_sge_throttle_uses_tc_directive(self): |
| 189 | + run = self._make_run('sge', max_workers=20, n_tasks=20, max_concurrent=5) |
| 190 | + with open(run.write_submit_script()) as f: |
| 191 | + content = f.read() |
| 192 | + self.assertIn('#$ -t 1-20', content) |
| 193 | + self.assertIn('#$ -tc 5', content) |
| 194 | + |
| 195 | + def test_htcondor_throttle_uses_max_materialize(self): |
| 196 | + run = self._make_run('htcondor', max_workers=12, n_tasks=12, max_concurrent=3) |
| 197 | + with open(run.write_submit_script()) as f: |
| 198 | + content = f.read() |
| 199 | + self.assertIn('queue 12', content) |
| 200 | + self.assertIn('max_materialize = 3', content) |
| 201 | + |
| 202 | + def test_throttle_clamped_to_array_size(self): |
| 203 | + # max_concurrent > array_size should clamp to array_size (no-op throttle). |
| 204 | + run = self._make_run('slurm', max_workers=6, n_tasks=6, max_concurrent=99) |
| 205 | + with open(run.write_submit_script()) as f: |
| 206 | + content = f.read() |
| 207 | + self.assertIn('#SBATCH --array=1-6%6', content) |
| 208 | + |
| 209 | + def test_unthrottled_has_no_throttle_markers(self): |
| 210 | + """Regression guard: unthrottled submit scripts contain no throttle syntax.""" |
| 211 | + array_line_markers = { |
| 212 | + 'slurm': '#SBATCH --array=', |
| 213 | + 'pbs': '#PBS -J ', |
| 214 | + 'sge': '#$ -t ', |
| 215 | + 'htcondor': 'queue ', |
| 216 | + } |
| 217 | + for cs, marker in array_line_markers.items(): |
| 218 | + with self.subTest(cluster_software=cs): |
| 219 | + run = self._make_run(cs, max_workers=10, n_tasks=10, max_concurrent=None) |
| 220 | + with open(run.write_submit_script()) as f: |
| 221 | + content = f.read() |
| 222 | + array_line = next(line for line in content.splitlines() if marker in line) |
| 223 | + self.assertNotIn('%', array_line) |
| 224 | + self.assertNotIn('-tc ', content) |
| 225 | + self.assertNotIn('max_materialize', content) |
| 226 | + |
| 227 | + def test_oge_routes_to_sge_throttle(self): |
| 228 | + """cluster_software='oge' should use the SGE throttle directive.""" |
| 229 | + run = self._make_run('oge', max_workers=15, n_tasks=15, max_concurrent=4) |
| 230 | + with open(run.write_submit_script()) as f: |
| 231 | + content = f.read() |
| 232 | + self.assertIn('#$ -t 1-15', content) |
| 233 | + self.assertIn('#$ -tc 4', content) |
| 234 | + |
| 235 | + def test_invalid_max_concurrent_values_raise(self): |
| 236 | + """Only None and positive integers are accepted.""" |
| 237 | + for max_concurrent in (-1, 0, -2, 1.5, '3', True, False): |
| 238 | + with self.subTest(max_concurrent=max_concurrent): |
| 239 | + with self.assertRaisesRegex(ValueError, 'max_concurrent'): |
| 240 | + PipeRun(project_directory=self.tmpdir, |
| 241 | + run_id=f'invalid_{uuid.uuid4().hex[:8]}', |
| 242 | + tasks=[_make_spec('t_0')], |
| 243 | + cluster_software='slurm', |
| 244 | + max_concurrent=max_concurrent) |
| 245 | + |
| 246 | + def test_render_throttle_branching_matrix(self): |
| 247 | + """Unit-test _render_throttle directly across scheduler × throttle combos.""" |
| 248 | + run = self._make_run('slurm', max_workers=1, n_tasks=1) |
| 249 | + cases = [ |
| 250 | + ('slurm', 100, None, {'array_range': '1-100', 'extra_directives': ''}), |
| 251 | + ('slurm', 100, 8, {'array_range': '1-100%8', 'extra_directives': ''}), |
| 252 | + ('pbs', 50, 4, {'array_range': '1-50%4', 'extra_directives': ''}), |
| 253 | + ('sge', 20, None, {'array_range': '1-20', 'extra_directives': ''}), |
| 254 | + ('sge', 20, 5, {'array_range': '1-20', 'extra_directives': '#$ -tc 5'}), |
| 255 | + ('oge', 20, 5, {'array_range': '1-20', 'extra_directives': '#$ -tc 5'}), |
| 256 | + ('htcondor', 12, None, {'array_range': '12', 'extra_directives': ''}), |
| 257 | + ('htcondor', 12, 3, {'array_range': '12', 'extra_directives': 'max_materialize = 3'}), |
| 258 | + ] |
| 259 | + for cs, array_size, throttle, expected in cases: |
| 260 | + with self.subTest(cluster_software=cs, throttle=throttle): |
| 261 | + run.cluster_software = cs |
| 262 | + self.assertEqual(run._render_throttle(array_size, throttle), expected) |
| 263 | + |
| 264 | + def test_from_dir_round_trip_preserves_max_concurrent(self): |
| 265 | + """Persist max_concurrent through run.json so crash-recovered runs keep their throttle.""" |
| 266 | + run = self._make_run('slurm', max_workers=20, n_tasks=20, max_concurrent=7) |
| 267 | + run.write_submit_script() |
| 268 | + reloaded = PipeRun.from_dir(run.pipe_root) |
| 269 | + self.assertEqual(reloaded.max_concurrent, 7) |
| 270 | + with open(reloaded.write_submit_script()) as f: |
| 271 | + content = f.read() |
| 272 | + self.assertIn('#SBATCH --array=1-20%7', content) |
| 273 | + |
173 | 274 | def test_overwrite_is_safe(self): |
174 | 275 | run = self._make_run('slurm') |
175 | 276 | p1 = run.write_submit_script() |
|
0 commit comments