Skip to content

Commit 42c5a29

Browse files
committed
Add concurrent limitation
Allow user to define how many workers to run in parallel rather than attempting to apply workers to every job in the array Added validation for settings of max concurrent (So only positive int or None)
1 parent d32a2e6 commit 42c5a29

2 files changed

Lines changed: 159 additions & 8 deletions

File tree

arc/job/pipe/pipe_run.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import stat
1818
import sys
1919
import time
20+
from numbers import Integral
2021
from typing import Dict, List, Optional
2122

2223
import arc.parser.parser as parser
@@ -51,16 +52,30 @@ class PipeRun:
5152
run_id (str): Unique identifier for this pipe run.
5253
tasks (List[TaskSpec]): Task specifications to execute.
5354
cluster_software (str): Cluster scheduler type.
54-
max_workers (int): Maximum number of concurrent array workers.
55+
max_workers (int): Maximum total array worker slots (array size).
56+
max_concurrent (Optional[int]): Max workers running simultaneously
57+
(like PBS ``%N``). ``None`` disables throttling.
5558
max_attempts (int): Maximum retry attempts per task.
5659
"""
5760

61+
@staticmethod
62+
def _validate_max_concurrent(max_concurrent: Optional[int]) -> None:
63+
"""Accept ``None`` or a positive integer for throttle settings."""
64+
if max_concurrent is None:
65+
return
66+
if isinstance(max_concurrent, bool) or not isinstance(max_concurrent, Integral):
67+
raise ValueError('PipeRun max_concurrent must be None or a positive integer.')
68+
if max_concurrent > 0:
69+
return
70+
raise ValueError('PipeRun max_concurrent must be None or a positive integer.')
71+
5872
def __init__(self,
5973
project_directory: str,
6074
run_id: str,
6175
tasks: List[TaskSpec],
6276
cluster_software: str,
6377
max_workers: int = 100,
78+
max_concurrent: Optional[int] = None,
6479
max_attempts: int = 3,
6580
pipe_root: Optional[str] = None,
6681
):
@@ -69,6 +84,8 @@ def __init__(self,
6984
self.tasks = tasks
7085
self.cluster_software = cluster_software
7186
self.max_workers = max_workers
87+
self._validate_max_concurrent(max_concurrent)
88+
self.max_concurrent = None if max_concurrent is None else int(max_concurrent)
7289
self.max_attempts = max_attempts
7390
self.pipe_root = pipe_root if pipe_root is not None \
7491
else os.path.join(project_directory, 'calcs', 'pipe_' + run_id)
@@ -103,6 +120,7 @@ def _save_run_metadata(self) -> None:
103120
'status': self.status.value,
104121
'cluster_software': self.cluster_software,
105122
'max_workers': self.max_workers,
123+
'max_concurrent': self.max_concurrent,
106124
'max_attempts': self.max_attempts,
107125
'task_family': task_family,
108126
'engine': engine,
@@ -146,6 +164,7 @@ def from_dir(cls, pipe_root: str) -> 'PipeRun':
146164
tasks=tasks,
147165
cluster_software=data['cluster_software'],
148166
max_workers=data.get('max_workers', 100),
167+
max_concurrent=data.get('max_concurrent'),
149168
max_attempts=data.get('max_attempts', 3),
150169
pipe_root=pipe_root,
151170
)
@@ -192,12 +211,41 @@ def _submission_resources(self):
192211
Derive resource settings from the homogeneous task list.
193212
194213
Returns:
195-
Tuple[int, int, int]: ``(cpus, memory_mb, array_size)``
214+
Tuple[int, int, int, Optional[int]]:
215+
``(cpus, memory_mb, array_size, throttle)`` where ``throttle``
216+
caps workers running simultaneously (clamped to ``array_size``),
217+
or ``None`` if unthrottled.
196218
"""
197219
cpus = self.tasks[0].required_cores if self.tasks else 1
198220
memory_mb = self.tasks[0].required_memory_mb if self.tasks else 4096
199221
array_size = min(self.max_workers, len(self.tasks)) if self.tasks else self.max_workers
200-
return cpus, memory_mb, array_size
222+
throttle = None
223+
if self.max_concurrent is not None and self.max_concurrent > 0:
224+
throttle = min(self.max_concurrent, array_size)
225+
return cpus, memory_mb, array_size, throttle
226+
227+
def _render_throttle(self, array_size: int, throttle: Optional[int]) -> Dict[str, str]:
228+
"""
229+
Render scheduler-specific array-range and extra-directives strings.
230+
231+
SLURM/PBS encode the throttle as an inline ``%K`` suffix on the range.
232+
SGE uses a separate ``-tc`` directive. HTCondor uses ``max_materialize``
233+
and takes a bare count (not a range) for ``queue``.
234+
"""
235+
cs = 'sge' if self.cluster_software == 'oge' else self.cluster_software
236+
if cs == 'htcondor':
237+
array_range = str(array_size)
238+
extra = f'max_materialize = {throttle}' if throttle else ''
239+
elif cs == 'sge':
240+
array_range = f'1-{array_size}'
241+
extra = f'#$ -tc {throttle}' if throttle else ''
242+
elif cs in ('slurm', 'pbs'):
243+
suffix = f'%{throttle}' if throttle else ''
244+
array_range = f'1-{array_size}{suffix}'
245+
extra = ''
246+
else:
247+
raise NotImplementedError(f'No throttle rendering for {self.cluster_software}')
248+
return {'array_range': array_range, 'extra_directives': extra}
201249

202250
def write_submit_script(self) -> str:
203251
"""
@@ -215,7 +263,8 @@ def write_submit_script(self) -> str:
215263
raise NotImplementedError(
216264
f'No pipe submit template for cluster software: {self.cluster_software}. '
217265
f'Available templates: {list(pipe_submit.keys())}')
218-
cpus, memory_mb, array_size = self._submission_resources()
266+
cpus, memory_mb, array_size, throttle = self._submission_resources()
267+
throttle_fields = self._render_throttle(array_size, throttle)
219268
server = servers_dict.get('local', {})
220269
queue, _ = next(iter(server.get('queues', {}).items()), ('', None))
221270
engine = self.tasks[0].engine if self.tasks else ''
@@ -226,7 +275,8 @@ def write_submit_script(self) -> str:
226275
env_setup = f'{env_setup}\n{scratch_export}' if env_setup else scratch_export
227276
content = pipe_submit[template_key].format(
228277
name=f'pipe_{self.run_id}',
229-
max_task_num=array_size,
278+
array_range=throttle_fields['array_range'],
279+
extra_directives=throttle_fields['extra_directives'],
230280
pipe_root=self.pipe_root,
231281
python_exe=sys.executable,
232282
cpus=cpus,

arc/job/pipe/pipe_run_test.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
import time
1313
import unittest
14+
import uuid
1415

1516
from arc.job.adapters.mockter import MockAdapter
1617
from arc.job.pipe.pipe_state import TaskState, PipeRunState, TaskSpec, read_task_state, update_task_state
@@ -134,12 +135,14 @@ def setUp(self):
134135
def tearDown(self):
135136
shutil.rmtree(self.tmpdir, ignore_errors=True)
136137

137-
def _make_run(self, cluster_software, max_workers=10, n_tasks=None):
138+
def _make_run(self, cluster_software, max_workers=10, n_tasks=None,
139+
max_concurrent=None, run_id=None):
138140
n = n_tasks if n_tasks is not None else max_workers
139141
tasks = [_make_spec(f't_{i}') for i in range(n)]
140-
run = PipeRun(project_directory=self.tmpdir, run_id='sub_test',
142+
run = PipeRun(project_directory=self.tmpdir,
143+
run_id=run_id or f'sub_test_{uuid.uuid4().hex[:8]}',
141144
tasks=tasks, cluster_software=cluster_software,
142-
max_workers=max_workers)
145+
max_workers=max_workers, max_concurrent=max_concurrent)
143146
run.stage()
144147
return run
145148

@@ -170,6 +173,104 @@ def test_htcondor_content(self):
170173
content = f.read()
171174
self.assertIn('queue 12', content)
172175

176+
def test_slurm_throttle(self):
177+
run = self._make_run('slurm', max_workers=100, n_tasks=100, max_concurrent=8)
178+
with open(run.write_submit_script()) as f:
179+
content = f.read()
180+
self.assertIn('#SBATCH --array=1-100%8', content)
181+
182+
def test_pbs_throttle(self):
183+
run = self._make_run('pbs', max_workers=50, n_tasks=50, max_concurrent=4)
184+
with open(run.write_submit_script()) as f:
185+
content = f.read()
186+
self.assertIn('#PBS -J 1-50%4', content)
187+
188+
def test_sge_throttle_uses_tc_directive(self):
189+
run = self._make_run('sge', max_workers=20, n_tasks=20, max_concurrent=5)
190+
with open(run.write_submit_script()) as f:
191+
content = f.read()
192+
self.assertIn('#$ -t 1-20', content)
193+
self.assertIn('#$ -tc 5', content)
194+
195+
def test_htcondor_throttle_uses_max_materialize(self):
196+
run = self._make_run('htcondor', max_workers=12, n_tasks=12, max_concurrent=3)
197+
with open(run.write_submit_script()) as f:
198+
content = f.read()
199+
self.assertIn('queue 12', content)
200+
self.assertIn('max_materialize = 3', content)
201+
202+
def test_throttle_clamped_to_array_size(self):
203+
# max_concurrent > array_size should clamp to array_size (no-op throttle).
204+
run = self._make_run('slurm', max_workers=6, n_tasks=6, max_concurrent=99)
205+
with open(run.write_submit_script()) as f:
206+
content = f.read()
207+
self.assertIn('#SBATCH --array=1-6%6', content)
208+
209+
def test_unthrottled_has_no_throttle_markers(self):
210+
"""Regression guard: unthrottled submit scripts contain no throttle syntax."""
211+
array_line_markers = {
212+
'slurm': '#SBATCH --array=',
213+
'pbs': '#PBS -J ',
214+
'sge': '#$ -t ',
215+
'htcondor': 'queue ',
216+
}
217+
for cs, marker in array_line_markers.items():
218+
with self.subTest(cluster_software=cs):
219+
run = self._make_run(cs, max_workers=10, n_tasks=10, max_concurrent=None)
220+
with open(run.write_submit_script()) as f:
221+
content = f.read()
222+
array_line = next(line for line in content.splitlines() if marker in line)
223+
self.assertNotIn('%', array_line)
224+
self.assertNotIn('-tc ', content)
225+
self.assertNotIn('max_materialize', content)
226+
227+
def test_oge_routes_to_sge_throttle(self):
228+
"""cluster_software='oge' should use the SGE throttle directive."""
229+
run = self._make_run('oge', max_workers=15, n_tasks=15, max_concurrent=4)
230+
with open(run.write_submit_script()) as f:
231+
content = f.read()
232+
self.assertIn('#$ -t 1-15', content)
233+
self.assertIn('#$ -tc 4', content)
234+
235+
def test_invalid_max_concurrent_values_raise(self):
236+
"""Only None and positive integers are accepted."""
237+
for max_concurrent in (-1, 0, -2, 1.5, '3', True, False):
238+
with self.subTest(max_concurrent=max_concurrent):
239+
with self.assertRaisesRegex(ValueError, 'max_concurrent'):
240+
PipeRun(project_directory=self.tmpdir,
241+
run_id=f'invalid_{uuid.uuid4().hex[:8]}',
242+
tasks=[_make_spec('t_0')],
243+
cluster_software='slurm',
244+
max_concurrent=max_concurrent)
245+
246+
def test_render_throttle_branching_matrix(self):
247+
"""Unit-test _render_throttle directly across scheduler × throttle combos."""
248+
run = self._make_run('slurm', max_workers=1, n_tasks=1)
249+
cases = [
250+
('slurm', 100, None, {'array_range': '1-100', 'extra_directives': ''}),
251+
('slurm', 100, 8, {'array_range': '1-100%8', 'extra_directives': ''}),
252+
('pbs', 50, 4, {'array_range': '1-50%4', 'extra_directives': ''}),
253+
('sge', 20, None, {'array_range': '1-20', 'extra_directives': ''}),
254+
('sge', 20, 5, {'array_range': '1-20', 'extra_directives': '#$ -tc 5'}),
255+
('oge', 20, 5, {'array_range': '1-20', 'extra_directives': '#$ -tc 5'}),
256+
('htcondor', 12, None, {'array_range': '12', 'extra_directives': ''}),
257+
('htcondor', 12, 3, {'array_range': '12', 'extra_directives': 'max_materialize = 3'}),
258+
]
259+
for cs, array_size, throttle, expected in cases:
260+
with self.subTest(cluster_software=cs, throttle=throttle):
261+
run.cluster_software = cs
262+
self.assertEqual(run._render_throttle(array_size, throttle), expected)
263+
264+
def test_from_dir_round_trip_preserves_max_concurrent(self):
265+
"""Persist max_concurrent through run.json so crash-recovered runs keep their throttle."""
266+
run = self._make_run('slurm', max_workers=20, n_tasks=20, max_concurrent=7)
267+
run.write_submit_script()
268+
reloaded = PipeRun.from_dir(run.pipe_root)
269+
self.assertEqual(reloaded.max_concurrent, 7)
270+
with open(reloaded.write_submit_script()) as f:
271+
content = f.read()
272+
self.assertIn('#SBATCH --array=1-20%7', content)
273+
173274
def test_overwrite_is_safe(self):
174275
run = self._make_run('slurm')
175276
p1 = run.write_submit_script()

0 commit comments

Comments
 (0)