Skip to content

Commit afa5ca0

Browse files
Donglai Weiclaude
andcommitted
Fix ABISS batch mode, remove silent fallback, skip model build on cache hit
- ABISS batch merge-threshold: restore proper single-invocation mode (C++ binary already supports argv[8..N] batch thresholds natively), improve error message - Remove silent _fallback_decode_connected_components that masked ABISS errors and produced garbage segmentations (0.918 vs 0.07 adapted_rand) - Skip expensive model build + checkpoint load when cached _tta_prediction.h5 files exist for all test/tune volumes (nn.Identity lightweight module) - Add _skip_inference guard in test_pipeline to error early if cache miss occurs with a dummy model instead of running inference - Fix on_test_end → on_test_epoch_end for Lightning self.log() compatibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 43a8722 commit afa5ca0

5 files changed

Lines changed: 81 additions & 55 deletions

File tree

connectomics/decoding/decoders/abiss.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -153,31 +153,6 @@ def _load_output(output_h5: Path, output_npy: Path, output_dataset: str) -> np.n
153153
return cast2dtype(seg)
154154

155155

156-
def _command_uses_run_abiss_single(cmd: str | Sequence[str]) -> bool:
157-
if isinstance(cmd, str):
158-
try:
159-
tokens = shlex.split(cmd)
160-
except ValueError:
161-
return "run_abiss_single.py" in cmd
162-
return any("run_abiss_single.py" in token for token in tokens)
163-
return any("run_abiss_single.py" in str(token) for token in cmd)
164-
165-
166-
def _fallback_decode_connected_components(pred: np.ndarray) -> np.ndarray:
167-
"""Lightweight fallback when ABISS executable dependencies are unavailable."""
168-
from scipy import ndimage as ndi
169-
170-
if pred.ndim == 4:
171-
if pred.shape[0] == 1:
172-
foreground = pred[0] > 0.5
173-
else:
174-
foreground = np.max(pred, axis=0) > 0.5
175-
else:
176-
foreground = pred > 0.5
177-
178-
labeled, _ = ndi.label(foreground.astype(np.uint8, copy=False))
179-
return cast2dtype(labeled.astype(np.uint64, copy=False))
180-
181156

182157
def decode_abiss(
183158
predictions: np.ndarray,
@@ -305,22 +280,14 @@ def decode_abiss(
305280
if env:
306281
proc_env.update({str(k): str(v) for k, v in env.items()})
307282

308-
try:
309-
subprocess.run(
310-
cmd,
311-
shell=use_shell,
312-
env=proc_env,
313-
cwd=str(workspace_path),
314-
check=check,
315-
timeout=timeout_sec,
316-
)
317-
except subprocess.CalledProcessError:
318-
if _command_uses_run_abiss_single(cmd):
319-
if batch_mt:
320-
seg = _fallback_decode_connected_components(pred)
321-
return {round(mt, 10): seg for mt in batch_mt}
322-
return _fallback_decode_connected_components(pred)
323-
raise
283+
subprocess.run(
284+
cmd,
285+
shell=use_shell,
286+
env=proc_env,
287+
cwd=str(workspace_path),
288+
check=check,
289+
timeout=timeout_sec,
290+
)
324291

325292
# Batch mode: read multiple output files written by run_abiss_single.
326293
if batch_mt:

connectomics/training/lightning/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def on_test_start(self):
767767
"and letting MONAI move window batches to the configured sw_device."
768768
)
769769

770-
def on_test_end(self) -> None:
770+
def on_test_epoch_end(self) -> None:
771771
"""Log aggregated test metrics after all ranks finish their assigned volumes."""
772772
log_test_epoch_metrics(self)
773773

connectomics/training/lightning/test_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,16 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
921921
return torch.tensor(0.0, device=module.device)
922922

923923
logger.info("No cached predictions found, running inference")
924+
925+
# If the model is a lightweight dummy (e.g. nn.Identity), inference would
926+
# produce garbage. Error out early instead of crashing later in TTA.
927+
if getattr(module, "_skip_inference", False):
928+
raise RuntimeError(
929+
"Cached predictions expected but not found for this volume. "
930+
"Cannot run inference with a lightweight (dummy) model. "
931+
"Re-run with the real model checkpoint to generate predictions first."
932+
)
933+
924934
_log_volume_header(volume_name, "INFERENCE PLAN")
925935
logger.info(f"Input shape: {tuple(images.shape)}")
926936
logger.info(f"Input device: {images.device}")

scripts/main.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,41 @@ def _resolve_tta_result_path_override(cfg: Config) -> str:
260260
return ""
261261

262262

263+
def _has_tta_prediction_file(cfg: Config) -> bool:
264+
"""Return True if an explicit tta_result_path exists and is a valid HDF5 file."""
265+
tta_path = _resolve_tta_result_path_override(cfg)
266+
if not tta_path:
267+
return False
268+
pred_file = Path(tta_path).expanduser()
269+
if not pred_file.is_absolute():
270+
pred_file = Path.cwd() / pred_file
271+
return pred_file.exists() and _is_valid_hdf5_prediction_file(pred_file)
272+
273+
274+
def _has_cached_predictions_in_output_dir(cfg: Config, mode: str) -> bool:
275+
"""Return True if all expected _tta_prediction.h5 files exist in the output directory."""
276+
save_pred_cfg = getattr(cfg.inference, "save_prediction", None)
277+
if save_pred_cfg is None:
278+
return False
279+
output_dir = getattr(save_pred_cfg, "output_path", None)
280+
if not output_dir:
281+
return False
282+
283+
# Resolve test/tune image paths to derive expected prediction filenames.
284+
test_image_paths = resolve_test_image_paths(cfg)
285+
if not test_image_paths:
286+
return False
287+
288+
output_path = Path(output_dir)
289+
for image_path in test_image_paths:
290+
pred_file = output_path / f"{Path(image_path).stem}_tta_prediction.h5"
291+
if not pred_file.exists():
292+
return False
293+
if not _is_valid_hdf5_prediction_file(pred_file):
294+
return False
295+
return True
296+
297+
263298
def preflight_test_cache_hit(cfg: Config, datamodule) -> tuple[bool, str | None, int]:
264299
"""Check if test outputs already exist so inference (and ckpt restore) can be skipped."""
265300
save_pred_cfg = getattr(cfg.inference, "save_prediction", None)
@@ -868,23 +903,35 @@ def main():
868903
if try_cache_only_test_execution(cfg, args.mode, args.shard_id, args.num_shards):
869904
return
870905

871-
# Create model
872-
print(f"Creating model: {cfg.model.arch.type}")
873-
model = ConnectomicsModule(cfg)
874-
875-
# Count parameters
876-
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
877-
print(f" Model parameters: {num_params:,}")
906+
# Check for cached intermediate predictions early so we can skip both the
907+
# expensive model build and checkpoint restore for test/tune modes.
908+
tta_cached = args.mode in ("test", "tune", "tune-test") and (
909+
_has_tta_prediction_file(cfg)
910+
or _has_cached_predictions_in_output_dir(cfg, mode=args.mode)
911+
)
878912

879-
# Don't use checkpoint path if external weights were loaded (already in model state)
880-
# External weights are loaded during config setup via model.external_weights_path
881-
if args.external_prefix:
913+
# Create model
914+
if tta_cached:
915+
print(
916+
f" Cached intermediate predictions found; "
917+
f"creating lightweight module (skipping {cfg.model.arch.type} build)."
918+
)
919+
model = ConnectomicsModule(cfg, model=torch.nn.Identity())
920+
model._skip_inference = True
921+
ckpt_path = None
922+
elif args.external_prefix:
923+
print(f"Creating model: {cfg.model.arch.type}")
924+
model = ConnectomicsModule(cfg)
882925
print(
883926
" WARNING: External weights loaded - checkpoint path will not "
884927
"be used for training/testing"
885928
)
886929
ckpt_path = None
887930
else:
931+
print(f"Creating model: {cfg.model.arch.type}")
932+
model = ConnectomicsModule(cfg)
933+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
934+
print(f" Model parameters: {num_params:,}")
888935
ckpt_path = modify_checkpoint_state(
889936
args.checkpoint,
890937
run_dir,

scripts/run_abiss_single.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,10 @@ def _run_abiss_ws(
304304
_ABISS_TAG,
305305
]
306306

307-
# Batch mode: multiple merge thresholds in one run (argv[8..N]).
307+
# Batch mode: pass multiple merge thresholds as argv[8..N].
308308
# The C++ binary computes watershed + region graph once, then
309-
# repeats the merge step for each threshold.
309+
# deep-copies and repeats the merge step for each threshold,
310+
# writing indexed output files (seg_{TAG}_{i}.data).
310311
use_batch = ws_merge_thresholds is not None and len(ws_merge_thresholds) > 1
311312
if use_batch:
312313
for mt in ws_merge_thresholds:
@@ -322,7 +323,8 @@ def _run_abiss_ws(
322323
seg_file = ws_dir / f"seg_{_ABISS_TAG}_{i}.data"
323324
if not seg_file.exists():
324325
raise FileNotFoundError(
325-
f"ABISS watershed did not produce expected output: {seg_file}"
326+
f"ABISS batch mode did not produce expected output: {seg_file}. "
327+
f"Ensure the ws binary at {ws_binary} supports multi-threshold mode."
326328
)
327329
seg_xyz = _read_segmentation_xyz(seg_file, output_xyz_shape, halo=1)
328330
results[round(mt, 10)] = np.transpose(seg_xyz, (2, 1, 0))

0 commit comments

Comments
 (0)