1616
1717import numpy as np
1818import torch
19- from torch ._dynamo import OptimizedModule
2019from torch .backends import cudnn
2120
2221from monai .data .meta_tensor import MetaTensor
2524join , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "join" )
2625load_json , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "load_json" )
2726
28- __all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "nnUNetMONAIModelWrapper " ]
27+ __all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "convert_nnunet_to_monai_bundle" , "convert_monai_bundle_to_nnunet" , "ModelnnUNetWrapper " ]
2928
3029
3130def get_nnunet_trainer (
@@ -42,33 +41,33 @@ def get_nnunet_trainer(
4241 only_run_validation = False ,
4342 disable_checkpointing = False ,
4443 val_with_best = False ,
45- device = torch . device ( "cuda" ) ,
44+ device = "cuda" ,
4645 pretrained_model = None ,
4746):
4847 """
4948 Get the nnUNet trainer instance based on the provided configuration.
5049 The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
5150 optimizer, loss function, DataLoader, etc.
5251
53- ```python
54- from monai.apps import SupervisedTrainer
55- from monai.bundle.nnunet import get_nnunet_trainer
52+ Example::
5653
57- dataset_name_or_id = 'Task101_PROSTATE'
58- fold = 0
59- configuration = '3d_fullres'
60- nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
54+ from monai.apps import SupervisedTrainer
55+ from monai.bundle.nnunet import get_nnunet_trainer
6156
62- trainer = SupervisedTrainer(
63- device=nnunet_trainer.device,
64- max_epochs=nnunet_trainer.num_epochs,
65- train_data_loader=nnunet_trainer.dataloader_train,
66- network=nnunet_trainer.network,
67- optimizer=nnunet_trainer.optimizer,
68- loss_function=nnunet_trainer.loss_function,
69- epoch_length=nnunet_trainer.num_iterations_per_epoch,
57+ dataset_name_or_id = 'Task009_Spleen'
58+ fold = 0
59+ configuration = '3d_fullres'
60+ nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
7061
71- ```
62+ trainer = SupervisedTrainer(
63+ device=nnunet_trainer.device,
64+ max_epochs=nnunet_trainer.num_epochs,
65+ train_data_loader=nnunet_trainer.dataloader_train,
66+ network=nnunet_trainer.network,
67+ optimizer=nnunet_trainer.optimizer,
68+ loss_function=nnunet_trainer.loss_function,
69+ epoch_length=nnunet_trainer.num_iterations_per_epoch,
70+ )
7271
7372 Parameters
7473 ----------
@@ -98,7 +97,7 @@ def get_nnunet_trainer(
9897 Whether to disable checkpointing. Default is False.
9998 val_with_best : bool, optional
10099 Whether to validate with the best model. Default is False.
101- device : torch.device , optional
100+ device : str , optional
102101 The device to be used for training. Default is 'cuda'.
103102 pretrained_model : str, optional
104103 Path to the pretrained model file.
@@ -130,7 +129,7 @@ def get_nnunet_trainer(
130129 trainer_class_name ,
131130 plans_identifier ,
132131 use_compressed_data ,
133- device = device ,
132+ device = torch . device ( device ) ,
134133 )
135134 if disable_checkpointing :
136135 nnunet_trainer .disable_checkpointing = disable_checkpointing
@@ -150,7 +149,7 @@ def get_nnunet_trainer(
150149 return nnunet_trainer
151150
152151
153- class nnUNetMONAIModelWrapper (torch .nn .Module ):
152+ class ModelnnUNetWrapper (torch .nn .Module ):
154153 """
155154 A wrapper class for nnUNet model integration with MONAI framework.
156155 The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
@@ -163,16 +162,14 @@ class nnUNetMONAIModelWrapper(torch.nn.Module):
163162 The folder path where the model and related files are stored.
164163 model_name : str, optional
165164 The name of the model file, by default "model.pt".
165+
166166 Attributes
167167 ----------
168- predictor : object
169- The predictor object used for inference.
168+ predictor : nnUNetPredictor
169+ The nnUNet predictor object used for inference.
170170 network_weights : torch.nn.Module
171171 The network weights of the model.
172- Methods
173- -------
174- forward(x)
175- Perform forward pass and prediction on the input data.
172+
176173 Notes
177174 -----
178175 This class integrates nnUNet model with MONAI framework by loading necessary configurations,
@@ -184,23 +181,23 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
184181 self .predictor = predictor
185182
186183 model_training_output_dir = model_folder
187- use_folds = "0"
184+ use_folds = [ "0" ]
188185
189186 from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
190187
191- ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
192- dataset_json = load_json (join (model_training_output_dir , "dataset.json" ))
193- plans = load_json (join (model_training_output_dir , "plans.json" ))
188+ # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
189+ dataset_json = load_json (join (Path ( model_training_output_dir ). parent , "dataset.json" ))
190+ plans = load_json (join (Path ( model_training_output_dir ). parent , "plans.json" ))
194191 plans_manager = PlansManager (plans )
195192
196193 if isinstance (use_folds , str ):
197194 use_folds = [use_folds ]
198195
199196 parameters = []
200197 for i , f in enumerate (use_folds ):
201- f = int (f ) if f != "all" else f
198+ f = str (f ) if f != "all" else f
202199 checkpoint = torch .load (
203- join (model_training_output_dir , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
200+ join (Path ( model_training_output_dir ). parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
204201 )
205202 if i == 0 :
206203 trainer_name = checkpoint ["trainer_name" ]
@@ -254,32 +251,67 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
254251 if (
255252 ("nnUNet_compile" in os .environ .keys ())
256253 and (os .environ ["nnUNet_compile" ].lower () in ("true" , "1" , "t" ))
257- and not isinstance (predictor .network , OptimizedModule )
254+ # and not isinstance(predictor.network, OptimizedModule)
258255 ):
259256 print ("Using torch.compile" )
260- predictor .network = torch .compile (self .network )
261- ## End Block
257+ # predictor.network = torch.compile(self.network)
258+ # End Block
262259 self .network_weights = self .predictor .network
263260
264- def forward (self , x ):
265- if type (x ) is tuple : # if batch is decollated (list of tensors)
266- input_files = [img .meta ["filename_or_obj" ][0 ] for img in x ]
267- else : # if batch is collated
268- input_files = x .meta ["filename_or_obj" ]
269- if type (input_files ) is str :
270- input_files = [input_files ]
261+ def forward (self , x : MetaTensor ) -> MetaTensor :
262+ """
263+ Forward pass for the nnUNet model.
264+
265+ :no-index:
266+
267+ Args:
268+ x (MetaTensor): Input tensor. If the input is a tuple,
269+ it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
270+
271+ Returns:
272+ MetaTensor: The output tensor with the same metadata as the input.
273+
274+ Raises:
275+ TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.
276+
277+ Notes:
278+ - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
279+ - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
280+ - The filenames are used to generate predictions using the nnUNet predictor.
281+ - The predictions are converted to torch tensors, with added batch and channel dimensions.
282+ - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
283+ """
284+ # if isinstance(x, tuple): # if batch is decollated (list of tensors)
285+ # properties_or_list_of_properties = []
286+ # image_or_list_of_images = []
287+
288+ # for img in x:
289+ # if isinstance(img, MetaTensor):
290+ # properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()})
291+ # image_or_list_of_images.append(img.cpu().numpy()[0,:])
292+ # else:
293+ # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
294+
295+ # else: # if batch is collated
296+ if isinstance (x , MetaTensor ):
297+ if "pixdim" in x .meta :
298+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
299+ else :
300+ properties_or_list_of_properties = {"spacing" : [1.0 , 1.0 , 1.0 ]}
301+ else :
302+ raise TypeError ("Input must be a MetaTensor or a tuple of MetaTensors." )
303+
304+ image_or_list_of_images = x .cpu ().numpy ()[0 , :]
271305
272306 # input_files should be a list of file paths, one per modality
273- prediction_output = self .predictor .predict_from_files (
274- [ input_files ] ,
307+ prediction_output = self .predictor .predict_from_list_of_npy_arrays (
308+ image_or_list_of_images ,
275309 None ,
310+ properties_or_list_of_properties ,
311+ truncated_ofname = None ,
276312 save_probabilities = False ,
277- overwrite = True ,
278- num_processes_preprocessing = 2 ,
313+ num_processes = 2 ,
279314 num_processes_segmentation_export = 2 ,
280- folder_with_segs_from_prev_stage = None ,
281- num_parts = 1 ,
282- part_id = 0 ,
283315 )
284316 # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
285317
@@ -288,35 +320,36 @@ def forward(self, x):
288320 out_tensors .append (torch .from_numpy (np .expand_dims (np .expand_dims (out , 0 ), 0 )))
289321 out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
290322
291- if type (x ) is tuple :
292- return MetaTensor (out_tensor , meta = x [0 ].meta )
293- else :
294- return MetaTensor (out_tensor , meta = x .meta )
323+ # if type(x) is tuple:
324+ # return MetaTensor(out_tensor, meta=x[0].meta)
325+ # else:
326+ return MetaTensor (out_tensor , meta = x .meta )
295327
296328
297329def get_nnunet_monai_predictor (model_folder , model_name = "model.pt" ):
298330 """
299- Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor.
331+ Initializes and returns a ` nnUNetMONAIModelWrapper` containing the corresponding ` nnUNetPredictor` .
300332 The model folder should contain the following files, created during training:
301- - dataset.json: from the nnUNet results folder.
302- - plans .json: from the nnUNet results folder.
303- - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
304- (`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`).
305- - model.pt: The checkpoint file containing the model weights.
333+
334+ - dataset .json: from the nnUNet results folder
335+ - plans.json: from the nnUNet results folder
336+ - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
337+ - model.pt: The checkpoint file containing the model weights.
306338
307339 The returned wrapper object can be used for inference with MONAI framework:
308- ```python
309- from monai.bundle.nnunet import get_nnunet_monai_predictor
310340
311- model_folder = 'path/to/monai_bundle/model'
312- model_name = 'model.pt'
313- wrapper = get_nnunet_monai_predictor(model_folder, model_name)
341+ Example::
342+
343+ from monai.bundle.nnunet import get_nnunet_monai_predictor
344+
345+ model_folder = 'path/to/monai_bundle/model'
346+ model_name = 'model.pt'
347+ wrapper = get_nnunet_monai_predictor(model_folder, model_name)
314348
315- # Perform inference
316- input_data = ...
317- output = wrapper(input_data)
349+ # Perform inference
350+ input_data = ...
351+ output = wrapper(input_data)
318352
319- ```
320353
321354 Parameters
322355 ----------
@@ -343,7 +376,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
343376 allow_tqdm = True ,
344377 )
345378 # initializes the network architecture, loads the checkpoint
346- wrapper = nnUNetMONAIModelWrapper (predictor , model_folder , model_name )
379+ wrapper = ModelnnUNetWrapper (predictor , model_folder , model_name )
347380 return wrapper
348381
349382
@@ -396,13 +429,14 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
396429
397430 torch .save (nnunet_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "nnunet_checkpoint.pth" ))
398431
432+ Path (bundle_root_folder ).joinpath ("models" , f"fold_{ fold } " ).mkdir (parents = True , exist_ok = True )
399433 monai_last_checkpoint = {}
400434 monai_last_checkpoint ["network_weights" ] = nnunet_checkpoint_final ["network_weights" ]
401- torch .save (monai_last_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "model.pt" ))
435+ torch .save (monai_last_checkpoint , Path (bundle_root_folder ).joinpath ("models" , f"fold_ { fold } " , "model.pt" ))
402436
403437 monai_best_checkpoint = {}
404438 monai_best_checkpoint ["network_weights" ] = nnunet_checkpoint_best ["network_weights" ]
405- torch .save (monai_best_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "best_model.pt" ))
439+ torch .save (monai_best_checkpoint , Path (bundle_root_folder ).joinpath ("models" , f"fold_ { fold } " , "best_model.pt" ))
406440
407441 if not os .path .exists (os .path .join (bundle_root_folder , "models" , "plans.json" )):
408442 shutil .copy (
0 commit comments