Skip to content

Commit 1972504

Browse files
Rename nnUNetMONAIModelWrapper to ModelnnUNetWrapper for consistency
1 parent fee1bb0 commit 1972504

2 files changed

Lines changed: 108 additions & 74 deletions

File tree

monai/apps/nnunet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_network_from_nnunet_plans,
1818
get_nnunet_monai_predictor,
1919
get_nnunet_trainer,
20-
nnUNetMONAIModelWrapper,
20+
ModelnnUNetWrapper
2121
)
2222
from .nnunetv2_runner import nnUNetV2Runner
2323
from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy as np
1818
import torch
19-
from torch._dynamo import OptimizedModule
2019
from torch.backends import cudnn
2120

2221
from monai.data.meta_tensor import MetaTensor
@@ -25,7 +24,7 @@
2524
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
2625
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
2726

28-
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"]
27+
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "convert_monai_bundle_to_nnunet","ModelnnUNetWrapper"]
2928

3029

3130
def get_nnunet_trainer(
@@ -42,33 +41,33 @@ def get_nnunet_trainer(
4241
only_run_validation=False,
4342
disable_checkpointing=False,
4443
val_with_best=False,
45-
device=torch.device("cuda"),
44+
device="cuda",
4645
pretrained_model=None,
4746
):
4847
"""
4948
Get the nnUNet trainer instance based on the provided configuration.
5049
The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
5150
optimizer, loss function, DataLoader, etc.
5251
53-
```python
54-
from monai.apps import SupervisedTrainer
55-
from monai.bundle.nnunet import get_nnunet_trainer
52+
Example::
5653
57-
dataset_name_or_id = 'Task101_PROSTATE'
58-
fold = 0
59-
configuration = '3d_fullres'
60-
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
54+
from monai.apps import SupervisedTrainer
55+
from monai.bundle.nnunet import get_nnunet_trainer
6156
62-
trainer = SupervisedTrainer(
63-
device=nnunet_trainer.device,
64-
max_epochs=nnunet_trainer.num_epochs,
65-
train_data_loader=nnunet_trainer.dataloader_train,
66-
network=nnunet_trainer.network,
67-
optimizer=nnunet_trainer.optimizer,
68-
loss_function=nnunet_trainer.loss_function,
69-
epoch_length=nnunet_trainer.num_iterations_per_epoch,
57+
dataset_name_or_id = 'Task009_Spleen'
58+
fold = 0
59+
configuration = '3d_fullres'
60+
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)
7061
71-
```
62+
trainer = SupervisedTrainer(
63+
device=nnunet_trainer.device,
64+
max_epochs=nnunet_trainer.num_epochs,
65+
train_data_loader=nnunet_trainer.dataloader_train,
66+
network=nnunet_trainer.network,
67+
optimizer=nnunet_trainer.optimizer,
68+
loss_function=nnunet_trainer.loss_function,
69+
epoch_length=nnunet_trainer.num_iterations_per_epoch,
70+
)
7271
7372
Parameters
7473
----------
@@ -98,7 +97,7 @@ def get_nnunet_trainer(
9897
Whether to disable checkpointing. Default is False.
9998
val_with_best : bool, optional
10099
Whether to validate with the best model. Default is False.
101-
device : torch.device, optional
100+
device : str, optional
102101
The device to be used for training. Default is 'cuda'.
103102
pretrained_model : str, optional
104103
Path to the pretrained model file.
@@ -130,7 +129,7 @@ def get_nnunet_trainer(
130129
trainer_class_name,
131130
plans_identifier,
132131
use_compressed_data,
133-
device=device,
132+
device=torch.device(device),
134133
)
135134
if disable_checkpointing:
136135
nnunet_trainer.disable_checkpointing = disable_checkpointing
@@ -150,7 +149,7 @@ def get_nnunet_trainer(
150149
return nnunet_trainer
151150

152151

153-
class nnUNetMONAIModelWrapper(torch.nn.Module):
152+
class ModelnnUNetWrapper(torch.nn.Module):
154153
"""
155154
A wrapper class for nnUNet model integration with MONAI framework.
156155
The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
@@ -163,16 +162,14 @@ class nnUNetMONAIModelWrapper(torch.nn.Module):
163162
The folder path where the model and related files are stored.
164163
model_name : str, optional
165164
The name of the model file, by default "model.pt".
165+
166166
Attributes
167167
----------
168-
predictor : object
169-
The predictor object used for inference.
168+
predictor : nnUNetPredictor
169+
The nnUNet predictor object used for inference.
170170
network_weights : torch.nn.Module
171171
The network weights of the model.
172-
Methods
173-
-------
174-
forward(x)
175-
Perform forward pass and prediction on the input data.
172+
176173
Notes
177174
-----
178175
This class integrates nnUNet model with MONAI framework by loading necessary configurations,
@@ -184,23 +181,23 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
184181
self.predictor = predictor
185182

186183
model_training_output_dir = model_folder
187-
use_folds = "0"
184+
use_folds = ["0"]
188185

189186
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
190187

191-
## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
192-
dataset_json = load_json(join(model_training_output_dir, "dataset.json"))
193-
plans = load_json(join(model_training_output_dir, "plans.json"))
188+
# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
189+
dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json"))
190+
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
194191
plans_manager = PlansManager(plans)
195192

196193
if isinstance(use_folds, str):
197194
use_folds = [use_folds]
198195

199196
parameters = []
200197
for i, f in enumerate(use_folds):
201-
f = int(f) if f != "all" else f
198+
f = str(f) if f != "all" else f
202199
checkpoint = torch.load(
203-
join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
200+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
204201
)
205202
if i == 0:
206203
trainer_name = checkpoint["trainer_name"]
@@ -254,32 +251,67 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
254251
if (
255252
("nnUNet_compile" in os.environ.keys())
256253
and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t"))
257-
and not isinstance(predictor.network, OptimizedModule)
254+
# and not isinstance(predictor.network, OptimizedModule)
258255
):
259256
print("Using torch.compile")
260-
predictor.network = torch.compile(self.network)
261-
## End Block
257+
# predictor.network = torch.compile(self.network)
258+
# End Block
262259
self.network_weights = self.predictor.network
263260

264-
def forward(self, x):
265-
if type(x) is tuple: # if batch is decollated (list of tensors)
266-
input_files = [img.meta["filename_or_obj"][0] for img in x]
267-
else: # if batch is collated
268-
input_files = x.meta["filename_or_obj"]
269-
if type(input_files) is str:
270-
input_files = [input_files]
261+
def forward(self, x: MetaTensor) -> MetaTensor:
262+
"""
263+
Forward pass for the nnUNet model.
264+
265+
:no-index:
266+
267+
Args:
268+
x (MetaTensor): Input tensor. If the input is a tuple,
269+
it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
270+
271+
Returns:
272+
MetaTensor: The output tensor with the same metadata as the input.
273+
274+
Raises:
275+
TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.
276+
277+
Notes:
278+
- If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
279+
- If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
280+
- The filenames are used to generate predictions using the nnUNet predictor.
281+
- The predictions are converted to torch tensors, with added batch and channel dimensions.
282+
- The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
283+
"""
284+
# if isinstance(x, tuple): # if batch is decollated (list of tensors)
285+
# properties_or_list_of_properties = []
286+
# image_or_list_of_images = []
287+
288+
# for img in x:
289+
# if isinstance(img, MetaTensor):
290+
# properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()})
291+
# image_or_list_of_images.append(img.cpu().numpy()[0,:])
292+
# else:
293+
# raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
294+
295+
# else: # if batch is collated
296+
if isinstance(x, MetaTensor):
297+
if "pixdim" in x.meta:
298+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
299+
else:
300+
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
301+
else:
302+
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
303+
304+
image_or_list_of_images = x.cpu().numpy()[0, :]
271305

272306
# input_files should be a list of file paths, one per modality
273-
prediction_output = self.predictor.predict_from_files(
274-
[input_files],
307+
prediction_output = self.predictor.predict_from_list_of_npy_arrays(
308+
image_or_list_of_images,
275309
None,
310+
properties_or_list_of_properties,
311+
truncated_ofname=None,
276312
save_probabilities=False,
277-
overwrite=True,
278-
num_processes_preprocessing=2,
313+
num_processes=2,
279314
num_processes_segmentation_export=2,
280-
folder_with_segs_from_prev_stage=None,
281-
num_parts=1,
282-
part_id=0,
283315
)
284316
# prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
285317

@@ -288,35 +320,36 @@ def forward(self, x):
288320
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
289321
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
290322

291-
if type(x) is tuple:
292-
return MetaTensor(out_tensor, meta=x[0].meta)
293-
else:
294-
return MetaTensor(out_tensor, meta=x.meta)
323+
# if type(x) is tuple:
324+
# return MetaTensor(out_tensor, meta=x[0].meta)
325+
# else:
326+
return MetaTensor(out_tensor, meta=x.meta)
295327

296328

297329
def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
298330
"""
299-
Initializes and returns a nnUNetMONAIModelWrapper with a nnUNetPredictor.
331+
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
300332
The model folder should contain the following files, created during training:
301-
- dataset.json: from the nnUNet results folder.
302-
- plans.json: from the nnUNet results folder.
303-
- nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
304-
(`init_kwargs`, `trainer_name`, `inference_allowed_mirroring_axes`).
305-
- model.pt: The checkpoint file containing the model weights.
333+
334+
- dataset.json: from the nnUNet results folder
335+
- plans.json: from the nnUNet results folder
336+
- nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
337+
- model.pt: The checkpoint file containing the model weights.
306338
307339
The returned wrapper object can be used for inference with MONAI framework:
308-
```python
309-
from monai.bundle.nnunet import get_nnunet_monai_predictor
310340
311-
model_folder = 'path/to/monai_bundle/model'
312-
model_name = 'model.pt'
313-
wrapper = get_nnunet_monai_predictor(model_folder, model_name)
341+
Example::
342+
343+
from monai.bundle.nnunet import get_nnunet_monai_predictor
344+
345+
model_folder = 'path/to/monai_bundle/model'
346+
model_name = 'model.pt'
347+
wrapper = get_nnunet_monai_predictor(model_folder, model_name)
314348
315-
# Perform inference
316-
input_data = ...
317-
output = wrapper(input_data)
349+
# Perform inference
350+
input_data = ...
351+
output = wrapper(input_data)
318352
319-
```
320353
321354
Parameters
322355
----------
@@ -343,7 +376,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
343376
allow_tqdm=True,
344377
)
345378
# initializes the network architecture, loads the checkpoint
346-
wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)
379+
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
347380
return wrapper
348381

349382

@@ -396,13 +429,14 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
396429

397430
torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth"))
398431

432+
Path(bundle_root_folder).joinpath("models", f"fold_{fold}").mkdir(parents=True, exist_ok=True)
399433
monai_last_checkpoint = {}
400434
monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"]
401-
torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt"))
435+
torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "model.pt"))
402436

403437
monai_best_checkpoint = {}
404438
monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"]
405-
torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt"))
439+
torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", f"fold_{fold}", "best_model.pt"))
406440

407441
if not os.path.exists(os.path.join(bundle_root_folder, "models", "plans.json")):
408442
shutil.copy(

0 commit comments

Comments
 (0)