1616from __future__ import annotations
1717
1818import logging
19+ import os
1920import warnings
2021from pathlib import Path
2122from typing import Any , Dict , List , Optional , Union
4445from ...models import build_model
4546from ...models .loss import create_loss , get_loss_metadata_for_module
4647from ..debugging import DebugManager
47- from .utils import is_tta_cache_suffix , resolve_prediction_cache_suffix , tta_cache_suffix
48+ from .utils import (
49+ final_prediction_output_tag ,
50+ format_checkpoint_name_tag ,
51+ is_tta_cache_suffix ,
52+ resolve_prediction_cache_suffix ,
53+ tta_cache_suffix ,
54+ tta_cache_suffix_candidates ,
55+ )
4856
4957# Import training/inference components
5058from ..loss import LossOrchestrator , build_loss_weighter , infer_num_loss_tasks_from_config
@@ -411,11 +419,36 @@ def _resolve_test_output_config(
411419 mode = "test"
412420 save_pred_cfg = self ._get_runtime_inference_config ().save_prediction
413421 output_dir_value = getattr (save_pred_cfg , "output_path" , None )
414- cache_suffix = resolve_prediction_cache_suffix (self .cfg , mode = mode )
422+ cache_suffix = resolve_prediction_cache_suffix (
423+ self .cfg ,
424+ mode = mode ,
425+ checkpoint_path = self ._get_prediction_checkpoint_path (),
426+ )
415427
416428 filenames = resolve_output_filenames (self .cfg , batch , global_step = self .global_step )
417429 return mode , output_dir_value , cache_suffix , filenames
418430
431+ def _get_prediction_checkpoint_path (self ) -> str :
432+ """Return the checkpoint/weights path whose stem should tag prediction caches."""
433+ explicit_path = getattr (self , "_prediction_checkpoint_path" , None )
434+ if explicit_path is not None :
435+ path_value = str (explicit_path ).strip ()
436+ if path_value :
437+ return path_value
438+
439+ trainer = getattr (self , "_trainer" , None )
440+ trainer_ckpt_path = getattr (trainer , "ckpt_path" , None ) if trainer is not None else None
441+ if trainer_ckpt_path is not None :
442+ path_value = str (trainer_ckpt_path ).strip ()
443+ if path_value :
444+ return path_value
445+
446+ external_weights_path = getattr (getattr (self .cfg , "model" , None ), "external_weights_path" , None )
447+ if isinstance (external_weights_path , str ) and external_weights_path .strip ():
448+ return external_weights_path .strip ()
449+
450+ return ""
451+
419452 def _resolve_tta_result_path_override (self ) -> str :
420453 """Return explicit intermediate prediction file from inference.tta_result_path."""
421454 inference_cfg = self ._get_runtime_inference_config ()
@@ -434,7 +467,7 @@ def _load_cached_predictions(
434467 if not pred_file .is_absolute ():
435468 pred_file = Path .cwd () / pred_file
436469
437- if pred_file . exists ():
470+ if os . path . exists (pred_file ):
438471 try :
439472 logger .info (f"Using explicit inference.tta_result_path file: { pred_file } " )
440473 pred = read_volume (str (pred_file ), dataset = "main" )
@@ -446,7 +479,14 @@ def _load_cached_predictions(
446479 f"{ len (filenames )} filenames; decoding will use the explicit file only."
447480 )
448481 # Treat explicit file as intermediate prediction so decoding still runs.
449- return pred , True , tta_cache_suffix (self .cfg )
482+ return (
483+ pred ,
484+ True ,
485+ tta_cache_suffix (
486+ self .cfg ,
487+ checkpoint_path = self ._get_prediction_checkpoint_path (),
488+ ),
489+ )
450490 except Exception as e :
451491 logger .warning (
452492 f"Failed to load explicit inference.tta_result_path file { pred_file } : { e } . "
@@ -462,14 +502,22 @@ def _load_cached_predictions(
462502 return None , False , cache_suffix
463503
464504 output_dir = Path (output_dir_value )
505+ checkpoint_tag = format_checkpoint_name_tag (self ._get_prediction_checkpoint_path ())
465506
466507 # Build ordered list of suffixes to try: final prediction first, then
467508 # intermediate TTA, then glob fallback.
468509 suffixes_to_try : list [str ] = []
469510 if is_tta_cache_suffix (cache_suffix ):
470511 # Prefer the final decoded file (e.g. _x16_prediction.h5) over
471512 # the intermediate TTA file (e.g. _tta_x16_prediction.h5).
472- final_suffix = cache_suffix .replace ("_tta_x" , "_x" )
513+ final_suffix = (
514+ "_"
515+ + final_prediction_output_tag (
516+ self .cfg ,
517+ checkpoint_path = self ._get_prediction_checkpoint_path (),
518+ )
519+ + ".h5"
520+ )
473521 suffixes_to_try .append (final_suffix )
474522 suffixes_to_try .append (cache_suffix )
475523
@@ -478,7 +526,7 @@ def _load_cached_predictions(
478526 all_exist = True
479527 for filename in filenames :
480528 pred_file = output_dir / f"{ filename } { try_suffix } "
481- if pred_file . exists ():
529+ if os . path . exists (pred_file ):
482530 try :
483531 pred = read_volume (str (pred_file ), dataset = "main" )
484532 existing_predictions .append (pred )
@@ -507,21 +555,43 @@ def _load_cached_predictions(
507555 )
508556 return predictions_np , True , try_suffix
509557
510- # Glob fallback: look for any TTA intermediate file.
558+ # Targeted fallback: look for the exact TTA intermediate cache suffix
559+ # matching the current config rather than any arbitrary TTA file.
511560 if mode == "test" and not is_tta_cache_suffix (cache_suffix ):
512- for filename in filenames :
513- tta_matches = sorted (output_dir .glob (f"{ filename } _tta_x*_prediction.h5" ))
514- if tta_matches :
515- pred_file = tta_matches [- 1 ]
516- loaded_suffix = pred_file .name [len (filename ):]
561+ fallback_suffixes = tta_cache_suffix_candidates (
562+ self .cfg ,
563+ checkpoint_path = self ._get_prediction_checkpoint_path (),
564+ )
565+ for try_suffix in fallback_suffixes :
566+ existing_predictions = []
567+ all_exist = True
568+ for filename in filenames :
569+ pred_file = output_dir / f"{ filename } { try_suffix } "
570+ if not os .path .exists (pred_file ):
571+ all_exist = False
572+ break
517573 try :
518574 pred = read_volume (str (pred_file ), dataset = "main" )
519- if pred .ndim < 4 :
520- pred = pred [np .newaxis , ...]
521- logger .info ("Loaded fallback TTA prediction: %s" , pred_file .name )
522- return pred , True , loaded_suffix
575+ existing_predictions .append (pred )
523576 except Exception as e :
524577 logger .warning (f"Failed to load { pred_file } : { e } " )
578+ all_exist = False
579+ break
580+ if all_exist and len (existing_predictions ) == len (filenames ):
581+ logger .info (
582+ "Loaded fallback TTA prediction(s) using exact suffix %s" ,
583+ try_suffix ,
584+ )
585+ if len (existing_predictions ) == 1 :
586+ predictions_np = existing_predictions [0 ]
587+ if predictions_np .ndim < 4 :
588+ predictions_np = predictions_np [np .newaxis , ...]
589+ else :
590+ predictions_np = np .stack (
591+ [p [np .newaxis , ...] if p .ndim < 4 else p for p in existing_predictions ],
592+ axis = 0 ,
593+ )
594+ return predictions_np , True , try_suffix
525595
526596 return None , False , cache_suffix
527597
@@ -554,9 +624,10 @@ def _save_metrics_to_file(self, metrics_dict: Dict[str, Any]):
554624 # Create filename with volume name and TTA pass tag
555625 volume_name = metrics_dict .get ("volume_name" , "unknown" )
556626 timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
557- cache_suffix = resolve_prediction_cache_suffix (self .cfg , mode = mode )
558- # Extract tag like "tta_x16" or "x1" from suffix "_tta_x16_prediction.h5"
559- tag = cache_suffix .lstrip ("_" ).replace ("_prediction.h5" , "" )
627+ tag = final_prediction_output_tag (
628+ self .cfg ,
629+ checkpoint_path = self ._get_prediction_checkpoint_path (),
630+ )
560631 metrics_file = output_dir / f"evaluation_metrics_{ volume_name } _{ tag } .txt"
561632
562633 # Write metrics to file
@@ -663,11 +734,15 @@ def _log_decode_experiment(
663734 decode_params ["decoder" ] = step .name
664735 decode_params .update (step .kwargs )
665736
737+ input_tta_prediction_name = (
738+ f"{ volume_name } { tta_cache_suffix (self .cfg , checkpoint_path = self ._get_prediction_checkpoint_path ())} "
739+ )
740+
666741 # Columns: timestamp, volume, decoder params..., metrics...
667742 # Use a fixed column order for readability
668743 param_keys = [
669744 "decoder" , "thresholds" , "merge_function" , "aff_threshold" ,
670- "channel_order" , "dust_merge_size" , "dust_merge_affinity" ,
745+ "channel_order" , "dust_merge" , " dust_merge_size" , "dust_merge_affinity" ,
671746 "dust_remove_size" ,
672747 ]
673748 metric_keys = [
@@ -676,8 +751,8 @@ def _log_decode_experiment(
676751 "instance_f1_detail" ,
677752 ]
678753
679- header_cols = ["timestamp" , "volume" ] + param_keys + metric_keys
680- row_vals = [timestamp , volume_name ]
754+ header_cols = ["timestamp" , "volume" , "input_tta_prediction_name" ] + param_keys + metric_keys
755+ row_vals = [timestamp , volume_name , input_tta_prediction_name ]
681756 for k in param_keys :
682757 row_vals .append (str (decode_params .get (k , "" )))
683758 for k in metric_keys :
0 commit comments