@@ -73,6 +73,47 @@ def _maybe_crop_affinity_array(
7373 return crop_spatial_by_pad (data , crop_pad , item_name = "tuning array" )
7474
7575
76+ def _expand_tuning_paths (path_or_pattern : Any , * , field_name : str ) -> list [str ]:
77+ """Expand string/list path inputs used by the tuning loader."""
78+ import glob
79+
80+ if path_or_pattern is None :
81+ return []
82+
83+ if isinstance (path_or_pattern , (str , Path )):
84+ pattern = str (path_or_pattern )
85+ if "*" in pattern or "?" in pattern :
86+ return sorted (glob .glob (pattern ))
87+ return [pattern ]
88+
89+ if isinstance (path_or_pattern , list ):
90+ expanded : list [str ] = []
91+ for entry in path_or_pattern :
92+ expanded .extend (_expand_tuning_paths (entry , field_name = field_name ))
93+ return expanded
94+
95+ raise TypeError (f"{ field_name } must be string or list, got { type (path_or_pattern )} " )
96+
97+
98+ def _resolve_tuning_prediction_files (
99+ cfg ,
100+ predictions_dir : Path ,
101+ cache_suffix : str ,
102+ ) -> tuple [list [str ], list [str ]]:
103+ """Resolve cached prediction files for the current tune dataset only."""
104+ tune_image_pattern = getattr (getattr (cfg .data , "val" , None ), "image" , None )
105+ if tune_image_pattern is None :
106+ raise ValueError ("Missing data.val.image in configuration" )
107+
108+ image_files = _expand_tuning_paths (tune_image_pattern , field_name = "data.val.image" )
109+ if not image_files :
110+ raise FileNotFoundError (f"No image files found matching pattern: { tune_image_pattern } " )
111+
112+ expected_files = [predictions_dir / f"{ Path (str (path )).stem } { cache_suffix } " for path in image_files ]
113+ existing_files = [str (path ) for path in expected_files if path .exists ()]
114+ return existing_files , [str (path ) for path in expected_files ]
115+
116+
76117@contextmanager
77118def _temporary_tuning_inference_overrides (* cfg_objects : Any ):
78119 """Force the pre-Optuna inference pass to cache raw predictions only."""
@@ -380,6 +421,13 @@ def optimize(self) -> optuna.Study:
380421 direction = direction ,
381422 )
382423
424+ # Seed the first trial with known-good defaults so TPE has a strong
425+ # baseline from the start instead of wasting early trials on random configs.
426+ default_params = self ._build_default_trial_params ()
427+ if default_params :
428+ study .enqueue_trial (default_params )
429+ logger .info ("Seeded first trial with default parameters: %s" , default_params )
430+
383431 # Run optimization
384432 n_trials = self .tune_cfg .n_trials
385433 timeout = self .tune_cfg .timeout
@@ -776,6 +824,70 @@ def _sample_parameters(self, trial: optuna.Trial) -> Dict[str, Any]:
776824
777825 return params
778826
827+ def _build_default_trial_params (self ) -> Optional [Dict [str , Any ]]:
828+ """Build a param dict from config defaults to seed the first Optuna trial.
829+
830+ Maps ``parameter_space.decoding.defaults`` (and postprocessing defaults)
831+ back to the flat Optuna parameter names used by ``_suggest_param``.
832+ """
833+ params : Dict [str , Any ] = {}
834+
835+ # --- decoding defaults ---
836+ decoding_cfg = getattr (self .param_space_cfg , "decoding" , None )
837+ defaults = getattr (decoding_cfg , "defaults" , None ) if decoding_cfg else None
838+ param_defs = getattr (decoding_cfg , "parameters" , None ) if decoding_cfg else None
839+
840+ if defaults and param_defs :
841+ for name , pcfg in param_defs .items ():
842+ if self ._abiss_batch_enabled and name == "ws_merge_threshold" :
843+ continue
844+ val = self ._lookup_default (defaults , name , pcfg )
845+ if val is not None :
846+ params [name ] = val
847+
848+ # --- postprocessing defaults ---
849+ postproc_cfg = getattr (self .param_space_cfg , "postprocessing" , None )
850+ if postproc_cfg and getattr (postproc_cfg , "enabled" , False ):
851+ pp_defaults = getattr (postproc_cfg , "defaults" , None )
852+ pp_params = getattr (postproc_cfg , "parameters" , None )
853+ if pp_defaults and pp_params :
854+ for name , pcfg in pp_params .items ():
855+ val = self ._lookup_default (pp_defaults , name , pcfg )
856+ if val is not None :
857+ params [name ] = val
858+
859+ return params if params else None
860+
861+ @staticmethod
862+ def _lookup_default (defaults : Any , name : str , pcfg : Any ) -> Any :
863+ """Resolve a single default value from the defaults block.
864+
865+ Handles three layouts:
866+ - ``nest_under``: ``defaults.<nest_under>.<name>``
867+ - ``param_group`` + ``tuple_index``: ``defaults.<param_group>[tuple_index]``
868+ - direct: ``defaults.<name>``
869+ """
870+ # nested (e.g. cli_args.ws_high_threshold)
871+ nest_under = pcfg .get ("nest_under" , None ) if hasattr (pcfg , "get" ) else getattr (pcfg , "nest_under" , None )
872+ if nest_under :
873+ nested = getattr (defaults , nest_under , None ) if not isinstance (defaults , dict ) else defaults .get (nest_under )
874+ if nested is not None :
875+ val = nested .get (name , None ) if isinstance (nested , dict ) else getattr (nested , name , None )
876+ if val is not None :
877+ return val
878+
879+ # tuple param (e.g. binary_threshold[0])
880+ param_group = pcfg .get ("param_group" , None ) if hasattr (pcfg , "get" ) else getattr (pcfg , "param_group" , None )
881+ tuple_index = pcfg .get ("tuple_index" , None ) if hasattr (pcfg , "get" ) else getattr (pcfg , "tuple_index" , None )
882+ if param_group is not None and tuple_index is not None :
883+ group_val = getattr (defaults , param_group , None ) if not isinstance (defaults , dict ) else defaults .get (param_group )
884+ if isinstance (group_val , (list , tuple )) and int (tuple_index ) < len (group_val ):
885+ return group_val [int (tuple_index )]
886+
887+ # direct
888+ val = getattr (defaults , name , None ) if not isinstance (defaults , dict ) else defaults .get (name )
889+ return val
890+
779891 def _reconstruct_decoding_params (self , sampled_params : Dict [str , Any ]) -> Dict [str , Any ]:
780892 """
781893 Reconstruct decoding function parameters from sampled values.
@@ -1051,45 +1163,65 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
10511163 logger .info ("STARTING PARAMETER TUNING | Output directory: %s" , output_dir )
10521164
10531165 # Step 1: Run inference on tune dataset
1054- import glob
1055-
10561166 from connectomics .data .io import read_volume
10571167 from connectomics .training .lightning import create_datamodule
10581168
10591169 logger .info ("[1/4] Running inference on tuning dataset..." )
10601170
10611171 tune_data = cfg .data
1172+ cache_suffix = "_tta_prediction.h5"
10621173
1063- # Create datamodule with tune mode using merged runtime cfg.data/cfg.inference.
1064- datamodule = create_datamodule (cfg , mode = "tune" )
1174+ output_pred_dir = cfg .inference .save_prediction .output_path
1175+ predictions_dir = Path (output_pred_dir )
1176+ pred_files , expected_pred_files = _resolve_tuning_prediction_files (cfg , predictions_dir , cache_suffix )
10651177
1066- logger .info ("Using intermediate-only cache generation (decoding/evaluation disabled)" )
1178+ if len (pred_files ) == len (expected_pred_files ):
1179+ logger .info (
1180+ "Found %d existing tuning prediction file(s) for the current tune dataset in %s "
1181+ "— skipping inference." ,
1182+ len (pred_files ),
1183+ predictions_dir ,
1184+ )
1185+ else :
1186+ if pred_files :
1187+ logger .info (
1188+ "Found %d/%d matching tuning prediction file(s); rerunning inference for missing "
1189+ "volumes instead of mixing partial caches." ,
1190+ len (pred_files ),
1191+ len (expected_pred_files ),
1192+ )
1193+ else :
1194+ logger .info ("No matching tuning prediction files found in %s." , predictions_dir )
1195+
1196+ # Create datamodule with tune mode using merged runtime cfg.data/cfg.inference.
1197+ datamodule = create_datamodule (cfg , mode = "tune" )
10671198
1068- # Run test to populate/load raw prediction caches only. Optuna applies its own
1069- # decoding sweep afterward, so the tune inference pass must not decode with the
1070- # default config first.
1071- with _temporary_tuning_inference_overrides (cfg , getattr (model , "cfg" , None )) as cache_suffix :
1072- model ._tune_mode = True
1073- try :
1074- results = trainer .test (model , datamodule = datamodule , ckpt_path = checkpoint_path )
1075- finally :
1076- model ._tune_mode = False
1199+ logger .info ("Using intermediate-only cache generation (decoding/evaluation disabled)" )
10771200
1078- logger .info ("Test completed. Results: %s" , results )
1201+ # Run test to populate/load raw prediction caches only. Optuna applies its own
1202+ # decoding sweep afterward, so the tune inference pass must not decode with the
1203+ # default config first.
1204+ with _temporary_tuning_inference_overrides (cfg , getattr (model , "cfg" , None )) as cache_suffix :
1205+ model ._tune_mode = True
1206+ try :
1207+ results = trainer .test (model , datamodule = datamodule , ckpt_path = checkpoint_path )
1208+ finally :
1209+ model ._tune_mode = False
1210+
1211+ logger .info ("Test completed. Results: %s" , results )
1212+ pred_files , expected_pred_files = _resolve_tuning_prediction_files (
1213+ cfg , predictions_dir , cache_suffix
1214+ )
10791215
10801216 # Step 2: Load predictions from saved files
10811217 logger .info ("[2/4] Loading predictions from saved files..." )
1082- output_pred_dir = cfg .inference .save_prediction .output_path
1083- predictions_dir = Path (output_pred_dir )
1084-
1085- # Find all prediction files using cache_suffix from config
1086- pred_pattern = f"*{ cache_suffix } "
1087- pred_files = sorted (glob .glob (str (predictions_dir / pred_pattern )))
10881218
1089- if not pred_files :
1219+ if len (pred_files ) != len (expected_pred_files ):
1220+ missing = sorted (set (expected_pred_files ) - set (pred_files ))
10901221 raise FileNotFoundError (
1091- f"No prediction files found in: { predictions_dir } \n "
1092- f"Expected files matching pattern: { pred_pattern } "
1222+ "Missing tuning prediction files for the current tune dataset.\n "
1223+ f"Found: { len (pred_files )} /{ len (expected_pred_files )} in { predictions_dir } \n "
1224+ f"Missing: { missing } "
10931225 )
10941226
10951227 logger .info ("Found %d prediction file(s)" , len (pred_files ))
@@ -1116,14 +1248,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11161248 raise ValueError ("Missing data.val.label in configuration" )
11171249
11181250 # Handle both string patterns and pre-resolved lists
1119- if isinstance (tune_label_pattern , list ):
1120- # Already resolved to list of files
1121- label_files = sorted (tune_label_pattern )
1122- elif isinstance (tune_label_pattern , str ):
1123- # Glob pattern - expand it
1124- label_files = sorted (glob .glob (tune_label_pattern ))
1125- else :
1126- raise TypeError (f"data.val.label must be string or list, got { type (tune_label_pattern )} " )
1251+ label_files = _expand_tuning_paths (tune_label_pattern , field_name = "data.val.label" )
11271252
11281253 if not label_files :
11291254 raise FileNotFoundError (f"No label files found matching pattern: { tune_label_pattern } " )
@@ -1148,13 +1273,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11481273 all_masks = None
11491274 tune_mask_pattern = getattr (getattr (tune_data , "val" , None ), "mask" , None )
11501275 if tune_mask_pattern :
1151- # Handle both string patterns and pre-resolved lists
1152- if isinstance (tune_mask_pattern , list ):
1153- mask_files = sorted (tune_mask_pattern )
1154- elif isinstance (tune_mask_pattern , str ):
1155- mask_files = sorted (glob .glob (tune_mask_pattern ))
1156- else :
1157- raise TypeError (f"data.val.mask must be string or list, got { type (tune_mask_pattern )} " )
1276+ mask_files = _expand_tuning_paths (tune_mask_pattern , field_name = "data.val.mask" )
11581277
11591278 if not mask_files :
11601279 logger .warning ("No mask files found matching pattern: %s" , tune_mask_pattern )
0 commit comments