Skip to content

Commit 205d1f6

Browse files
Donglai Weiclaude
andcommitted
Fix temp dir pollution, verbose tuner logging, and ARE discrepancy
- Move temp directories to system tmpdir instead of repo root (abiss.py, run_abiss_single.py) - Use print() for per-threshold ARE breakdown so it always shows - Save intermediate _tta_prediction.h5 BEFORE crop_pad/affinity_crop so the tuner (which applies its own crops) works from the same pre-crop data as the test pipeline, fixing ARE discrepancy between test and tune Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent afa5ca0 commit 205d1f6

4 files changed

Lines changed: 33 additions & 32 deletions

File tree

connectomics/decoding/decoders/abiss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import shlex
1414
import subprocess
1515
import sys
16+
import tempfile
1617
from pathlib import Path
17-
from tempfile import TemporaryDirectory
1818
from typing import Any, Dict, List, Mapping, Optional, Sequence
1919

2020
import numpy as np
@@ -221,7 +221,7 @@ def decode_abiss(
221221
workspace_path.mkdir(parents=True, exist_ok=True)
222222
temp_ctx = None
223223
else:
224-
temp_ctx = TemporaryDirectory(prefix="decode_abiss_")
224+
temp_ctx = tempfile.TemporaryDirectory(prefix="decode_abiss_", dir=tempfile.gettempdir())
225225
workspace_path = Path(temp_ctx.name).resolve()
226226
launch_cwd = Path.cwd().resolve()
227227
package_root = Path(__file__).resolve().parents[2]

connectomics/decoding/tuning/optuna_tuner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,9 @@ def _abiss_batch_objective(
619619
for mt in self._abiss_all_merge_thresholds
620620
if mt_are[mt]
621621
)
622-
logger.info(
623-
"Trial %3d: best ARE=%.4f (mt=%.2f) Prec=%.4f Rec=%.4f | %s",
624-
self.trial_count,
625-
best_avg,
626-
best_mt,
627-
avg_prec,
628-
avg_rec,
629-
mt_summary,
622+
print(
623+
f" Trial {self.trial_count:3d}: best ARE={best_avg:.4f} (mt={best_mt:.2f}) "
624+
f"Prec={avg_prec:.4f} Rec={avg_rec:.4f} | {mt_summary}"
630625
)
631626

632627
return best_avg
@@ -1230,7 +1225,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
12301225
all_predictions = []
12311226
for pred_file in pred_files:
12321227
pred = read_volume(pred_file)
1233-
logger.info("Loaded %s: shape %s", Path(pred_file).name, pred.shape)
1228+
logger.info("Loaded %s: shape %s, dtype %s, range [%.4f, %.4f]",
1229+
Path(pred_file).name, pred.shape, pred.dtype, pred.min(), pred.max())
12341230
all_predictions.append(pred)
12351231

12361232
total_slices = sum(p.shape[1] for p in all_predictions)

connectomics/training/lightning/test_pipeline.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,30 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
973973
_distributed_tta_barrier(module)
974974
return torch.tensor(0.0, device=module.device)
975975
predictions_np = predictions.detach().cpu().float().numpy()
976+
977+
# Save intermediate predictions BEFORE crop_pad/affinity_crop so that
978+
# the tuner (which applies its own crops to both predictions and labels)
979+
# works from the same pre-crop data as the test pipeline.
980+
inference_duration = time.time() - inference_start
981+
logger.info(
982+
f"Inference completed in {inference_duration / 60:.2f} minutes ({inference_duration:.1f}s)"
983+
)
984+
985+
save_intermediate = bool(getattr(inference_cfg.save_prediction, "enabled", False))
986+
if save_intermediate:
987+
logger.info("[STAGE: Saving Intermediate Predictions]")
988+
save_start = time.time()
989+
predictions_to_save = apply_save_prediction_transform(module.cfg, predictions_np)
990+
write_outputs(
991+
module.cfg,
992+
predictions_to_save,
993+
filenames,
994+
suffix="tta_prediction",
995+
mode=mode,
996+
batch_meta=batch.get("image_meta_dict"),
997+
)
998+
logger.info(f"Intermediate predictions saved ({time.time() - save_start:.1f}s)")
999+
9761000
predictions_np = _apply_prediction_crop_pad_if_needed(
9771001
module,
9781002
predictions_np,
@@ -986,10 +1010,6 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
9861010
reference_spatial_shape=reference_spatial_shape,
9871011
item_name="predictions",
9881012
)
989-
inference_duration = time.time() - inference_start
990-
logger.info(
991-
f"Inference completed in {inference_duration / 60:.2f} minutes ({inference_duration:.1f}s)"
992-
)
9931013

9941014
logger.info("Prediction Summary:")
9951015
logger.info(f" Shape: {predictions_np.shape}")
@@ -998,21 +1018,6 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
9981018
logger.info(f" Max: {predictions_np.max():.6f}")
9991019
logger.info(f" Mean: {predictions_np.mean():.6f}")
10001020

1001-
save_intermediate = bool(getattr(inference_cfg.save_prediction, "enabled", False))
1002-
if save_intermediate:
1003-
logger.info("[STAGE: Saving Intermediate Predictions]")
1004-
save_start = time.time()
1005-
predictions_to_save = apply_save_prediction_transform(module.cfg, predictions_np)
1006-
write_outputs(
1007-
module.cfg,
1008-
predictions_to_save,
1009-
filenames,
1010-
suffix="tta_prediction",
1011-
mode=mode,
1012-
batch_meta=batch.get("image_meta_dict"),
1013-
)
1014-
logger.info(f"Intermediate predictions saved ({time.time() - save_start:.1f}s)")
1015-
10161021
# In tune mode, skip decoding — the Optuna tuner will handle it.
10171022
if mode == "tune":
10181023
logger.info("Tune mode: skipping decoding (Optuna tuner will handle it)")

scripts/run_abiss_single.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import shutil
1111
import subprocess
1212
import sys
13+
import tempfile
1314
from pathlib import Path
14-
from tempfile import TemporaryDirectory
1515
from typing import Iterable, Optional
1616

1717
import numpy as np
@@ -282,7 +282,7 @@ def _run_abiss_ws(
282282
ws_dir.mkdir(parents=True, exist_ok=True)
283283
temp_ctx = None
284284
else:
285-
temp_ctx = TemporaryDirectory(prefix="abiss_single_")
285+
temp_ctx = tempfile.TemporaryDirectory(prefix="abiss_single_", dir=tempfile.gettempdir())
286286
ws_dir = Path(temp_ctx.name).resolve()
287287

288288
try:

0 commit comments

Comments
 (0)