Skip to content

Commit 43a8722

Browse files
author
Donglai Wei
committed
Fix stale prediction cache reuse in tune mode
1 parent 304c1ac commit 43a8722

2 files changed

Lines changed: 247 additions & 71 deletions

File tree

connectomics/decoding/tuning/optuna_tuner.py

Lines changed: 158 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,47 @@ def _maybe_crop_affinity_array(
7373
return crop_spatial_by_pad(data, crop_pad, item_name="tuning array")
7474

7575

76+
def _expand_tuning_paths(path_or_pattern: Any, *, field_name: str) -> list[str]:
77+
"""Expand string/list path inputs used by the tuning loader."""
78+
import glob
79+
80+
if path_or_pattern is None:
81+
return []
82+
83+
if isinstance(path_or_pattern, (str, Path)):
84+
pattern = str(path_or_pattern)
85+
if "*" in pattern or "?" in pattern:
86+
return sorted(glob.glob(pattern))
87+
return [pattern]
88+
89+
if isinstance(path_or_pattern, list):
90+
expanded: list[str] = []
91+
for entry in path_or_pattern:
92+
expanded.extend(_expand_tuning_paths(entry, field_name=field_name))
93+
return expanded
94+
95+
raise TypeError(f"{field_name} must be string or list, got {type(path_or_pattern)}")
96+
97+
98+
def _resolve_tuning_prediction_files(
99+
cfg,
100+
predictions_dir: Path,
101+
cache_suffix: str,
102+
) -> tuple[list[str], list[str]]:
103+
"""Resolve cached prediction files for the current tune dataset only."""
104+
tune_image_pattern = getattr(getattr(cfg.data, "val", None), "image", None)
105+
if tune_image_pattern is None:
106+
raise ValueError("Missing data.val.image in configuration")
107+
108+
image_files = _expand_tuning_paths(tune_image_pattern, field_name="data.val.image")
109+
if not image_files:
110+
raise FileNotFoundError(f"No image files found matching pattern: {tune_image_pattern}")
111+
112+
expected_files = [predictions_dir / f"{Path(str(path)).stem}{cache_suffix}" for path in image_files]
113+
existing_files = [str(path) for path in expected_files if path.exists()]
114+
return existing_files, [str(path) for path in expected_files]
115+
116+
76117
@contextmanager
77118
def _temporary_tuning_inference_overrides(*cfg_objects: Any):
78119
"""Force the pre-Optuna inference pass to cache raw predictions only."""
@@ -380,6 +421,13 @@ def optimize(self) -> optuna.Study:
380421
direction=direction,
381422
)
382423

424+
# Seed the first trial with known-good defaults so TPE has a strong
425+
# baseline from the start instead of wasting early trials on random configs.
426+
default_params = self._build_default_trial_params()
427+
if default_params:
428+
study.enqueue_trial(default_params)
429+
logger.info("Seeded first trial with default parameters: %s", default_params)
430+
383431
# Run optimization
384432
n_trials = self.tune_cfg.n_trials
385433
timeout = self.tune_cfg.timeout
@@ -776,6 +824,70 @@ def _sample_parameters(self, trial: optuna.Trial) -> Dict[str, Any]:
776824

777825
return params
778826

827+
def _build_default_trial_params(self) -> Optional[Dict[str, Any]]:
828+
"""Build a param dict from config defaults to seed the first Optuna trial.
829+
830+
Maps ``parameter_space.decoding.defaults`` (and postprocessing defaults)
831+
back to the flat Optuna parameter names used by ``_suggest_param``.
832+
"""
833+
params: Dict[str, Any] = {}
834+
835+
# --- decoding defaults ---
836+
decoding_cfg = getattr(self.param_space_cfg, "decoding", None)
837+
defaults = getattr(decoding_cfg, "defaults", None) if decoding_cfg else None
838+
param_defs = getattr(decoding_cfg, "parameters", None) if decoding_cfg else None
839+
840+
if defaults and param_defs:
841+
for name, pcfg in param_defs.items():
842+
if self._abiss_batch_enabled and name == "ws_merge_threshold":
843+
continue
844+
val = self._lookup_default(defaults, name, pcfg)
845+
if val is not None:
846+
params[name] = val
847+
848+
# --- postprocessing defaults ---
849+
postproc_cfg = getattr(self.param_space_cfg, "postprocessing", None)
850+
if postproc_cfg and getattr(postproc_cfg, "enabled", False):
851+
pp_defaults = getattr(postproc_cfg, "defaults", None)
852+
pp_params = getattr(postproc_cfg, "parameters", None)
853+
if pp_defaults and pp_params:
854+
for name, pcfg in pp_params.items():
855+
val = self._lookup_default(pp_defaults, name, pcfg)
856+
if val is not None:
857+
params[name] = val
858+
859+
return params if params else None
860+
861+
@staticmethod
862+
def _lookup_default(defaults: Any, name: str, pcfg: Any) -> Any:
863+
"""Resolve a single default value from the defaults block.
864+
865+
Handles three layouts:
866+
- ``nest_under``: ``defaults.<nest_under>.<name>``
867+
- ``param_group`` + ``tuple_index``: ``defaults.<param_group>[tuple_index]``
868+
- direct: ``defaults.<name>``
869+
"""
870+
# nested (e.g. cli_args.ws_high_threshold)
871+
nest_under = pcfg.get("nest_under", None) if hasattr(pcfg, "get") else getattr(pcfg, "nest_under", None)
872+
if nest_under:
873+
nested = getattr(defaults, nest_under, None) if not isinstance(defaults, dict) else defaults.get(nest_under)
874+
if nested is not None:
875+
val = nested.get(name, None) if isinstance(nested, dict) else getattr(nested, name, None)
876+
if val is not None:
877+
return val
878+
879+
# tuple param (e.g. binary_threshold[0])
880+
param_group = pcfg.get("param_group", None) if hasattr(pcfg, "get") else getattr(pcfg, "param_group", None)
881+
tuple_index = pcfg.get("tuple_index", None) if hasattr(pcfg, "get") else getattr(pcfg, "tuple_index", None)
882+
if param_group is not None and tuple_index is not None:
883+
group_val = getattr(defaults, param_group, None) if not isinstance(defaults, dict) else defaults.get(param_group)
884+
if isinstance(group_val, (list, tuple)) and int(tuple_index) < len(group_val):
885+
return group_val[int(tuple_index)]
886+
887+
# direct
888+
val = getattr(defaults, name, None) if not isinstance(defaults, dict) else defaults.get(name)
889+
return val
890+
779891
def _reconstruct_decoding_params(self, sampled_params: Dict[str, Any]) -> Dict[str, Any]:
780892
"""
781893
Reconstruct decoding function parameters from sampled values.
@@ -1051,45 +1163,65 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
10511163
logger.info("STARTING PARAMETER TUNING | Output directory: %s", output_dir)
10521164

10531165
# Step 1: Run inference on tune dataset
1054-
import glob
1055-
10561166
from connectomics.data.io import read_volume
10571167
from connectomics.training.lightning import create_datamodule
10581168

10591169
logger.info("[1/4] Running inference on tuning dataset...")
10601170

10611171
tune_data = cfg.data
1172+
cache_suffix = "_tta_prediction.h5"
10621173

1063-
# Create datamodule with tune mode using merged runtime cfg.data/cfg.inference.
1064-
datamodule = create_datamodule(cfg, mode="tune")
1174+
output_pred_dir = cfg.inference.save_prediction.output_path
1175+
predictions_dir = Path(output_pred_dir)
1176+
pred_files, expected_pred_files = _resolve_tuning_prediction_files(cfg, predictions_dir, cache_suffix)
10651177

1066-
logger.info("Using intermediate-only cache generation (decoding/evaluation disabled)")
1178+
if len(pred_files) == len(expected_pred_files):
1179+
logger.info(
1180+
"Found %d existing tuning prediction file(s) for the current tune dataset in %s "
1181+
"— skipping inference.",
1182+
len(pred_files),
1183+
predictions_dir,
1184+
)
1185+
else:
1186+
if pred_files:
1187+
logger.info(
1188+
"Found %d/%d matching tuning prediction file(s); rerunning inference for missing "
1189+
"volumes instead of mixing partial caches.",
1190+
len(pred_files),
1191+
len(expected_pred_files),
1192+
)
1193+
else:
1194+
logger.info("No matching tuning prediction files found in %s.", predictions_dir)
1195+
1196+
# Create datamodule with tune mode using merged runtime cfg.data/cfg.inference.
1197+
datamodule = create_datamodule(cfg, mode="tune")
10671198

1068-
# Run test to populate/load raw prediction caches only. Optuna applies its own
1069-
# decoding sweep afterward, so the tune inference pass must not decode with the
1070-
# default config first.
1071-
with _temporary_tuning_inference_overrides(cfg, getattr(model, "cfg", None)) as cache_suffix:
1072-
model._tune_mode = True
1073-
try:
1074-
results = trainer.test(model, datamodule=datamodule, ckpt_path=checkpoint_path)
1075-
finally:
1076-
model._tune_mode = False
1199+
logger.info("Using intermediate-only cache generation (decoding/evaluation disabled)")
10771200

1078-
logger.info("Test completed. Results: %s", results)
1201+
# Run test to populate/load raw prediction caches only. Optuna applies its own
1202+
# decoding sweep afterward, so the tune inference pass must not decode with the
1203+
# default config first.
1204+
with _temporary_tuning_inference_overrides(cfg, getattr(model, "cfg", None)) as cache_suffix:
1205+
model._tune_mode = True
1206+
try:
1207+
results = trainer.test(model, datamodule=datamodule, ckpt_path=checkpoint_path)
1208+
finally:
1209+
model._tune_mode = False
1210+
1211+
logger.info("Test completed. Results: %s", results)
1212+
pred_files, expected_pred_files = _resolve_tuning_prediction_files(
1213+
cfg, predictions_dir, cache_suffix
1214+
)
10791215

10801216
# Step 2: Load predictions from saved files
10811217
logger.info("[2/4] Loading predictions from saved files...")
1082-
output_pred_dir = cfg.inference.save_prediction.output_path
1083-
predictions_dir = Path(output_pred_dir)
1084-
1085-
# Find all prediction files using cache_suffix from config
1086-
pred_pattern = f"*{cache_suffix}"
1087-
pred_files = sorted(glob.glob(str(predictions_dir / pred_pattern)))
10881218

1089-
if not pred_files:
1219+
if len(pred_files) != len(expected_pred_files):
1220+
missing = sorted(set(expected_pred_files) - set(pred_files))
10901221
raise FileNotFoundError(
1091-
f"No prediction files found in: {predictions_dir}\n"
1092-
f"Expected files matching pattern: {pred_pattern}"
1222+
"Missing tuning prediction files for the current tune dataset.\n"
1223+
f"Found: {len(pred_files)}/{len(expected_pred_files)} in {predictions_dir}\n"
1224+
f"Missing: {missing}"
10931225
)
10941226

10951227
logger.info("Found %d prediction file(s)", len(pred_files))
@@ -1116,14 +1248,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11161248
raise ValueError("Missing data.val.label in configuration")
11171249

11181250
# Handle both string patterns and pre-resolved lists
1119-
if isinstance(tune_label_pattern, list):
1120-
# Already resolved to list of files
1121-
label_files = sorted(tune_label_pattern)
1122-
elif isinstance(tune_label_pattern, str):
1123-
# Glob pattern - expand it
1124-
label_files = sorted(glob.glob(tune_label_pattern))
1125-
else:
1126-
raise TypeError(f"data.val.label must be string or list, got {type(tune_label_pattern)}")
1251+
label_files = _expand_tuning_paths(tune_label_pattern, field_name="data.val.label")
11271252

11281253
if not label_files:
11291254
raise FileNotFoundError(f"No label files found matching pattern: {tune_label_pattern}")
@@ -1148,13 +1273,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
11481273
all_masks = None
11491274
tune_mask_pattern = getattr(getattr(tune_data, "val", None), "mask", None)
11501275
if tune_mask_pattern:
1151-
# Handle both string patterns and pre-resolved lists
1152-
if isinstance(tune_mask_pattern, list):
1153-
mask_files = sorted(tune_mask_pattern)
1154-
elif isinstance(tune_mask_pattern, str):
1155-
mask_files = sorted(glob.glob(tune_mask_pattern))
1156-
else:
1157-
raise TypeError(f"data.val.mask must be string or list, got {type(tune_mask_pattern)}")
1276+
mask_files = _expand_tuning_paths(tune_mask_pattern, field_name="data.val.mask")
11581277

11591278
if not mask_files:
11601279
logger.warning("No mask files found matching pattern: %s", tune_mask_pattern)

0 commit comments

Comments
 (0)