Skip to content

Commit 5c633f2

Browse files
1 parent 1972504 commit 5c633f2

1 file changed

Lines changed: 101 additions & 116 deletions

File tree

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 101 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,26 @@
2121
from monai.data.meta_tensor import MetaTensor
2222
from monai.utils import optional_import
2323

24+
from typing import Union, Optional
2425
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
2526
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
2627

2728
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "convert_monai_bundle_to_nnunet","ModelnnUNetWrapper"]
2829

2930

3031
def get_nnunet_trainer(
31-
dataset_name_or_id,
32-
configuration,
33-
fold,
34-
trainer_class_name="nnUNetTrainer",
35-
plans_identifier="nnUNetPlans",
36-
pretrained_weights=None,
37-
num_gpus=1,
38-
use_compressed_data=False,
39-
export_validation_probabilities=False,
40-
continue_training=False,
41-
only_run_validation=False,
42-
disable_checkpointing=False,
43-
val_with_best=False,
44-
device="cuda",
45-
pretrained_model=None,
46-
):
32+
dataset_name_or_id: Union[str, int],
33+
configuration: str,
34+
fold: Union[int, str],
35+
trainer_class_name: str = "nnUNetTrainer",
36+
plans_identifier: str = "nnUNetPlans",
37+
use_compressed_data: bool = False,
38+
continue_training: bool = False,
39+
only_run_validation: bool = False,
40+
disable_checkpointing: bool = False,
41+
device: str = "cuda",
42+
pretrained_model: Optional[str] = None,
43+
) -> object:
4744
"""
4845
Get the nnUNet trainer instance based on the provided configuration.
4946
The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
@@ -81,29 +78,22 @@ def get_nnunet_trainer(
8178
The class name of the trainer to be used. Default is 'nnUNetTrainer'.
8279
plans_identifier : str, optional
8380
Identifier for the plans to be used. Default is 'nnUNetPlans'.
84-
pretrained_weights : str, optional
85-
Path to the pretrained weights file.
86-
num_gpus : int, optional
87-
Number of GPUs to be used. Default is 1.
8881
use_compressed_data : bool, optional
8982
Whether to use compressed data. Default is False.
90-
export_validation_probabilities : bool, optional
91-
Whether to export validation probabilities. Default is False.
9283
continue_training : bool, optional
9384
Whether to continue training from a checkpoint. Default is False.
9485
only_run_validation : bool, optional
9586
Whether to only run validation. Default is False.
9687
disable_checkpointing : bool, optional
9788
Whether to disable checkpointing. Default is False.
98-
val_with_best : bool, optional
99-
Whether to validate with the best model. Default is False.
10089
device : str, optional
10190
The device to be used for training. Default is 'cuda'.
102-
pretrained_model : str, optional
91+
pretrained_model : Optional[str], optional
10392
Path to the pretrained model file.
93+
10494
Returns
10595
-------
106-
nnunet_trainer
96+
nnunet_trainer : object
10797
The nnUNet trainer instance.
10898
"""
10999
# From nnUNet/nnunetv2/run/run_training.py#run_training
@@ -117,36 +107,34 @@ def get_nnunet_trainer(
117107
)
118108
raise e
119109

120-
if int(num_gpus) > 1:
121-
... # Disable for now
122-
else:
123-
from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
124-
125-
nnunet_trainer = get_trainer_from_args(
126-
str(dataset_name_or_id),
127-
configuration,
128-
fold,
129-
trainer_class_name,
130-
plans_identifier,
131-
use_compressed_data,
132-
device=torch.device(device),
133-
)
134-
if disable_checkpointing:
135-
nnunet_trainer.disable_checkpointing = disable_checkpointing
136110

137-
assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy."
111+
from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
112+
113+
nnunet_trainer = get_trainer_from_args(
114+
str(dataset_name_or_id),
115+
configuration,
116+
fold,
117+
trainer_class_name,
118+
plans_identifier,
119+
use_compressed_data,
120+
device=torch.device(device),
121+
)
122+
if disable_checkpointing:
123+
nnunet_trainer.disable_checkpointing = disable_checkpointing
124+
125+
assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy."
138126

139-
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
140-
nnunet_trainer.on_train_start() # Added to Initialize Trainer
141-
if torch.cuda.is_available():
142-
cudnn.deterministic = False
143-
cudnn.benchmark = True
127+
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation)
128+
nnunet_trainer.on_train_start() # Added to Initialize Trainer
129+
if torch.cuda.is_available():
130+
cudnn.deterministic = False
131+
cudnn.benchmark = True
144132

145-
if pretrained_model is not None:
146-
state_dict = torch.load(pretrained_model)
147-
if "network_weights" in state_dict:
148-
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
149-
return nnunet_trainer
133+
if pretrained_model is not None:
134+
state_dict = torch.load(pretrained_model)
135+
if "network_weights" in state_dict:
136+
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
137+
return nnunet_trainer
150138

151139

152140
class ModelnnUNetWrapper(torch.nn.Module):
@@ -176,12 +164,11 @@ class ModelnnUNetWrapper(torch.nn.Module):
176164
restoring network architecture, and setting up the predictor for inference.
177165
"""
178166

179-
def __init__(self, predictor, model_folder, model_name="model.pt"):
167+
def __init__(self, predictor: object, model_folder: str, model_name: str = "model.pt"):
180168
super().__init__()
181169
self.predictor = predictor
182170

183171
model_training_output_dir = model_folder
184-
use_folds = ["0"]
185172

186173
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
187174

@@ -190,31 +177,26 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
190177
plans = load_json(join(Path(model_training_output_dir).parent, "plans.json"))
191178
plans_manager = PlansManager(plans)
192179

193-
if isinstance(use_folds, str):
194-
use_folds = [use_folds]
195-
196180
parameters = []
197-
for i, f in enumerate(use_folds):
198-
f = str(f) if f != "all" else f
199-
checkpoint = torch.load(
200-
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
181+
182+
checkpoint = torch.load(
183+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
184+
)
185+
trainer_name = checkpoint["trainer_name"]
186+
configuration_name = checkpoint["init_args"]["configuration"]
187+
inference_allowed_mirroring_axes = (
188+
checkpoint["inference_allowed_mirroring_axes"]
189+
if "inference_allowed_mirroring_axes" in checkpoint.keys()
190+
else None
191+
)
192+
if Path(model_training_output_dir).joinpath(model_name).is_file():
193+
monai_checkpoint = torch.load(
194+
join(model_training_output_dir, model_name), map_location=torch.device("cpu")
201195
)
202-
if i == 0:
203-
trainer_name = checkpoint["trainer_name"]
204-
configuration_name = checkpoint["init_args"]["configuration"]
205-
inference_allowed_mirroring_axes = (
206-
checkpoint["inference_allowed_mirroring_axes"]
207-
if "inference_allowed_mirroring_axes" in checkpoint.keys()
208-
else None
209-
)
210-
if Path(model_training_output_dir).joinpath(f"fold_{f}", model_name).is_file():
211-
monai_checkpoint = torch.load(
212-
join(model_training_output_dir, model_name), map_location=torch.device("cpu")
213-
)
214-
if "network_weights" in monai_checkpoint.keys():
215-
parameters.append(monai_checkpoint["network_weights"])
216-
else:
217-
parameters.append(monai_checkpoint)
196+
if "network_weights" in monai_checkpoint.keys():
197+
parameters.append(monai_checkpoint["network_weights"])
198+
else:
199+
parameters.append(monai_checkpoint)
218200

219201
configuration_manager = plans_manager.get_configuration(configuration_name)
220202
# restore network
@@ -251,10 +233,8 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
251233
if (
252234
("nnUNet_compile" in os.environ.keys())
253235
and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t"))
254-
# and not isinstance(predictor.network, OptimizedModule)
255236
):
256237
print("Using torch.compile")
257-
# predictor.network = torch.compile(self.network)
258238
# End Block
259239
self.network_weights = self.predictor.network
260240

@@ -281,21 +261,12 @@ def forward(self, x: MetaTensor) -> MetaTensor:
281261
- The predictions are converted to torch tensors, with added batch and channel dimensions.
282262
- The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
283263
"""
284-
# if isinstance(x, tuple): # if batch is decollated (list of tensors)
285-
# properties_or_list_of_properties = []
286-
# image_or_list_of_images = []
287-
288-
# for img in x:
289-
# if isinstance(img, MetaTensor):
290-
# properties_or_list_of_properties.append({"spacing": img.meta['pixdim'][0][1:4].numpy().tolist()})
291-
# image_or_list_of_images.append(img.cpu().numpy()[0,:])
292-
# else:
293-
# raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
294-
295-
# else: # if batch is collated
296264
if isinstance(x, MetaTensor):
297265
if "pixdim" in x.meta:
298266
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
267+
elif "affine" in x.meta:
268+
spacing = [abs(x.meta['affine'][0][0].item()), abs(x.meta['affine'][1][1].item()), abs(x.meta['affine'][2][2].item())]
269+
properties_or_list_of_properties = {"spacing": spacing}
299270
else:
300271
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
301272
else:
@@ -320,13 +291,10 @@ def forward(self, x: MetaTensor) -> MetaTensor:
320291
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
321292
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
322293

323-
# if type(x) is tuple:
324-
# return MetaTensor(out_tensor, meta=x[0].meta)
325-
# else:
326294
return MetaTensor(out_tensor, meta=x.meta)
327295

328296

329-
def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
297+
def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt") -> ModelnnUNetWrapper:
330298
"""
331299
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
332300
The model folder should contain the following files, created during training:
@@ -360,7 +328,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
360328
361329
Returns
362330
-------
363-
nnUNetMONAIModelWrapper
331+
ModelnnUNetWrapper
364332
A wrapper object that contains the nnUNetPredictor and the loaded model.
365333
"""
366334

@@ -380,7 +348,9 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
380348
return wrapper
381349

382350

383-
def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
351+
def convert_nnunet_to_monai_bundle(
352+
nnunet_config: dict, bundle_root_folder: str, fold: int = 0
353+
) -> None:
384354
"""
385355
Convert nnUNet model checkpoints and configuration to MONAI bundle format.
386356
@@ -450,7 +420,13 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
450420
)
451421

452422

453-
def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model_ckpt=None, model_key_in_ckpt="model"):
423+
def get_network_from_nnunet_plans(
424+
plans_file: str,
425+
dataset_file: str,
426+
configuration: str,
427+
model_ckpt: Optional[str] = None,
428+
model_key_in_ckpt: str = "model"
429+
) -> torch.nn.Module:
454430
"""
455431
Load and initialize a neural network based on nnUNet plans and configuration.
456432
@@ -462,7 +438,7 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model
462438
Path to the JSON file containing the dataset information.
463439
configuration : str
464440
The configuration name to be used from the plans.
465-
model_ckpt : str, optional
441+
model_ckpt : Optional[str], optional
466442
Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).
467443
model_key_in_ckpt : str, optional
468444
The key in the checkpoint file that contains the model state dictionary (default is "model").
@@ -505,7 +481,11 @@ def get_network_from_nnunet_plans(plans_file, dataset_file, configuration, model
505481
return network
506482

507483

508-
def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0):
484+
def convert_monai_bundle_to_nnunet(
485+
nnunet_config: dict,
486+
bundle_root_folder: str,
487+
fold: int = 0
488+
) -> None:
509489
"""
510490
Convert a MONAI bundle to nnU-Net format.
511491
@@ -527,8 +507,8 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0):
527507
"""
528508
from odict import odict
529509

530-
nnunet_trainer = "nnUNetTrainer"
531-
nnunet_plans = "nnUNetPlans"
510+
nnunet_trainer: str = "nnUNetTrainer"
511+
nnunet_plans: str = "nnUNetPlans"
532512

533513
if "nnunet_trainer" in nnunet_config:
534514
nnunet_trainer = nnunet_config["nnunet_trainer"]
@@ -539,8 +519,13 @@ def convert_monai_bundle_to_nnunet(nnunet_config, bundle_root_folder, fold=0):
539519
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
540520
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
541521

542-
def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True):
543-
522+
def subfiles(
523+
folder: str,
524+
join: bool = True,
525+
prefix: Optional[str] = None,
526+
suffix: Optional[str] = None,
527+
sort: bool = True
528+
) -> list[str]:
544529
if join:
545530
l = os.path.join # noqa: E741
546531
else:
@@ -556,42 +541,42 @@ def subfiles(folder, join: bool = True, prefix: str = None, suffix: str = None,
556541
res.sort()
557542
return res
558543

559-
nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath(
544+
nnunet_model_folder: Path = Path(os.environ["nnUNet_results"]).joinpath(
560545
maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]),
561546
f"{nnunet_trainer}__{nnunet_plans}__3d_fullres",
562547
)
563548

564-
nnunet_preprocess_model_folder = Path(os.environ["nnUNet_preprocessed"]).joinpath(
549+
nnunet_preprocess_model_folder: Path = Path(os.environ["nnUNet_preprocessed"]).joinpath(
565550
maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
566551
)
567552

568553
Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
569554

570-
nnunet_checkpoint = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
571-
latest_checkpoints = subfiles(
555+
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
556+
latest_checkpoints: list[str] = subfiles(
572557
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True, join=False
573558
)
574-
epochs = []
559+
epochs: list[int] = []
575560
for latest_checkpoint in latest_checkpoints:
576561
epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")]))
577562

578563
epochs.sort()
579-
final_epoch = epochs[-1]
580-
monai_last_checkpoint = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt")
564+
final_epoch: int = epochs[-1]
565+
monai_last_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt")
581566

582-
best_checkpoints = subfiles(
567+
best_checkpoints: list[str] = subfiles(
583568
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"),
584569
prefix="checkpoint_key_metric",
585570
sort=True,
586571
join=False,
587572
)
588-
key_metrics = []
573+
key_metrics: list[str] = []
589574
for best_checkpoint in best_checkpoints:
590575
key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric=") : -len(".pt")]))
591576

592577
key_metrics.sort()
593-
best_key_metric = key_metrics[-1]
594-
monai_best_checkpoint = torch.load(
578+
best_key_metric: str = key_metrics[-1]
579+
monai_best_checkpoint: dict = torch.load(
595580
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
596581
)
597582

0 commit comments

Comments
 (0)