Skip to content

Commit 11d8808

Browse files
committed
Sync nnUNet robustness updates to cchmc app from seg_metrics_op
1 parent 1f1292f commit 11d8808

3 files changed

Lines changed: 369 additions & 105 deletions

File tree

examples/apps/cchmc_nnunet_fifteen_ckpt_app/convert_nnunet_ckpts.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,24 @@
2323
if current_dir not in sys.path:
2424
sys.path.insert(0, current_dir)
2525

26-
# Try importing from local apps.nnunet_bundle instead of from MONAI
27-
try:
28-
from my_app.nnunet_bundle import convert_best_nnunet_to_monai_bundle
29-
except ImportError:
30-
# If local import fails, try to find the module in alternate locations
26+
def _import_converter():
27+
"""Deferred import so nnunetv2 is loaded AFTER nnUNet_results env var is set,
28+
preventing nnunetv2.paths from caching a None value for nnUNet_results."""
29+
try:
30+
from my_app.nnunet_bundle import convert_best_nnunet_to_monai_bundle
31+
return convert_best_nnunet_to_monai_bundle
32+
except ImportError:
33+
pass
3134
try:
3235
from monai.apps.nnunet_bundle import convert_best_nnunet_to_monai_bundle
36+
return convert_best_nnunet_to_monai_bundle
3337
except ImportError:
34-
print(
35-
"Error: Could not import convert_best_nnunet_to_monai_bundle from my_app.nnunet_bundle or apps.nnunet_bundle"
36-
)
37-
print("Please ensure that nnunet_bundle.py is properly installed in your project.")
38-
sys.exit(1)
38+
pass
39+
print(
40+
"Error: Could not import convert_best_nnunet_to_monai_bundle from my_app.nnunet_bundle or apps.nnunet_bundle"
41+
)
42+
print("Please ensure that nnunet_bundle.py is properly installed in your project.")
43+
sys.exit(1)
3944

4045

4146
def parse_args():
@@ -57,6 +62,15 @@ def parse_args():
5762
default=None,
5863
help="Path to nnUNet results directory with trained models.",
5964
)
65+
parser.add_argument(
66+
"--checkpoint_type",
67+
type=str,
68+
default="final",
69+
choices=["final", "best", "both"],
70+
help="Which nnUNet checkpoint(s) to convert: 'final' (default) saves checkpoint_final.pth weights as "
71+
"final_model.pt; 'best' saves checkpoint_best.pth weights as best_model.pt; "
72+
"'both' saves checkpoint_final.pth as final_model.pt and checkpoint_best.pth as best_model.pt.",
73+
)
6074
return parser.parse_args()
6175

6276

@@ -90,9 +104,12 @@ def main():
90104
print(f"MAP will be created at: {map_root}")
91105
print(f" nnUNet_results: {os.environ.get('nnUNet_results')}")
92106

107+
# Import AFTER env vars are set so nnunetv2.paths caches the correct nnUNet_results value
108+
convert_best_nnunet_to_monai_bundle = _import_converter()
109+
93110
# Convert the nnUNet checkpoints to MONAI bundle format
94111
try:
95-
convert_best_nnunet_to_monai_bundle(nnunet_config, map_root)
112+
convert_best_nnunet_to_monai_bundle(nnunet_config, map_root, checkpoint_type=args.checkpoint_type)
96113
print(f"Successfully converted nnUNet checkpoints to MONAI bundle at: {map_root}/models")
97114
except Exception as e:
98115
print(f"Error converting nnUNet checkpoints: {e}")

0 commit comments

Comments
 (0)