Skip to content

Commit a59936a

Browse files
Donglai Weiclaude
andcommitted
Encode TTA pass count in prediction cache filename, remove tuner affinity crop
- Rename _tta_prediction.h5 → _tta_x{N}_prediction.h5 where N is the number of TTA passes (e.g. _tta_x1_prediction.h5 when TTA is disabled) - Add compute_tta_passes(), tta_cache_suffix(), is_tta_cache_suffix() to connectomics.training.lightning.utils - Remove _maybe_crop_affinity_array from tuner — cached predictions are already fully cropped by the test pipeline; add shape validation that errors on mismatch instead of silently hiding it - Replace all hardcoded "_tta_prediction.h5" with dynamic suffix - Fallback cache lookup uses _tta_x*_prediction.h5 glob pattern Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8994fea commit a59936a

8 files changed

Lines changed: 119 additions & 96 deletions

File tree

connectomics/config/schema/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TuneOutputConfig:
5555

5656
output_dir: Optional[str] = None
5757
output_pred: Optional[str] = None
58-
cache_suffix: str = "_tta_prediction.h5"
58+
cache_suffix: str = "_tta_x1_prediction.h5"
5959
save_all_trials: bool = False
6060
save_best_segmentation: bool = True
6161
save_study: bool = True

connectomics/decoding/tuning/optuna_tuner.py

Lines changed: 22 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,8 @@
3737
except ImportError:
3838
OPTUNA_AVAILABLE = False
3939

40-
from connectomics.data.process.affinity import (
41-
affinity_deepem_crop_enabled,
42-
compute_affinity_crop_pad,
43-
crop_spatial_by_pad,
44-
resolve_affinity_channel_groups_from_cfg,
45-
)
4640
from connectomics.metrics.metrics_seg import adapted_rand
41+
from connectomics.training.lightning.utils import tta_cache_suffix
4742

4843
from ..registry import get_decoder
4944
from ..utils import remove_small_instances
@@ -53,26 +48,6 @@
5348
__all__ = ["OptunaDecodingTuner", "run_tuning", "load_and_apply_best_params"]
5449

5550

56-
def _maybe_crop_affinity_array(
57-
data: np.ndarray,
58-
*,
59-
reference_spatial_shape: tuple[int, ...],
60-
crop_pad: tuple[tuple[int, int], ...],
61-
) -> np.ndarray:
62-
if not crop_pad:
63-
return data
64-
expected_cropped_shape = tuple(
65-
int(reference_spatial_shape[axis]) - crop_pad[axis][0] - crop_pad[axis][1]
66-
for axis in range(len(crop_pad))
67-
)
68-
data_spatial_shape = tuple(int(v) for v in data.shape[-len(crop_pad) :])
69-
if data_spatial_shape == expected_cropped_shape:
70-
return data
71-
if data_spatial_shape != reference_spatial_shape:
72-
return data
73-
return crop_spatial_by_pad(data, crop_pad, item_name="tuning array")
74-
75-
7651
def _expand_tuning_paths(path_or_pattern: Any, *, field_name: str) -> list[str]:
7752
"""Expand string/list path inputs used by the tuning loader."""
7853
import glob
@@ -119,9 +94,12 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
11994
"""Force the pre-Optuna inference pass to cache raw predictions only."""
12095
inference_cfgs = []
12196
seen_inference_cfgs: set[int] = set()
97+
primary_cfg = None
12298
for cfg_obj in cfg_objects:
12399
if cfg_obj is None:
124100
continue
101+
if primary_cfg is None:
102+
primary_cfg = cfg_obj
125103
inference_cfg = getattr(cfg_obj, "inference", None)
126104
if inference_cfg is None or id(inference_cfg) in seen_inference_cfgs:
127105
continue
@@ -131,6 +109,8 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
131109
if not inference_cfgs:
132110
raise ValueError("Missing runtime cfg.inference configuration required for tuning")
133111

112+
suffix = tta_cache_suffix(primary_cfg) if primary_cfg is not None else "_tta_x1_prediction.h5"
113+
134114
backups = []
135115
for inference_cfg in inference_cfgs:
136116
save_prediction_cfg = getattr(inference_cfg, "save_prediction", None)
@@ -156,13 +136,13 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
156136
)
157137

158138
save_prediction_cfg.enabled = True
159-
save_prediction_cfg.cache_suffix = "_tta_prediction.h5"
139+
save_prediction_cfg.cache_suffix = suffix
160140
inference_cfg.decoding = None
161141
if evaluation_cfg is not None:
162142
evaluation_cfg.enabled = False
163143

164144
try:
165-
yield "_tta_prediction.h5"
145+
yield suffix
166146
finally:
167147
for backup in backups:
168148
inference_cfg = backup["inference_cfg"]
@@ -1164,7 +1144,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11641144
logger.info("[1/4] Running inference on tuning dataset...")
11651145

11661146
tune_data = cfg.data
1167-
cache_suffix = "_tta_prediction.h5"
1147+
cache_suffix = tta_cache_suffix(cfg)
11681148

11691149
output_pred_dir = cfg.inference.save_prediction.output_path
11701150
predictions_dir = Path(output_pred_dir)
@@ -1294,45 +1274,19 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
12941274
f"Mismatch: {len(all_predictions)} prediction files vs " f"{len(all_masks)} mask files"
12951275
)
12961276

1297-
if affinity_deepem_crop_enabled(cfg):
1298-
groups = resolve_affinity_channel_groups_from_cfg(cfg)
1299-
all_offsets = []
1300-
for _, offsets in groups:
1301-
all_offsets.extend(offsets)
1302-
crop_pad = compute_affinity_crop_pad(all_offsets)
1303-
if crop_pad and any(before or after for before, after in crop_pad):
1304-
cropped_predictions = []
1305-
cropped_labels = []
1306-
cropped_masks = [] if all_masks is not None else None
1307-
for idx, pred in enumerate(all_predictions):
1308-
reference_spatial_shape = tuple(
1309-
int(v) for v in all_labels[idx].shape[-len(crop_pad) :]
1310-
)
1311-
cropped_predictions.append(
1312-
_maybe_crop_affinity_array(
1313-
np.asarray(pred),
1314-
reference_spatial_shape=reference_spatial_shape,
1315-
crop_pad=crop_pad,
1316-
)
1317-
)
1318-
cropped_labels.append(
1319-
_maybe_crop_affinity_array(
1320-
np.asarray(all_labels[idx]),
1321-
reference_spatial_shape=reference_spatial_shape,
1322-
crop_pad=crop_pad,
1323-
)
1324-
)
1325-
if cropped_masks is not None:
1326-
cropped_masks.append(
1327-
_maybe_crop_affinity_array(
1328-
np.asarray(all_masks[idx]),
1329-
reference_spatial_shape=reference_spatial_shape,
1330-
crop_pad=crop_pad,
1331-
)
1332-
)
1333-
all_predictions = cropped_predictions
1334-
all_labels = cropped_labels
1335-
all_masks = cropped_masks
1277+
# Validate that prediction and label spatial shapes match.
1278+
# Cached TTA prediction files are saved after crop_pad + affinity_crop
1279+
# in the test pipeline, so they should already align with the label volume.
1280+
for idx, pred in enumerate(all_predictions):
1281+
pred_spatial = tuple(int(v) for v in pred.shape[-3:])
1282+
label_spatial = tuple(int(v) for v in all_labels[idx].shape[-3:])
1283+
if pred_spatial != label_spatial:
1284+
raise ValueError(
1285+
f"Prediction/label spatial shape mismatch for volume {idx}: "
1286+
f"prediction {pred_spatial} vs label {label_spatial}. "
1287+
f"Cached predictions may be stale — regenerate TTA predictions "
1288+
f"by re-running inference with the real model checkpoint."
1289+
)
13361290

13371291
# Step 4: Create tuner and run optimization (per-volume evaluation)
13381292
logger.info("[4/5] Creating Optuna tuner...")

connectomics/training/lightning/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
from .runtime import cleanup_run_directory, modify_checkpoint_state, setup_run_directory
1818
from .trainer import create_trainer
1919
from .utils import (
20+
compute_tta_passes,
2021
extract_best_score_from_checkpoint,
22+
is_tta_cache_suffix,
2123
parse_args,
2224
setup_config,
2325
setup_seed_everything,
26+
tta_cache_suffix,
2427
)
2528

2629
__all__ = [
@@ -41,4 +44,7 @@
4144
"parse_args",
4245
"setup_config",
4346
"extract_best_score_from_checkpoint",
47+
"compute_tta_passes",
48+
"tta_cache_suffix",
49+
"is_tta_cache_suffix",
4450
]

connectomics/training/lightning/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ...models import build_model
4545
from ...models.loss import create_loss, get_loss_metadata_for_module
4646
from ..debugging import DebugManager
47+
from .utils import is_tta_cache_suffix, tta_cache_suffix
4748

4849
# Import training/inference components
4950
from ..loss import LossOrchestrator, build_loss_weighter, infer_num_loss_tasks_from_config
@@ -445,7 +446,7 @@ def _load_cached_predictions(
445446
f"{len(filenames)} filenames; decoding will use the explicit file only."
446447
)
447448
# Treat explicit file as intermediate prediction so decoding still runs.
448-
return pred, True, "_tta_prediction.h5"
449+
return pred, True, tta_cache_suffix(self.cfg)
449450
except Exception as e:
450451
logger.warning(
451452
f"Failed to load explicit inference.tta_result_path file {pred_file}: {e}. "
@@ -467,11 +468,11 @@ def _load_cached_predictions(
467468

468469
for filename in filenames:
469470
pred_file = output_dir / f"{filename}{cache_suffix}"
470-
if not pred_file.exists() and mode == "test" and cache_suffix != "_tta_prediction.h5":
471-
tta_pred_file = output_dir / f"{filename}_tta_prediction.h5"
472-
if tta_pred_file.exists():
473-
pred_file = tta_pred_file
474-
loaded_suffix = "_tta_prediction.h5"
471+
if not pred_file.exists() and mode == "test" and not is_tta_cache_suffix(cache_suffix):
472+
tta_matches = sorted(output_dir.glob(f"{filename}_tta_x*_prediction.h5"))
473+
if tta_matches:
474+
pred_file = tta_matches[-1]
475+
loaded_suffix = pred_file.name[len(filename):]
475476

476477
if pred_file.exists():
477478
try:

connectomics/training/lightning/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,68 @@ def setup_seed_everything():
346346
return seed_everything
347347

348348

349+
def compute_tta_passes(cfg: Config, spatial_dims: int = 3) -> int:
350+
"""Return the total number of TTA inference passes from config.
351+
352+
This determines the multiplier in the cached prediction filename
353+
(e.g. ``_tta_x16_prediction.h5``). When TTA is disabled the count is 1.
354+
"""
355+
inference_cfg = getattr(cfg, "inference", None)
356+
if inference_cfg is None:
357+
return 1
358+
tta_cfg = getattr(inference_cfg, "test_time_augmentation", None)
359+
if tta_cfg is None or not bool(getattr(tta_cfg, "enabled", False)):
360+
return 1
361+
362+
flip_axes_cfg = getattr(tta_cfg, "flip_axes", None)
363+
rotation90_axes_cfg = getattr(tta_cfg, "rotation90_axes", None)
364+
365+
def _cfg_len(value):
366+
if value is None or isinstance(value, str):
367+
return 0
368+
try:
369+
return len(value)
370+
except TypeError:
371+
return 0
372+
373+
if flip_axes_cfg == "all" or flip_axes_cfg == []:
374+
flip_variants = 2 ** spatial_dims if spatial_dims > 0 else 1
375+
elif flip_axes_cfg is None:
376+
flip_variants = 1
377+
else:
378+
flip_variants = 1 + _cfg_len(flip_axes_cfg)
379+
380+
if rotation90_axes_cfg == "all":
381+
rotation_planes = 3 if spatial_dims == 3 else 1 if spatial_dims == 2 else 0
382+
elif rotation90_axes_cfg is None:
383+
rotation_planes = 0
384+
else:
385+
rotation_planes = _cfg_len(rotation90_axes_cfg)
386+
387+
passes_per_flip = 1 if rotation_planes == 0 else rotation_planes * 4
388+
return flip_variants * passes_per_flip
389+
390+
391+
def tta_cache_suffix(cfg: Config, spatial_dims: int = 3) -> str:
392+
"""Return the TTA prediction cache suffix, e.g. ``_tta_x1_prediction.h5``."""
393+
n = compute_tta_passes(cfg, spatial_dims=spatial_dims)
394+
return f"_tta_x{n}_prediction.h5"
395+
396+
397+
def is_tta_cache_suffix(suffix: str | None) -> bool:
398+
"""Return True for any TTA intermediate prediction suffix (``_tta_x*_prediction.h5``)."""
399+
if not suffix:
400+
return False
401+
return suffix.startswith("_tta_x") and suffix.endswith("_prediction.h5")
402+
403+
349404
__all__ = [
350405
"parse_args",
351406
"setup_config",
352407
"expand_file_paths",
353408
"extract_best_score_from_checkpoint",
354409
"setup_seed_everything",
410+
"compute_tta_passes",
411+
"tta_cache_suffix",
412+
"is_tta_cache_suffix",
355413
]

scripts/main.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@
7474
cleanup_run_directory,
7575
create_datamodule,
7676
create_trainer,
77+
is_tta_cache_suffix,
7778
modify_checkpoint_state,
7879
parse_args,
7980
setup_config,
8081
setup_run_directory,
8182
setup_seed_everything,
83+
tta_cache_suffix,
8284
)
8385

8486
# Setup seed_everything helper
@@ -211,20 +213,21 @@ def _resolve_cached_prediction_files(
211213
pred_file = output_dir / f"{filename}{cache_suffix}"
212214
current_suffix = cache_suffix
213215

214-
if not pred_file.exists() and cache_suffix != "_tta_prediction.h5":
215-
tta_pred_file = output_dir / f"{filename}_tta_prediction.h5"
216-
if tta_pred_file.exists():
217-
pred_file = tta_pred_file
218-
current_suffix = "_tta_prediction.h5"
216+
# Fallback: search for any _tta_x*_prediction.h5 if exact suffix not found.
217+
if not pred_file.exists() and not is_tta_cache_suffix(cache_suffix):
218+
tta_matches = sorted(output_dir.glob(f"{filename}_tta_x*_prediction.h5"))
219+
if tta_matches:
220+
pred_file = tta_matches[-1] # latest / highest augmentation count
221+
current_suffix = pred_file.name[len(filename):]
219222

220223
if not pred_file.exists():
221224
return False, None, []
222225

223226
if not _is_valid_hdf5_prediction_file(pred_file):
224227
return False, None, []
225228

226-
if current_suffix == "_tta_prediction.h5":
227-
loaded_suffix = "_tta_prediction.h5"
229+
if is_tta_cache_suffix(current_suffix):
230+
loaded_suffix = current_suffix
228231
resolved_files.append(pred_file)
229232

230233
return True, loaded_suffix, resolved_files
@@ -272,7 +275,7 @@ def _has_tta_prediction_file(cfg: Config) -> bool:
272275

273276

274277
def _has_cached_predictions_in_output_dir(cfg: Config, mode: str) -> bool:
275-
"""Return True if all expected _tta_prediction.h5 files exist in the output directory."""
278+
"""Return True if all expected TTA prediction files exist in the output directory."""
276279
save_pred_cfg = getattr(cfg.inference, "save_prediction", None)
277280
if save_pred_cfg is None:
278281
return False
@@ -285,9 +288,10 @@ def _has_cached_predictions_in_output_dir(cfg: Config, mode: str) -> bool:
285288
if not test_image_paths:
286289
return False
287290

291+
suffix = tta_cache_suffix(cfg)
288292
output_path = Path(output_dir)
289293
for image_path in test_image_paths:
290-
pred_file = output_path / f"{Path(image_path).stem}_tta_prediction.h5"
294+
pred_file = output_path / f"{Path(image_path).stem}{suffix}"
291295
if not pred_file.exists():
292296
return False
293297
if not _is_valid_hdf5_prediction_file(pred_file):
@@ -309,7 +313,7 @@ def preflight_test_cache_hit(cfg: Config, datamodule) -> tuple[bool, str | None,
309313

310314
# If explicit intermediate prediction exists, skip TTA inference and ckpt restore.
311315
if pred_file.exists() and _is_valid_hdf5_prediction_file(pred_file):
312-
return True, "_tta_prediction.h5", 1
316+
return True, tta_cache_suffix(cfg), 1
313317

314318
print(
315319
" WARNING: inference.tta_result_path file missing or unreadable "
@@ -691,7 +695,7 @@ def try_cache_only_test_execution(
691695
)
692696
return False
693697

694-
if loaded_suffix != "_tta_prediction.h5":
698+
if not is_tta_cache_suffix(loaded_suffix):
695699
if _is_test_evaluation_enabled(cfg):
696700
print(
697701
" [OK]Loaded final predictions from disk, skipping "
@@ -766,7 +770,7 @@ def _configure_checkpoint_output_paths(args, cfg: Config) -> tuple[Path | None,
766770

767771
save_pred_cfg = cfg.inference.save_prediction
768772
save_pred_cfg.output_path = str(output_base / results_folder_name)
769-
save_pred_cfg.cache_suffix = "_tta_prediction.h5"
773+
save_pred_cfg.cache_suffix = tta_cache_suffix(cfg)
770774

771775
if args.mode == "tune-test":
772776
print(f"Test output: {save_pred_cfg.output_path}")
@@ -812,7 +816,7 @@ def _handle_test_cache_hit(
812816
ckpt_path: str | None,
813817
) -> tuple[bool, None]:
814818
"""Print cache-hit status and return whether the test loop can be skipped."""
815-
if cached_suffix == "_tta_prediction.h5":
819+
if is_tta_cache_suffix(cached_suffix):
816820
print(" [OK]Loaded intermediate predictions from disk, skipping inference")
817821
else:
818822
print(
@@ -828,7 +832,7 @@ def _handle_test_cache_hit(
828832

829833
should_skip_test_loop = (
830834
args.mode == "test"
831-
and cached_suffix != "_tta_prediction.h5"
835+
and not is_tta_cache_suffix(cached_suffix)
832836
and not _is_test_evaluation_enabled(cfg)
833837
)
834838
if should_skip_test_loop:
@@ -876,7 +880,7 @@ def main():
876880

877881
# Tuning expects cached intermediate predictions by default.
878882
if args.mode in ["tune", "tune-test"]:
879-
cfg.inference.save_prediction.cache_suffix = "_tta_prediction.h5"
883+
cfg.inference.save_prediction.cache_suffix = tta_cache_suffix(cfg)
880884

881885
# Run preflight checks for training mode
882886
if args.mode == "train":

0 commit comments

Comments
 (0)