Skip to content

Commit ca851cd

Browse files
Add modality_list parameter to nnUNetExecutor and related functions
1 parent 1a30a0b commit ca851cd

3 files changed

Lines changed: 15 additions & 1 deletion

File tree

monai/nvflare/nnunet_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class nnUNetExecutor(Executor):
7777
Extra configurations for training.
7878
exclude_vars : list, optional
7979
List of variables to exclude.
80+
modality_list : list, optional
81+
List of modalities.
8082
8183
Methods
8284
-------
@@ -119,6 +121,7 @@ def __init__(
119121
tracking_uri=None,
120122
mlflow_token=None,
121123
bundle_root=None,
124+
modality_list=None,
122125
train_extra_configs=None,
123126
exclude_vars=None,
124127
):
@@ -143,6 +146,7 @@ def __init__(
143146
self.prepare_bundle_name = prepare_bundle_name
144147
self.bundle_root = bundle_root
145148
self.train_extra_configs = train_extra_configs
149+
self.modality_list = modality_list
146150

147151
def handle_event(self, event_type: str, fl_ctx: FLContext):
148152
if event_type == EventType.START_RUN:
@@ -211,6 +215,7 @@ def prepare_dataset(self) -> Shareable:
211215
mlflow_token=self.mlflow_token,
212216
subfolder_suffix=self.subfolder_suffix,
213217
trainer_class_name=nnunet_trainer_name,
218+
modality_list=self.modality_list,
214219
)
215220

216221
outgoing_dxo = DXO(data_kind=DataKind.COLLECTION, data=data_list, meta={})

monai/nvflare/nvflare_generate_job_configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def prepare_config(clients, experiment, root_dir, script_dir, nvflare_exec):
135135
],
136136
}
137137

138+
if "modality_list" in experiment["modality_list"]:
139+
client["executors"][0]["executor"]["args"]["modality_list"] = experiment["modality_list"]
140+
138141
if "subfolder_suffix" in clients[client_id]:
139142
client["executors"][0]["executor"]["args"]["subfolder_suffix"] = clients[client_id]["subfolder_suffix"]
140143
if "mlflow_token" in experiment:

monai/nvflare/nvflare_nnunet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def prepare_data_folder(
305305
experiment_name,
306306
client_name,
307307
dataset_format,
308+
modality_list = None,
308309
tracking_uri=None,
309310
mlflow_token=None,
310311
subfolder_suffix=None,
@@ -332,6 +333,8 @@ def prepare_data_folder(
332333
Format of the dataset. Supported formats are "subfolders", "decathlon", and "nnunet".
333334
tracking_uri : str, optional
334335
URI for MLflow tracking server.
336+
modality_list : list, optional
337+
List of modalities. Default is None.
335338
mlflow_token : str, optional
336339
Token for MLflow authentication.
337340
subfolder_suffix : str, optional
@@ -438,9 +441,12 @@ def prepare_data_folder(
438441

439442
os.makedirs(nnunet_root_dir, exist_ok=True)
440443

444+
if modality_list is None:
445+
modality_list = [k for k in modality_dict.keys() if k != "label"]
446+
441447
data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml")
442448
data_src = {
443-
"modality": [k for k in modality_dict.keys() if k != "label"],
449+
"modality": modality_list,
444450
"dataset_name_or_id": dataset_name_or_id,
445451
"datalist": str(datalist_file),
446452
"dataroot": str(data_dir),

0 commit comments

Comments
 (0)