3737except ImportError :
3838 OPTUNA_AVAILABLE = False
3939
40- from connectomics .data .process .affinity import (
41- affinity_deepem_crop_enabled ,
42- compute_affinity_crop_pad ,
43- crop_spatial_by_pad ,
44- resolve_affinity_channel_groups_from_cfg ,
45- )
4640from connectomics .metrics .metrics_seg import adapted_rand
41+ from connectomics .training .lightning .utils import tta_cache_suffix
4742
4843from ..registry import get_decoder
4944from ..utils import remove_small_instances
5348__all__ = ["OptunaDecodingTuner" , "run_tuning" , "load_and_apply_best_params" ]
5449
5550
56- def _maybe_crop_affinity_array (
57- data : np .ndarray ,
58- * ,
59- reference_spatial_shape : tuple [int , ...],
60- crop_pad : tuple [tuple [int , int ], ...],
61- ) -> np .ndarray :
62- if not crop_pad :
63- return data
64- expected_cropped_shape = tuple (
65- int (reference_spatial_shape [axis ]) - crop_pad [axis ][0 ] - crop_pad [axis ][1 ]
66- for axis in range (len (crop_pad ))
67- )
68- data_spatial_shape = tuple (int (v ) for v in data .shape [- len (crop_pad ) :])
69- if data_spatial_shape == expected_cropped_shape :
70- return data
71- if data_spatial_shape != reference_spatial_shape :
72- return data
73- return crop_spatial_by_pad (data , crop_pad , item_name = "tuning array" )
74-
75-
7651def _expand_tuning_paths (path_or_pattern : Any , * , field_name : str ) -> list [str ]:
7752 """Expand string/list path inputs used by the tuning loader."""
7853 import glob
@@ -119,9 +94,12 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
11994 """Force the pre-Optuna inference pass to cache raw predictions only."""
12095 inference_cfgs = []
12196 seen_inference_cfgs : set [int ] = set ()
97+ primary_cfg = None
12298 for cfg_obj in cfg_objects :
12399 if cfg_obj is None :
124100 continue
101+ if primary_cfg is None :
102+ primary_cfg = cfg_obj
125103 inference_cfg = getattr (cfg_obj , "inference" , None )
126104 if inference_cfg is None or id (inference_cfg ) in seen_inference_cfgs :
127105 continue
@@ -131,6 +109,8 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
131109 if not inference_cfgs :
132110 raise ValueError ("Missing runtime cfg.inference configuration required for tuning" )
133111
112+ suffix = tta_cache_suffix (primary_cfg ) if primary_cfg is not None else "_tta_x1_prediction.h5"
113+
134114 backups = []
135115 for inference_cfg in inference_cfgs :
136116 save_prediction_cfg = getattr (inference_cfg , "save_prediction" , None )
@@ -156,13 +136,13 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
156136 )
157137
158138 save_prediction_cfg .enabled = True
159- save_prediction_cfg .cache_suffix = "_tta_prediction.h5"
139+ save_prediction_cfg .cache_suffix = suffix
160140 inference_cfg .decoding = None
161141 if evaluation_cfg is not None :
162142 evaluation_cfg .enabled = False
163143
164144 try :
165- yield "_tta_prediction.h5"
145+ yield suffix
166146 finally :
167147 for backup in backups :
168148 inference_cfg = backup ["inference_cfg" ]
@@ -1164,7 +1144,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11641144 logger .info ("[1/4] Running inference on tuning dataset..." )
11651145
11661146 tune_data = cfg .data
1167- cache_suffix = "_tta_prediction.h5"
1147+ cache_suffix = tta_cache_suffix ( cfg )
11681148
11691149 output_pred_dir = cfg .inference .save_prediction .output_path
11701150 predictions_dir = Path (output_pred_dir )
@@ -1294,45 +1274,19 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
12941274 f"Mismatch: { len (all_predictions )} prediction files vs " f"{ len (all_masks )} mask files"
12951275 )
12961276
1297- if affinity_deepem_crop_enabled (cfg ):
1298- groups = resolve_affinity_channel_groups_from_cfg (cfg )
1299- all_offsets = []
1300- for _ , offsets in groups :
1301- all_offsets .extend (offsets )
1302- crop_pad = compute_affinity_crop_pad (all_offsets )
1303- if crop_pad and any (before or after for before , after in crop_pad ):
1304- cropped_predictions = []
1305- cropped_labels = []
1306- cropped_masks = [] if all_masks is not None else None
1307- for idx , pred in enumerate (all_predictions ):
1308- reference_spatial_shape = tuple (
1309- int (v ) for v in all_labels [idx ].shape [- len (crop_pad ) :]
1310- )
1311- cropped_predictions .append (
1312- _maybe_crop_affinity_array (
1313- np .asarray (pred ),
1314- reference_spatial_shape = reference_spatial_shape ,
1315- crop_pad = crop_pad ,
1316- )
1317- )
1318- cropped_labels .append (
1319- _maybe_crop_affinity_array (
1320- np .asarray (all_labels [idx ]),
1321- reference_spatial_shape = reference_spatial_shape ,
1322- crop_pad = crop_pad ,
1323- )
1324- )
1325- if cropped_masks is not None :
1326- cropped_masks .append (
1327- _maybe_crop_affinity_array (
1328- np .asarray (all_masks [idx ]),
1329- reference_spatial_shape = reference_spatial_shape ,
1330- crop_pad = crop_pad ,
1331- )
1332- )
1333- all_predictions = cropped_predictions
1334- all_labels = cropped_labels
1335- all_masks = cropped_masks
1277+ # Validate that prediction and label spatial shapes match.
1278+ # Cached TTA prediction files are saved after crop_pad + affinity_crop
1279+ # in the test pipeline, so they should already align with the label volume.
1280+ for idx , pred in enumerate (all_predictions ):
1281+ pred_spatial = tuple (int (v ) for v in pred .shape [- 3 :])
1282+ label_spatial = tuple (int (v ) for v in all_labels [idx ].shape [- 3 :])
1283+ if pred_spatial != label_spatial :
1284+ raise ValueError (
1285+ f"Prediction/label spatial shape mismatch for volume { idx } : "
1286+ f"prediction { pred_spatial } vs label { label_spatial } . "
1287+ f"Cached predictions may be stale — regenerate TTA predictions "
1288+ f"by re-running inference with the real model checkpoint."
1289+ )
13361290
13371291 # Step 4: Create tuner and run optimization (per-volume evaluation)
13381292 logger .info ("[4/5] Creating Optuna tuner..." )
0 commit comments