Skip to content

Commit 69bfa31

Browse files
author
Donglai Wei
committed
Improve prediction naming for experiment logging
1 parent 7344a5e commit 69bfa31

9 files changed

Lines changed: 626 additions & 81 deletions

File tree

connectomics/decoding/tuning/optuna_tuner.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,27 @@ def _resolve_tuning_prediction_files(
8989
return existing_files, [str(path) for path in expected_files]
9090

9191

92+
def _print_best_params_yaml(best_params_file: Path) -> None:
93+
"""Print the current best-params YAML to stdout for interactive tune runs."""
94+
try:
95+
best_params = OmegaConf.load(best_params_file)
96+
yaml_text = OmegaConf.to_yaml(best_params).rstrip()
97+
except Exception:
98+
yaml_text = best_params_file.read_text().rstrip()
99+
100+
print("\n" + "=" * 80)
101+
print(f"BEST PARAMETERS | {best_params_file}")
102+
print("=" * 80)
103+
if yaml_text:
104+
print(yaml_text)
105+
else:
106+
print("[empty]")
107+
108+
92109
@contextmanager
93-
def _temporary_tuning_inference_overrides(*cfg_objects: Any):
110+
def _temporary_tuning_inference_overrides(
111+
*cfg_objects: Any, checkpoint_path: str | None = None
112+
):
94113
"""Force the pre-Optuna inference pass to cache raw predictions only."""
95114
inference_cfgs = []
96115
seen_inference_cfgs: set[int] = set()
@@ -109,7 +128,11 @@ def _temporary_tuning_inference_overrides(*cfg_objects: Any):
109128
if not inference_cfgs:
110129
raise ValueError("Missing runtime cfg.inference configuration required for tuning")
111130

112-
suffix = tta_cache_suffix(primary_cfg) if primary_cfg is not None else "_tta_x1_prediction.h5"
131+
suffix = (
132+
tta_cache_suffix(primary_cfg, checkpoint_path=checkpoint_path)
133+
if primary_cfg is not None
134+
else "_tta_x1_prediction.h5"
135+
)
113136

114137
backups = []
115138
for inference_cfg in inference_cfgs:
@@ -1286,6 +1309,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
12861309
"Delete this file to re-run tuning.",
12871310
best_params_file,
12881311
)
1312+
_print_best_params_yaml(best_params_file)
12891313
return
12901314

12911315
logger.info("STARTING PARAMETER TUNING | Output directory: %s", output_dir)
@@ -1297,7 +1321,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
12971321
logger.info("[1/4] Running inference on tuning dataset...")
12981322

12991323
tune_data = cfg.data
1300-
cache_suffix = tta_cache_suffix(cfg)
1324+
prediction_checkpoint_path = getattr(model, "_prediction_checkpoint_path", None) or checkpoint_path
1325+
cache_suffix = tta_cache_suffix(cfg, checkpoint_path=prediction_checkpoint_path)
13011326

13021327
output_pred_dir = cfg.inference.save_prediction.output_path
13031328
predictions_dir = Path(output_pred_dir)
@@ -1329,7 +1354,11 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
13291354
# Run test to populate/load raw prediction caches only. Optuna applies its own
13301355
# decoding sweep afterward, so the tune inference pass must not decode with the
13311356
# default config first.
1332-
with _temporary_tuning_inference_overrides(cfg, getattr(model, "cfg", None)) as cache_suffix:
1357+
with _temporary_tuning_inference_overrides(
1358+
cfg,
1359+
getattr(model, "cfg", None),
1360+
checkpoint_path=prediction_checkpoint_path,
1361+
) as cache_suffix:
13331362
model._tune_mode = True
13341363
try:
13351364
results = trainer.test(model, datamodule=datamodule, ckpt_path=checkpoint_path)
@@ -1456,6 +1485,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
14561485
study.best_value,
14571486
study.best_params,
14581487
)
1488+
if best_params_file.exists():
1489+
_print_best_params_yaml(best_params_file)
14591490

14601491

14611492
def load_and_apply_best_params(cfg):

connectomics/training/lightning/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from .utils import (
2020
compute_tta_passes,
2121
extract_best_score_from_checkpoint,
22+
final_prediction_output_tag,
2223
is_tta_cache_suffix,
2324
parse_args,
2425
resolve_prediction_cache_suffix,
2526
setup_config,
2627
setup_seed_everything,
2728
tta_cache_suffix,
29+
tta_cache_suffix_candidates,
2830
)
2931

3032
__all__ = [
@@ -46,7 +48,9 @@
4648
"setup_config",
4749
"extract_best_score_from_checkpoint",
4850
"compute_tta_passes",
51+
"final_prediction_output_tag",
4952
"tta_cache_suffix",
53+
"tta_cache_suffix_candidates",
5054
"resolve_prediction_cache_suffix",
5155
"is_tta_cache_suffix",
5256
]

connectomics/training/lightning/model.py

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import logging
19+
import os
1920
import warnings
2021
from pathlib import Path
2122
from typing import Any, Dict, List, Optional, Union
@@ -44,7 +45,14 @@
4445
from ...models import build_model
4546
from ...models.loss import create_loss, get_loss_metadata_for_module
4647
from ..debugging import DebugManager
47-
from .utils import is_tta_cache_suffix, resolve_prediction_cache_suffix, tta_cache_suffix
48+
from .utils import (
49+
final_prediction_output_tag,
50+
format_checkpoint_name_tag,
51+
is_tta_cache_suffix,
52+
resolve_prediction_cache_suffix,
53+
tta_cache_suffix,
54+
tta_cache_suffix_candidates,
55+
)
4856

4957
# Import training/inference components
5058
from ..loss import LossOrchestrator, build_loss_weighter, infer_num_loss_tasks_from_config
@@ -411,11 +419,36 @@ def _resolve_test_output_config(
411419
mode = "test"
412420
save_pred_cfg = self._get_runtime_inference_config().save_prediction
413421
output_dir_value = getattr(save_pred_cfg, "output_path", None)
414-
cache_suffix = resolve_prediction_cache_suffix(self.cfg, mode=mode)
422+
cache_suffix = resolve_prediction_cache_suffix(
423+
self.cfg,
424+
mode=mode,
425+
checkpoint_path=self._get_prediction_checkpoint_path(),
426+
)
415427

416428
filenames = resolve_output_filenames(self.cfg, batch, global_step=self.global_step)
417429
return mode, output_dir_value, cache_suffix, filenames
418430

431+
def _get_prediction_checkpoint_path(self) -> str:
432+
"""Return the checkpoint/weights path whose stem should tag prediction caches."""
433+
explicit_path = getattr(self, "_prediction_checkpoint_path", None)
434+
if explicit_path is not None:
435+
path_value = str(explicit_path).strip()
436+
if path_value:
437+
return path_value
438+
439+
trainer = getattr(self, "_trainer", None)
440+
trainer_ckpt_path = getattr(trainer, "ckpt_path", None) if trainer is not None else None
441+
if trainer_ckpt_path is not None:
442+
path_value = str(trainer_ckpt_path).strip()
443+
if path_value:
444+
return path_value
445+
446+
external_weights_path = getattr(getattr(self.cfg, "model", None), "external_weights_path", None)
447+
if isinstance(external_weights_path, str) and external_weights_path.strip():
448+
return external_weights_path.strip()
449+
450+
return ""
451+
419452
def _resolve_tta_result_path_override(self) -> str:
420453
"""Return explicit intermediate prediction file from inference.tta_result_path."""
421454
inference_cfg = self._get_runtime_inference_config()
@@ -434,7 +467,7 @@ def _load_cached_predictions(
434467
if not pred_file.is_absolute():
435468
pred_file = Path.cwd() / pred_file
436469

437-
if pred_file.exists():
470+
if os.path.exists(pred_file):
438471
try:
439472
logger.info(f"Using explicit inference.tta_result_path file: {pred_file}")
440473
pred = read_volume(str(pred_file), dataset="main")
@@ -446,7 +479,14 @@ def _load_cached_predictions(
446479
f"{len(filenames)} filenames; decoding will use the explicit file only."
447480
)
448481
# Treat explicit file as intermediate prediction so decoding still runs.
449-
return pred, True, tta_cache_suffix(self.cfg)
482+
return (
483+
pred,
484+
True,
485+
tta_cache_suffix(
486+
self.cfg,
487+
checkpoint_path=self._get_prediction_checkpoint_path(),
488+
),
489+
)
450490
except Exception as e:
451491
logger.warning(
452492
f"Failed to load explicit inference.tta_result_path file {pred_file}: {e}. "
@@ -462,14 +502,22 @@ def _load_cached_predictions(
462502
return None, False, cache_suffix
463503

464504
output_dir = Path(output_dir_value)
505+
checkpoint_tag = format_checkpoint_name_tag(self._get_prediction_checkpoint_path())
465506

466507
# Build ordered list of suffixes to try: final prediction first, then
467508
# intermediate TTA, then glob fallback.
468509
suffixes_to_try: list[str] = []
469510
if is_tta_cache_suffix(cache_suffix):
470511
# Prefer the final decoded file (e.g. _x16_prediction.h5) over
471512
# the intermediate TTA file (e.g. _tta_x16_prediction.h5).
472-
final_suffix = cache_suffix.replace("_tta_x", "_x")
513+
final_suffix = (
514+
"_"
515+
+ final_prediction_output_tag(
516+
self.cfg,
517+
checkpoint_path=self._get_prediction_checkpoint_path(),
518+
)
519+
+ ".h5"
520+
)
473521
suffixes_to_try.append(final_suffix)
474522
suffixes_to_try.append(cache_suffix)
475523

@@ -478,7 +526,7 @@ def _load_cached_predictions(
478526
all_exist = True
479527
for filename in filenames:
480528
pred_file = output_dir / f"{filename}{try_suffix}"
481-
if pred_file.exists():
529+
if os.path.exists(pred_file):
482530
try:
483531
pred = read_volume(str(pred_file), dataset="main")
484532
existing_predictions.append(pred)
@@ -507,21 +555,43 @@ def _load_cached_predictions(
507555
)
508556
return predictions_np, True, try_suffix
509557

510-
# Glob fallback: look for any TTA intermediate file.
558+
# Targeted fallback: look for the exact TTA intermediate cache suffix
559+
# matching the current config rather than any arbitrary TTA file.
511560
if mode == "test" and not is_tta_cache_suffix(cache_suffix):
512-
for filename in filenames:
513-
tta_matches = sorted(output_dir.glob(f"{filename}_tta_x*_prediction.h5"))
514-
if tta_matches:
515-
pred_file = tta_matches[-1]
516-
loaded_suffix = pred_file.name[len(filename):]
561+
fallback_suffixes = tta_cache_suffix_candidates(
562+
self.cfg,
563+
checkpoint_path=self._get_prediction_checkpoint_path(),
564+
)
565+
for try_suffix in fallback_suffixes:
566+
existing_predictions = []
567+
all_exist = True
568+
for filename in filenames:
569+
pred_file = output_dir / f"{filename}{try_suffix}"
570+
if not os.path.exists(pred_file):
571+
all_exist = False
572+
break
517573
try:
518574
pred = read_volume(str(pred_file), dataset="main")
519-
if pred.ndim < 4:
520-
pred = pred[np.newaxis, ...]
521-
logger.info("Loaded fallback TTA prediction: %s", pred_file.name)
522-
return pred, True, loaded_suffix
575+
existing_predictions.append(pred)
523576
except Exception as e:
524577
logger.warning(f"Failed to load {pred_file}: {e}")
578+
all_exist = False
579+
break
580+
if all_exist and len(existing_predictions) == len(filenames):
581+
logger.info(
582+
"Loaded fallback TTA prediction(s) using exact suffix %s",
583+
try_suffix,
584+
)
585+
if len(existing_predictions) == 1:
586+
predictions_np = existing_predictions[0]
587+
if predictions_np.ndim < 4:
588+
predictions_np = predictions_np[np.newaxis, ...]
589+
else:
590+
predictions_np = np.stack(
591+
[p[np.newaxis, ...] if p.ndim < 4 else p for p in existing_predictions],
592+
axis=0,
593+
)
594+
return predictions_np, True, try_suffix
525595

526596
return None, False, cache_suffix
527597

@@ -554,9 +624,10 @@ def _save_metrics_to_file(self, metrics_dict: Dict[str, Any]):
554624
# Create filename with volume name and TTA pass tag
555625
volume_name = metrics_dict.get("volume_name", "unknown")
556626
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
557-
cache_suffix = resolve_prediction_cache_suffix(self.cfg, mode=mode)
558-
# Extract tag like "tta_x16" or "x1" from suffix "_tta_x16_prediction.h5"
559-
tag = cache_suffix.lstrip("_").replace("_prediction.h5", "")
627+
tag = final_prediction_output_tag(
628+
self.cfg,
629+
checkpoint_path=self._get_prediction_checkpoint_path(),
630+
)
560631
metrics_file = output_dir / f"evaluation_metrics_{volume_name}_{tag}.txt"
561632

562633
# Write metrics to file
@@ -663,11 +734,15 @@ def _log_decode_experiment(
663734
decode_params["decoder"] = step.name
664735
decode_params.update(step.kwargs)
665736

737+
input_tta_prediction_name = (
738+
f"{volume_name}{tta_cache_suffix(self.cfg, checkpoint_path=self._get_prediction_checkpoint_path())}"
739+
)
740+
666741
# Columns: timestamp, volume, decoder params..., metrics...
667742
# Use a fixed column order for readability
668743
param_keys = [
669744
"decoder", "thresholds", "merge_function", "aff_threshold",
670-
"channel_order", "dust_merge_size", "dust_merge_affinity",
745+
"channel_order", "dust_merge", "dust_merge_size", "dust_merge_affinity",
671746
"dust_remove_size",
672747
]
673748
metric_keys = [
@@ -676,8 +751,8 @@ def _log_decode_experiment(
676751
"instance_f1_detail",
677752
]
678753

679-
header_cols = ["timestamp", "volume"] + param_keys + metric_keys
680-
row_vals = [timestamp, volume_name]
754+
header_cols = ["timestamp", "volume", "input_tta_prediction_name"] + param_keys + metric_keys
755+
row_vals = [timestamp, volume_name, input_tta_prediction_name]
681756
for k in param_keys:
682757
row_vals.append(str(decode_params.get(k, "")))
683758
for k in metric_keys:

connectomics/training/lightning/test_pipeline.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@
2828
)
2929
from ...metrics.metrics_seg import AdaptedRandError
3030
from ...metrics.segmentation_numpy import instance_matching, instance_matching_simple, voi
31-
from .utils import compute_tta_passes, is_tta_cache_suffix
31+
from .utils import (
32+
compute_tta_passes,
33+
final_prediction_output_tag,
34+
format_checkpoint_name_tag,
35+
format_decode_tag,
36+
format_select_channel_tag,
37+
is_tta_cache_suffix,
38+
)
3239

3340
logger = logging.getLogger(__name__)
3441

@@ -790,12 +797,14 @@ def _process_decoding_postprocessing(
790797
if save_final_predictions:
791798
logger.info("[STAGE: Saving Final Predictions]")
792799
save_start = time.time()
793-
final_tta_passes = compute_tta_passes(module.cfg)
794800
write_outputs(
795801
module.cfg,
796802
postprocessed_predictions,
797803
filenames,
798-
suffix=f"x{final_tta_passes}_prediction",
804+
suffix=final_prediction_output_tag(
805+
module.cfg,
806+
checkpoint_path=module._get_prediction_checkpoint_path(),
807+
),
799808
mode=mode,
800809
batch_meta=batch_meta,
801810
)
@@ -1058,12 +1067,14 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
10581067
logger.info("[STAGE: Saving Intermediate Predictions]")
10591068
save_start = time.time()
10601069
tta_passes = compute_tta_passes(module.cfg)
1070+
ch_tag = format_select_channel_tag(module.cfg)
1071+
checkpoint_tag = format_checkpoint_name_tag(module._get_prediction_checkpoint_path())
10611072
predictions_to_save = apply_save_prediction_transform(module.cfg, predictions_np)
10621073
write_outputs(
10631074
module.cfg,
10641075
predictions_to_save,
10651076
filenames,
1066-
suffix=f"tta_x{tta_passes}_prediction",
1077+
suffix=f"tta_x{tta_passes}{ch_tag}{checkpoint_tag}_prediction",
10671078
mode=mode,
10681079
batch_meta=batch.get("image_meta_dict"),
10691080
)

0 commit comments

Comments
 (0)