Skip to content

Commit 573fb05

Browse files
Donglai Weiclaude
andcommitted
Update decode_large: chunk-index, progress display, yaml config
- decode_large.py: progress formatting, serial progress thread, wait loop with stage counts, chunk-range support - decode_large_worker.sh: NFS cache refresh, simplified chunk-index mode - waterz_decoding_large.yaml: updated params Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3f1a7f5 commit 573fb05

3 files changed

Lines changed: 135 additions & 32 deletions

File tree

scripts/decode_large.py

Lines changed: 113 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@
2323

2424
import yaml
2525

26+
_STAGE_ORDER = ["decode", "fragment", "offsets", "stitch", "connect",
27+
"build_rg", "merge_rg", "agglomerate", "relabel", "apply", "assemble"]
28+
29+
30+
def _format_progress(counts):
31+
"""Format stage_counts() in pipeline order."""
32+
order = {s: i for i, s in enumerate(_STAGE_ORDER)}
33+
stages = sorted(counts.keys(), key=lambda s: (order.get(s, 999), s))
34+
parts = []
35+
for stage in stages:
36+
sc = counts[stage]
37+
done = sc.get("succeeded", 0)
38+
total = sum(sc.values())
39+
running = sc.get("running", 0)
40+
status = f"{stage}: {done}/{total}"
41+
if running:
42+
status += f" ({running} running)"
43+
parts.append(status)
44+
return f" Progress: {' | '.join(parts)}"
45+
46+
2647
def _worker_fn(args_tuple):
2748
"""Worker function for parallel decode (takes all args as tuple for spawn compatibility)."""
2849
worker_idx, workflow_root, idle_timeout, max_tasks = args_tuple
@@ -52,6 +73,10 @@ def main():
5273
parser.add_argument("--assemble", action="store_true", help="Assemble final output volume")
5374
parser.add_argument("--parallel", type=int, default=None,
5475
help="Run N worker processes on this machine")
76+
parser.add_argument("--sbatch", action="store_true",
77+
help="Force SLURM submission (overrides backend config)")
78+
parser.add_argument("--local", action="store_true",
79+
help="Force local multiprocess (overrides backend config)")
5580
parser.add_argument("--max-tasks", type=int, default=None, help="Max tasks per worker")
5681
parser.add_argument("--idle-timeout", type=float, default=60.0, help="Worker idle timeout (seconds)")
5782
parser.add_argument("--worker-id", type=str, default=None, help="Worker identifier")
@@ -106,6 +131,28 @@ def main():
106131
if n_stale:
107132
print(f"Reset {n_stale} stale RUNNING tasks (older than {args.stale_timeout}s).")
108133

134+
# Recover decode tasks that completed outside the orchestrator (e.g. --chunk-index)
135+
n_recovered = 0
136+
for chunk in runner.chunks:
137+
output_path = runner._raw_chunk_path(chunk.key)
138+
if not output_path.exists():
139+
continue
140+
task_id = f"decode:{chunk.key}"
141+
try:
142+
record = runner.orchestrator.get_record(task_id)
143+
if record.state.value == "succeeded":
144+
continue
145+
max_id = runner._read_chunk_max(output_path)
146+
runner.orchestrator.force_complete(
147+
task_id, result={"chunk_path": str(output_path), "max_id": max_id},
148+
)
149+
n_recovered += 1
150+
except Exception as e:
151+
print(f" Warning: {chunk.key}: corrupt output ({e}), deleting")
152+
output_path.unlink(missing_ok=True)
153+
if n_recovered:
154+
print(f"Recovered {n_recovered} decode tasks from existing chunk files.")
155+
109156
chunks = runner.chunks
110157
borders = runner.borders
111158
print(f"Volume shape: {config.volume_shape}")
@@ -119,6 +166,63 @@ def main():
119166
print("Workflow initialized. Launch workers to execute tasks.")
120167
return
121168

169+
# Determine execution backend: CLI flags override YAML config
170+
if args.sbatch:
171+
backend = "slurm"
172+
elif args.local:
173+
backend = "multiprocess"
174+
else:
175+
backend = large_cfg.get("backend", "multiprocess")
176+
177+
if backend == "slurm":
178+
import subprocess, tempfile, textwrap
179+
180+
slurm_cfg = large_cfg.get("slurm", {})
181+
partition = slurm_cfg.get("partition", "weilab")
182+
mem = slurm_cfg.get("mem", "64G")
183+
cpus = slurm_cfg.get("cpus_per_task", 2)
184+
time_limit = slurm_cfg.get("time", "12:00:00")
185+
n_chunks = len(chunks)
186+
187+
script_path = os.path.abspath(sys.argv[0])
188+
config_path = os.path.abspath(args.config)
189+
work_dir = os.getcwd()
190+
output_dir = os.path.join(work_dir, "slurm_outputs")
191+
os.makedirs(output_dir, exist_ok=True)
192+
193+
sbatch_script = textwrap.dedent(f"""\
194+
#!/bin/bash
195+
#SBATCH --job-name=waterz_worker
196+
#SBATCH --partition={partition}
197+
#SBATCH --mem={mem}
198+
#SBATCH --cpus-per-task={cpus}
199+
#SBATCH --time={time_limit}
200+
#SBATCH --array=0-{n_chunks - 1}
201+
#SBATCH --output={output_dir}/waterz_worker_%A_%a.out
202+
#SBATCH --error={output_dir}/waterz_worker_%A_%a.err
203+
204+
source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc
205+
cd {work_dir}
206+
export CCACHE_DISABLE=1
207+
export OMP_NUM_THREADS=1
208+
export OPENBLAS_NUM_THREADS=1
209+
export MKL_NUM_THREADS=1
210+
211+
python {script_path} --config {config_path} --worker --no-reset-stale
212+
""")
213+
214+
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
215+
f.write(sbatch_script)
216+
tmp_path = f.name
217+
218+
result = subprocess.run(["sbatch", tmp_path], capture_output=True, text=True)
219+
os.unlink(tmp_path)
220+
print(result.stdout.strip())
221+
if result.returncode != 0:
222+
print(result.stderr.strip(), file=sys.stderr)
223+
sys.exit(result.returncode)
224+
return
225+
122226
# Direct chunk assignment (no orchestrator competition)
123227
chunk_index = args.chunk_index
124228
if chunk_index is None and os.environ.get("SLURM_ARRAY_TASK_ID"):
@@ -136,6 +240,10 @@ def main():
136240
print(f"Chunk index {idx} out of range (0-{len(chunks)-1}), skipping")
137241
continue
138242
chunk = chunks[idx]
243+
output_path = runner._raw_chunk_path(chunk.key)
244+
if output_path.exists():
245+
print(f"Chunk {idx}/{len(chunks)} ({chunk.key}): already exists, skipping")
246+
continue
139247
print(f"Decoding chunk {idx}/{len(chunks)}: {chunk.key}")
140248
from waterz.orchestrator import TaskRecord, TaskSpec
141249
record = TaskRecord(spec=TaskSpec(name=f"decode_{chunk.key}", stage="decode", key=chunk.key))
@@ -165,16 +273,7 @@ def main():
165273
counts = runner.orchestrator.stage_counts()
166274
now = _time.monotonic()
167275
if now - last_print >= 10:
168-
parts = []
169-
for stage, sc in sorted(counts.items()):
170-
done = sc.get("succeeded", 0)
171-
total = sum(sc.values())
172-
running = sc.get("running", 0)
173-
status = f"{stage}: {done}/{total}"
174-
if running:
175-
status += f" ({running} running)"
176-
parts.append(status)
177-
print(f" Progress: {' | '.join(parts)}", flush=True)
276+
print(_format_progress(counts), flush=True)
178277
last_print = now
179278

180279
all_terminal = all(
@@ -198,14 +297,15 @@ def main():
198297
print(f"Output: {config.resolved_output_path}")
199298
return
200299

201-
if args.parallel and args.parallel > 1:
300+
n_parallel = args.parallel or large_cfg.get("num_workers", 1)
301+
if n_parallel > 1:
202302
import multiprocessing as mp
203303

204304
workflow_root = large_cfg["workflow_root"]
205305
idle_timeout = args.idle_timeout or 120
206306
max_tasks = args.max_tasks
207307

208-
n_workers = args.parallel
308+
n_workers = n_parallel
209309
print(f"Running parallel decode with {n_workers} workers...")
210310

211311
worker_args = [
@@ -226,16 +326,7 @@ def main():
226326
def _progress_loop():
227327
while not stop_progress.wait(10):
228328
counts = runner.orchestrator.stage_counts()
229-
parts = []
230-
for stage, sc in sorted(counts.items()):
231-
done = sc.get("succeeded", 0)
232-
total = sum(sc.values())
233-
running = sc.get("running", 0)
234-
status = f"{stage}: {done}/{total}"
235-
if running:
236-
status += f" ({running} running)"
237-
parts.append(status)
238-
print(f" Progress: {' | '.join(parts)}", flush=True)
329+
print(_format_progress(counts), flush=True)
239330

240331
t = threading.Thread(target=_progress_loop, daemon=True)
241332
t.start()

scripts/decode_large_worker.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22
#SBATCH --job-name=waterz_worker
3-
#SBATCH --mem=64G
3+
#SBATCH --mem=150G
44
#SBATCH --cpus-per-task=2
55
#SBATCH --time=12:00:00
66
#SBATCH --output=slurm_outputs/waterz_worker_%A_%a.out
@@ -18,6 +18,9 @@ cd /projects/weilab/weidf/lib/pytorch_connectomics
1818

1919
export CCACHE_DISABLE=1
2020

21+
# Force NFS cache refresh so all nodes see the latest witty .so files
22+
ls -la ~/.cache/witty/ > /dev/null 2>&1
23+
2124
echo "Worker ${SLURM_ARRAY_TASK_ID} of ${SLURM_ARRAY_TASK_COUNT} on $(hostname)"
2225
echo "Config: ${CONFIG}"
2326
echo "Start: $(date)"

tutorials/waterz_decoding_large.yaml

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ description: >
1414
python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --init-only
1515
1616
# Step 2: Launch workers (each claims and executes tasks)
17-
sbatch --array=0-7 scripts/decode_large_worker.sh tutorials/waterz_decoding_large.yaml
17+
sbatch --array=0-25 scripts/decode_large_worker.sh tutorials/waterz_decoding_large.yaml
1818
1919
# Step 3: Wait for completion + assemble
2020
python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --wait --assemble
@@ -28,28 +28,28 @@ large_decode:
2828
workflow_root: "/projects/weilab/weidf/lib/pytorch_connectomics/outputs/neuron_liconn_mit/DL228B/" # SET THIS
2929

3030
# Chunk layout
31-
chunk_shape: [80, 2066, 2066] # voxels per chunk (Z, Y, X)
31+
chunk_shape: [40, 2066, 2066] # voxels per chunk (Z, Y, X)
3232
overlap: [0, 0, 0] # overlap per axis; [8,8,8] for overlap pipeline
3333

3434
# Waterz agglomeration parameters
3535
thresholds: [0.4]
3636
merge_function: aff85_his256
37-
aff_threshold_low: 0
38-
aff_threshold_high: 1
37+
aff_threshold_low: 0.1
38+
aff_threshold_high: 0.9
3939
border_threshold: 0.3
4040
channel_order: xyz
41-
#use_aff_uint8: true # uint8 affinities (4x less aff memory)
42-
#use_seg_uint32: true # uint32 segment IDs (2x less seg memory)
41+
use_aff_uint8: true # uint8 affinities (4x less aff memory)
42+
use_seg_uint32: true # uint32 segment IDs (2x less seg memory)
4343

4444
# Fragment initialization (per chunk, overlap pipeline only)
4545
#compute_fragments: true # 2D slice-by-slice mahotas watershed
4646
#seed_method: maxima_distance # maxima_distance[-T], minima[-T], grid[-N]
4747

4848
# Dust merge (per chunk, after agglomeration)
4949
dust_merge: true
50-
dust_merge_size: 400
51-
dust_merge_affinity: 0.3
52-
dust_remove_size: 200
50+
dust_merge_size: 1500
51+
dust_merge_affinity: 0.1
52+
dust_remove_size: 600
5353

5454
# Border stitching / overlap merge (same params as face_merge_pairs)
5555
min_overlap: 10 # min overlap pixels to consider a pair
@@ -58,6 +58,15 @@ large_decode:
5858
one_sided_min_size: 0 # min segment size in face for one-sided merge
5959
affinity_threshold: 0.0 # min boundary affinity (0=disabled)
6060

61+
# Execution
62+
backend: multiprocess # "multiprocess" or "slurm"
63+
num_workers: 1 # local parallel workers (--parallel default)
64+
slurm:
65+
partition: short
66+
mem: 64G
67+
cpus_per_task: 1
68+
time: "12:00:00"
69+
6170
# Output
6271
write_output: true
6372
output_path: "" # default: workflow_root/assembled.h5

0 commit comments

Comments
 (0)