2323
2424import yaml
2525
26+ _STAGE_ORDER = ["decode" , "fragment" , "offsets" , "stitch" , "connect" ,
27+ "build_rg" , "merge_rg" , "agglomerate" , "relabel" , "apply" , "assemble" ]
28+
29+
30+ def _format_progress (counts ):
31+ """Format stage_counts() in pipeline order."""
32+ order = {s : i for i , s in enumerate (_STAGE_ORDER )}
33+ stages = sorted (counts .keys (), key = lambda s : (order .get (s , 999 ), s ))
34+ parts = []
35+ for stage in stages :
36+ sc = counts [stage ]
37+ done = sc .get ("succeeded" , 0 )
38+ total = sum (sc .values ())
39+ running = sc .get ("running" , 0 )
40+ status = f"{ stage } : { done } /{ total } "
41+ if running :
42+ status += f" ({ running } running)"
43+ parts .append (status )
44+ return f" Progress: { ' | ' .join (parts )} "
45+
46+
2647def _worker_fn (args_tuple ):
2748 """Worker function for parallel decode (takes all args as tuple for spawn compatibility)."""
2849 worker_idx , workflow_root , idle_timeout , max_tasks = args_tuple
@@ -52,6 +73,10 @@ def main():
5273 parser .add_argument ("--assemble" , action = "store_true" , help = "Assemble final output volume" )
5374 parser .add_argument ("--parallel" , type = int , default = None ,
5475 help = "Run N worker processes on this machine" )
76+ parser .add_argument ("--sbatch" , action = "store_true" ,
77+ help = "Force SLURM submission (overrides backend config)" )
78+ parser .add_argument ("--local" , action = "store_true" ,
79+ help = "Force local multiprocess (overrides backend config)" )
5580 parser .add_argument ("--max-tasks" , type = int , default = None , help = "Max tasks per worker" )
5681 parser .add_argument ("--idle-timeout" , type = float , default = 60.0 , help = "Worker idle timeout (seconds)" )
5782 parser .add_argument ("--worker-id" , type = str , default = None , help = "Worker identifier" )
@@ -106,6 +131,28 @@ def main():
106131 if n_stale :
107132 print (f"Reset { n_stale } stale RUNNING tasks (older than { args .stale_timeout } s)." )
108133
134+ # Recover decode tasks that completed outside the orchestrator (e.g. --chunk-index)
135+ n_recovered = 0
136+ for chunk in runner .chunks :
137+ output_path = runner ._raw_chunk_path (chunk .key )
138+ if not output_path .exists ():
139+ continue
140+ task_id = f"decode:{ chunk .key } "
141+ try :
142+ record = runner .orchestrator .get_record (task_id )
143+ if record .state .value == "succeeded" :
144+ continue
145+ max_id = runner ._read_chunk_max (output_path )
146+ runner .orchestrator .force_complete (
147+ task_id , result = {"chunk_path" : str (output_path ), "max_id" : max_id },
148+ )
149+ n_recovered += 1
150+ except Exception as e :
151+ print (f" Warning: { chunk .key } : corrupt output ({ e } ), deleting" )
152+ output_path .unlink (missing_ok = True )
153+ if n_recovered :
154+ print (f"Recovered { n_recovered } decode tasks from existing chunk files." )
155+
109156 chunks = runner .chunks
110157 borders = runner .borders
111158 print (f"Volume shape: { config .volume_shape } " )
@@ -119,6 +166,63 @@ def main():
119166 print ("Workflow initialized. Launch workers to execute tasks." )
120167 return
121168
169+ # Determine execution backend: CLI flags override YAML config
170+ if args .sbatch :
171+ backend = "slurm"
172+ elif args .local :
173+ backend = "multiprocess"
174+ else :
175+ backend = large_cfg .get ("backend" , "multiprocess" )
176+
177+ if backend == "slurm" :
178+ import subprocess , tempfile , textwrap
179+
180+ slurm_cfg = large_cfg .get ("slurm" , {})
181+ partition = slurm_cfg .get ("partition" , "weilab" )
182+ mem = slurm_cfg .get ("mem" , "64G" )
183+ cpus = slurm_cfg .get ("cpus_per_task" , 2 )
184+ time_limit = slurm_cfg .get ("time" , "12:00:00" )
185+ n_chunks = len (chunks )
186+
187+ script_path = os .path .abspath (sys .argv [0 ])
188+ config_path = os .path .abspath (args .config )
189+ work_dir = os .getcwd ()
190+ output_dir = os .path .join (work_dir , "slurm_outputs" )
191+ os .makedirs (output_dir , exist_ok = True )
192+
193+ sbatch_script = textwrap .dedent (f"""\
194+ #!/bin/bash
195+ #SBATCH --job-name=waterz_worker
196+ #SBATCH --partition={ partition }
197+ #SBATCH --mem={ mem }
198+ #SBATCH --cpus-per-task={ cpus }
199+ #SBATCH --time={ time_limit }
200+ #SBATCH --array=0-{ n_chunks - 1 }
201+ #SBATCH --output={ output_dir } /waterz_worker_%A_%a.out
202+ #SBATCH --error={ output_dir } /waterz_worker_%A_%a.err
203+
204+ source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc
205+ cd { work_dir }
206+ export CCACHE_DISABLE=1
207+ export OMP_NUM_THREADS=1
208+ export OPENBLAS_NUM_THREADS=1
209+ export MKL_NUM_THREADS=1
210+
211+ python { script_path } --config { config_path } --worker --no-reset-stale
212+ """ )
213+
214+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".sh" , delete = False ) as f :
215+ f .write (sbatch_script )
216+ tmp_path = f .name
217+
218+ result = subprocess .run (["sbatch" , tmp_path ], capture_output = True , text = True )
219+ os .unlink (tmp_path )
220+ print (result .stdout .strip ())
221+ if result .returncode != 0 :
222+ print (result .stderr .strip (), file = sys .stderr )
223+ sys .exit (result .returncode )
224+ return
225+
122226 # Direct chunk assignment (no orchestrator competition)
123227 chunk_index = args .chunk_index
124228 if chunk_index is None and os .environ .get ("SLURM_ARRAY_TASK_ID" ):
@@ -136,6 +240,10 @@ def main():
136240 print (f"Chunk index { idx } out of range (0-{ len (chunks )- 1 } ), skipping" )
137241 continue
138242 chunk = chunks [idx ]
243+ output_path = runner ._raw_chunk_path (chunk .key )
244+ if output_path .exists ():
245+ print (f"Chunk { idx } /{ len (chunks )} ({ chunk .key } ): already exists, skipping" )
246+ continue
139247 print (f"Decoding chunk { idx } /{ len (chunks )} : { chunk .key } " )
140248 from waterz .orchestrator import TaskRecord , TaskSpec
141249 record = TaskRecord (spec = TaskSpec (name = f"decode_{ chunk .key } " , stage = "decode" , key = chunk .key ))
@@ -165,16 +273,7 @@ def main():
165273 counts = runner .orchestrator .stage_counts ()
166274 now = _time .monotonic ()
167275 if now - last_print >= 10 :
168- parts = []
169- for stage , sc in sorted (counts .items ()):
170- done = sc .get ("succeeded" , 0 )
171- total = sum (sc .values ())
172- running = sc .get ("running" , 0 )
173- status = f"{ stage } : { done } /{ total } "
174- if running :
175- status += f" ({ running } running)"
176- parts .append (status )
177- print (f" Progress: { ' | ' .join (parts )} " , flush = True )
276+ print (_format_progress (counts ), flush = True )
178277 last_print = now
179278
180279 all_terminal = all (
@@ -198,14 +297,15 @@ def main():
198297 print (f"Output: { config .resolved_output_path } " )
199298 return
200299
201- if args .parallel and args .parallel > 1 :
300+ n_parallel = args .parallel or large_cfg .get ("num_workers" , 1 )
301+ if n_parallel > 1 :
202302 import multiprocessing as mp
203303
204304 workflow_root = large_cfg ["workflow_root" ]
205305 idle_timeout = args .idle_timeout or 120
206306 max_tasks = args .max_tasks
207307
208- n_workers = args . parallel
308+ n_workers = n_parallel
209309 print (f"Running parallel decode with { n_workers } workers..." )
210310
211311 worker_args = [
@@ -226,16 +326,7 @@ def main():
226326 def _progress_loop ():
227327 while not stop_progress .wait (10 ):
228328 counts = runner .orchestrator .stage_counts ()
229- parts = []
230- for stage , sc in sorted (counts .items ()):
231- done = sc .get ("succeeded" , 0 )
232- total = sum (sc .values ())
233- running = sc .get ("running" , 0 )
234- status = f"{ stage } : { done } /{ total } "
235- if running :
236- status += f" ({ running } running)"
237- parts .append (status )
238- print (f" Progress: { ' | ' .join (parts )} " , flush = True )
329+ print (_format_progress (counts ), flush = True )
239330
240331 t = threading .Thread (target = _progress_loop , daemon = True )
241332 t .start ()
0 commit comments