Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions examples/apps/cchmc_nnunet_fifteen_ckpt_app/convert_nnunet_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,25 @@
if current_dir not in sys.path:
sys.path.insert(0, current_dir)

# Try importing from local apps.nnunet_bundle instead of from MONAI
try:
from my_app.nnunet_bundle import convert_best_nnunet_to_monai_bundle
except ImportError:
# If local import fails, try to find the module in alternate locations

def _import_converter():
"""Deferred import so nnunetv2 is loaded AFTER nnUNet_results env var is set,
preventing nnunetv2.paths from caching a None value for nnUNet_results."""
try:
from my_app.nnunet_bundle import convert_best_nnunet_to_monai_bundle

return convert_best_nnunet_to_monai_bundle
except ImportError:
pass
try:
from monai.apps.nnunet_bundle import convert_best_nnunet_to_monai_bundle

return convert_best_nnunet_to_monai_bundle
except ImportError:
print(
"Error: Could not import convert_best_nnunet_to_monai_bundle from my_app.nnunet_bundle or apps.nnunet_bundle"
)
print("Please ensure that nnunet_bundle.py is properly installed in your project.")
sys.exit(1)
pass
print("Error: Could not import convert_best_nnunet_to_monai_bundle from my_app.nnunet_bundle or apps.nnunet_bundle")
print("Please ensure that nnunet_bundle.py is properly installed in your project.")
sys.exit(1)


def parse_args():
Expand All @@ -57,6 +63,15 @@ def parse_args():
default=None,
help="Path to nnUNet results directory with trained models.",
)
parser.add_argument(
"--checkpoint_type",
type=str,
default="final",
choices=["final", "best", "both"],
help="Which nnUNet checkpoint(s) to convert: 'final' (default) saves checkpoint_final.pth weights as "
"final_model.pt; 'best' saves checkpoint_best.pth weights as best_model.pt; "
"'both' saves checkpoint_final.pth as final_model.pt and checkpoint_best.pth as best_model.pt.",
)
return parser.parse_args()


Expand Down Expand Up @@ -90,9 +105,12 @@ def main():
print(f"MAP will be created at: {map_root}")
print(f" nnUNet_results: {os.environ.get('nnUNet_results')}")

# Import AFTER env vars are set so nnunetv2.paths caches the correct nnUNet_results value
convert_best_nnunet_to_monai_bundle = _import_converter()

# Convert the nnUNet checkpoints to MONAI bundle format
try:
convert_best_nnunet_to_monai_bundle(nnunet_config, map_root)
convert_best_nnunet_to_monai_bundle(nnunet_config, map_root, checkpoint_type=args.checkpoint_type)
print(f"Successfully converted nnUNet checkpoints to MONAI bundle at: {map_root}/models")
except Exception as e:
print(f"Error converting nnUNet checkpoints: {e}")
Expand Down
Loading
Loading