2121from monai .data .meta_tensor import MetaTensor
2222from monai .utils import optional_import
2323
24+ from typing import Union , Optional
2425join , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "join" )
2526load_json , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "load_json" )
2627
2728__all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "convert_nnunet_to_monai_bundle" , "convert_monai_bundle_to_nnunet" ,"ModelnnUNetWrapper" ]
2829
2930
3031def get_nnunet_trainer (
31- dataset_name_or_id ,
32- configuration ,
33- fold ,
34- trainer_class_name = "nnUNetTrainer" ,
35- plans_identifier = "nnUNetPlans" ,
36- pretrained_weights = None ,
37- num_gpus = 1 ,
38- use_compressed_data = False ,
39- export_validation_probabilities = False ,
40- continue_training = False ,
41- only_run_validation = False ,
42- disable_checkpointing = False ,
43- val_with_best = False ,
44- device = "cuda" ,
45- pretrained_model = None ,
46- ):
32+ dataset_name_or_id : Union [str , int ],
33+ configuration : str ,
34+ fold : Union [int , str ],
35+ trainer_class_name : str = "nnUNetTrainer" ,
36+ plans_identifier : str = "nnUNetPlans" ,
37+ use_compressed_data : bool = False ,
38+ continue_training : bool = False ,
39+ only_run_validation : bool = False ,
40+ disable_checkpointing : bool = False ,
41+ device : str = "cuda" ,
42+ pretrained_model : Optional [str ] = None ,
43+ ) -> object :
4744 """
4845 Get the nnUNet trainer instance based on the provided configuration.
4946 The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
@@ -81,29 +78,22 @@ def get_nnunet_trainer(
8178 The class name of the trainer to be used. Default is 'nnUNetTrainer'.
8279 plans_identifier : str, optional
8380 Identifier for the plans to be used. Default is 'nnUNetPlans'.
84- pretrained_weights : str, optional
85- Path to the pretrained weights file.
86- num_gpus : int, optional
87- Number of GPUs to be used. Default is 1.
8881 use_compressed_data : bool, optional
8982 Whether to use compressed data. Default is False.
90- export_validation_probabilities : bool, optional
91- Whether to export validation probabilities. Default is False.
9283 continue_training : bool, optional
9384 Whether to continue training from a checkpoint. Default is False.
9485 only_run_validation : bool, optional
9586 Whether to only run validation. Default is False.
9687 disable_checkpointing : bool, optional
9788 Whether to disable checkpointing. Default is False.
98- val_with_best : bool, optional
99- Whether to validate with the best model. Default is False.
10089 device : str, optional
10190 The device to be used for training. Default is 'cuda'.
102- pretrained_model : str, optional
91+ pretrained_model : Optional[ str] , optional
10392 Path to the pretrained model file.
93+
10494 Returns
10595 -------
106- nnunet_trainer
96+ nnunet_trainer : object
10797 The nnUNet trainer instance.
10898 """
10999 # From nnUNet/nnunetv2/run/run_training.py#run_training
@@ -117,36 +107,34 @@ def get_nnunet_trainer(
117107 )
118108 raise e
119109
120- if int (num_gpus ) > 1 :
121- ... # Disable for now
122- else :
123- from nnunetv2 .run .run_training import get_trainer_from_args , maybe_load_checkpoint
124-
125- nnunet_trainer = get_trainer_from_args (
126- str (dataset_name_or_id ),
127- configuration ,
128- fold ,
129- trainer_class_name ,
130- plans_identifier ,
131- use_compressed_data ,
132- device = torch .device (device ),
133- )
134- if disable_checkpointing :
135- nnunet_trainer .disable_checkpointing = disable_checkpointing
136110
137- assert not (continue_training and only_run_validation ), "Cannot set --c and --val flag at the same time. Dummy."
111+ from nnunetv2 .run .run_training import get_trainer_from_args , maybe_load_checkpoint
112+
113+ nnunet_trainer = get_trainer_from_args (
114+ str (dataset_name_or_id ),
115+ configuration ,
116+ fold ,
117+ trainer_class_name ,
118+ plans_identifier ,
119+ use_compressed_data ,
120+ device = torch .device (device ),
121+ )
122+ if disable_checkpointing :
123+ nnunet_trainer .disable_checkpointing = disable_checkpointing
124+
125+ assert not (continue_training and only_run_validation ), "Cannot set --c and --val flag at the same time. Dummy."
138126
139- maybe_load_checkpoint (nnunet_trainer , continue_training , only_run_validation , pretrained_weights )
140- nnunet_trainer .on_train_start () # Added to Initialize Trainer
141- if torch .cuda .is_available ():
142- cudnn .deterministic = False
143- cudnn .benchmark = True
127+ maybe_load_checkpoint (nnunet_trainer , continue_training , only_run_validation )
128+ nnunet_trainer .on_train_start () # Added to Initialize Trainer
129+ if torch .cuda .is_available ():
130+ cudnn .deterministic = False
131+ cudnn .benchmark = True
144132
145- if pretrained_model is not None :
146- state_dict = torch .load (pretrained_model )
147- if "network_weights" in state_dict :
148- nnunet_trainer .network ._orig_mod .load_state_dict (state_dict ["network_weights" ])
149- return nnunet_trainer
133+ if pretrained_model is not None :
134+ state_dict = torch .load (pretrained_model )
135+ if "network_weights" in state_dict :
136+ nnunet_trainer .network ._orig_mod .load_state_dict (state_dict ["network_weights" ])
137+ return nnunet_trainer
150138
151139
152140class ModelnnUNetWrapper (torch .nn .Module ):
@@ -176,12 +164,11 @@ class ModelnnUNetWrapper(torch.nn.Module):
176164 restoring network architecture, and setting up the predictor for inference.
177165 """
178166
179- def __init__ (self , predictor , model_folder , model_name = "model.pt" ):
167+ def __init__ (self , predictor : object , model_folder : str , model_name : str = "model.pt" ):
180168 super ().__init__ ()
181169 self .predictor = predictor
182170
183171 model_training_output_dir = model_folder
184- use_folds = ["0" ]
185172
186173 from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
187174
@@ -190,31 +177,26 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
190177 plans = load_json (join (Path (model_training_output_dir ).parent , "plans.json" ))
191178 plans_manager = PlansManager (plans )
192179
193- if isinstance (use_folds , str ):
194- use_folds = [use_folds ]
195-
196180 parameters = []
197- for i , f in enumerate (use_folds ):
198- f = str (f ) if f != "all" else f
199- checkpoint = torch .load (
200- join (Path (model_training_output_dir ).parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
181+
182+ checkpoint = torch .load (
183+ join (Path (model_training_output_dir ).parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
184+ )
185+ trainer_name = checkpoint ["trainer_name" ]
186+ configuration_name = checkpoint ["init_args" ]["configuration" ]
187+ inference_allowed_mirroring_axes = (
188+ checkpoint ["inference_allowed_mirroring_axes" ]
189+ if "inference_allowed_mirroring_axes" in checkpoint .keys ()
190+ else None
191+ )
192+ if Path (model_training_output_dir ).joinpath (model_name ).is_file ():
193+ monai_checkpoint = torch .load (
194+ join (model_training_output_dir , model_name ), map_location = torch .device ("cpu" )
201195 )
202- if i == 0 :
203- trainer_name = checkpoint ["trainer_name" ]
204- configuration_name = checkpoint ["init_args" ]["configuration" ]
205- inference_allowed_mirroring_axes = (
206- checkpoint ["inference_allowed_mirroring_axes" ]
207- if "inference_allowed_mirroring_axes" in checkpoint .keys ()
208- else None
209- )
210- if Path (model_training_output_dir ).joinpath (f"fold_{ f } " , model_name ).is_file ():
211- monai_checkpoint = torch .load (
212- join (model_training_output_dir , model_name ), map_location = torch .device ("cpu" )
213- )
214- if "network_weights" in monai_checkpoint .keys ():
215- parameters .append (monai_checkpoint ["network_weights" ])
216- else :
217- parameters .append (monai_checkpoint )
196+ if "network_weights" in monai_checkpoint .keys ():
197+ parameters .append (monai_checkpoint ["network_weights" ])
198+ else :
199+ parameters .append (monai_checkpoint )
218200
219201 configuration_manager = plans_manager .get_configuration (configuration_name )
220202 # restore network
@@ -251,10 +233,8 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
251233 if (
252234 ("nnUNet_compile" in os .environ .keys ())
253235 and (os .environ ["nnUNet_compile" ].lower () in ("true" , "1" , "t" ))
254- # and not isinstance(predictor.network, OptimizedModule)
255236 ):
256237 print ("Using torch.compile" )
257- # predictor.network = torch.compile(self.network)
258238 # End Block
259239 self .network_weights = self .predictor .network
260240
@@ -281,21 +261,12 @@ def forward(self, x: MetaTensor) -> MetaTensor:
281261 - The predictions are converted to torch tensors, with added batch and channel dimensions.
282262 - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
283263 """
284- # if isinstance(x, tuple): # if batch is decollated (list of tensors)
285- # properties_or_list_of_properties = []
286- # image_or_list_of_images = []
287-
288- # for img in x:
289- # if isinstance(img, MetaTensor):
290- # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()})
291- # image_or_list_of_images.append(img.cpu().numpy()[0,:])
292- # else:
293- # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
294-
295- # else: # if batch is collated
296264 if isinstance (x , MetaTensor ):
297265 if "pixdim" in x .meta :
298266 properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
267+ elif "affine" in x .meta :
268+ spacing = [abs (x .meta ['affine' ][0 ][0 ].item ()), abs (x .meta ['affine' ][1 ][1 ].item ()), abs (x .meta ['affine' ][2 ][2 ].item ())]
269+ properties_or_list_of_properties = {"spacing" : spacing }
299270 else :
300271 properties_or_list_of_properties = {"spacing" : [1.0 , 1.0 , 1.0 ]}
301272 else :
@@ -320,13 +291,10 @@ def forward(self, x: MetaTensor) -> MetaTensor:
320291 out_tensors .append (torch .from_numpy (np .expand_dims (np .expand_dims (out , 0 ), 0 )))
321292 out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
322293
323- # if type(x) is tuple:
324- # return MetaTensor(out_tensor, meta=x[0].meta)
325- # else:
326294 return MetaTensor (out_tensor , meta = x .meta )
327295
328296
329- def get_nnunet_monai_predictor (model_folder , model_name = "model.pt" ):
297+ def get_nnunet_monai_predictor (model_folder : str , model_name : str = "model.pt" ) -> ModelnnUNetWrapper :
330298 """
331299 Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
332300 The model folder should contain the following files, created during training:
@@ -360,7 +328,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
360328
361329 Returns
362330 -------
363- nnUNetMONAIModelWrapper
331+ ModelnnUNetWrapper
364332 A wrapper object that contains the nnUNetPredictor and the loaded model.
365333 """
366334
@@ -380,7 +348,9 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
380348 return wrapper
381349
382350
383- def convert_nnunet_to_monai_bundle (nnunet_config , bundle_root_folder , fold = 0 ):
351+ def convert_nnunet_to_monai_bundle (
352+ nnunet_config : dict , bundle_root_folder : str , fold : int = 0
353+ ) -> None :
384354 """
385355 Convert nnUNet model checkpoints and configuration to MONAI bundle format.
386356
@@ -450,7 +420,13 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
450420 )
451421
452422
453- def get_network_from_nnunet_plans (plans_file , dataset_file , configuration , model_ckpt = None , model_key_in_ckpt = "model" ):
423+ def get_network_from_nnunet_plans (
424+ plans_file : str ,
425+ dataset_file : str ,
426+ configuration : str ,
427+ model_ckpt : Optional [str ] = None ,
428+ model_key_in_ckpt : str = "model"
429+ ) -> torch .nn .Module :
454430 """
455431 Load and initialize a neural network based on nnUNet plans and configuration.
456432
@@ -462,7 +438,7 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model
462438 Path to the JSON file containing the dataset information.
463439 configuration : str
464440 The configuration name to be used from the plans.
465- model_ckpt : str, optional
441+ model_ckpt : Optional[ str] , optional
466442 Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).
467443 model_key_in_ckpt : str, optional
468444 The key in the checkpoint file that contains the model state dictionary (default is "model").
@@ -505,7 +481,11 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model
505481 return network
506482
507483
508- def convert_monai_bundle_to_nnunet (nnunet_config , bundle_root_folder , fold = 0 ):
484+ def convert_monai_bundle_to_nnunet (
485+ nnunet_config : dict ,
486+ bundle_root_folder : str ,
487+ fold : int = 0
488+ ) -> None :
509489 """
510490 Convert a MONAI bundle to nnU-Net format.
511491
@@ -527,8 +507,8 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0):
527507 """
528508 from odict import odict
529509
530- nnunet_trainer = "nnUNetTrainer"
531- nnunet_plans = "nnUNetPlans"
510+ nnunet_trainer : str = "nnUNetTrainer"
511+ nnunet_plans : str = "nnUNetPlans"
532512
533513 if "nnunet_trainer" in nnunet_config :
534514 nnunet_trainer = nnunet_config ["nnunet_trainer" ]
@@ -539,8 +519,13 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0):
539519 from nnunetv2 .training .logging .nnunet_logger import nnUNetLogger
540520 from nnunetv2 .utilities .dataset_name_id_conversion import maybe_convert_to_dataset_name
541521
542- def subfiles (folder , join : bool = True , prefix : str = None , suffix : str = None , sort : bool = True ):
543-
522+ def subfiles (
523+ folder : str ,
524+ join : bool = True ,
525+ prefix : Optional [str ] = None ,
526+ suffix : Optional [str ] = None ,
527+ sort : bool = True
528+ ) -> list [str ]:
544529 if join :
545530 l = os .path .join # noqa: E741
546531 else :
@@ -556,42 +541,42 @@ def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None,
556541 res .sort ()
557542 return res
558543
559- nnunet_model_folder = Path (os .environ ["nnUNet_results" ]).joinpath (
544+ nnunet_model_folder : Path = Path (os .environ ["nnUNet_results" ]).joinpath (
560545 maybe_convert_to_dataset_name (nnunet_config ["dataset_name_or_id" ]),
561546 f"{ nnunet_trainer } __{ nnunet_plans } __3d_fullres" ,
562547 )
563548
564- nnunet_preprocess_model_folder = Path (os .environ ["nnUNet_preprocessed" ]).joinpath (
549+ nnunet_preprocess_model_folder : Path = Path (os .environ ["nnUNet_preprocessed" ]).joinpath (
565550 maybe_convert_to_dataset_name (nnunet_config ["dataset_name_or_id" ])
566551 )
567552
568553 Path (nnunet_model_folder ).joinpath (f"fold_{ fold } " ).mkdir (parents = True , exist_ok = True )
569554
570- nnunet_checkpoint = torch .load (f"{ bundle_root_folder } /models/nnunet_checkpoint.pth" )
571- latest_checkpoints = subfiles (
555+ nnunet_checkpoint : dict = torch .load (f"{ bundle_root_folder } /models/nnunet_checkpoint.pth" )
556+ latest_checkpoints : list [ str ] = subfiles (
572557 Path (bundle_root_folder ).joinpath ("models" , f"fold_{ fold } " ), prefix = "checkpoint_epoch" , sort = True , join = False
573558 )
574- epochs = []
559+ epochs : list [ int ] = []
575560 for latest_checkpoint in latest_checkpoints :
576561 epochs .append (int (latest_checkpoint [len ("checkpoint_epoch=" ) : - len (".pt" )]))
577562
578563 epochs .sort ()
579- final_epoch = epochs [- 1 ]
580- monai_last_checkpoint = torch .load (f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_epoch={ final_epoch } .pt" )
564+ final_epoch : int = epochs [- 1 ]
565+ monai_last_checkpoint : dict = torch .load (f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_epoch={ final_epoch } .pt" )
581566
582- best_checkpoints = subfiles (
567+ best_checkpoints : list [ str ] = subfiles (
583568 Path (bundle_root_folder ).joinpath ("models" , f"fold_{ fold } " ),
584569 prefix = "checkpoint_key_metric" ,
585570 sort = True ,
586571 join = False ,
587572 )
588- key_metrics = []
573+ key_metrics : list [ str ] = []
589574 for best_checkpoint in best_checkpoints :
590575 key_metrics .append (str (best_checkpoint [len ("checkpoint_key_metric=" ) : - len (".pt" )]))
591576
592577 key_metrics .sort ()
593- best_key_metric = key_metrics [- 1 ]
594- monai_best_checkpoint = torch .load (
578+ best_key_metric : str = key_metrics [- 1 ]
579+ monai_best_checkpoint : dict = torch .load (
595580 f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_key_metric={ best_key_metric } .pt"
596581 )
597582
0 commit comments