diff --git a/distillation/fast_nnunet_distillation_export_onnx.py b/distillation/fast_nnunet_distillation_export_onnx.py index 0faac96..519163e 100755 --- a/distillation/fast_nnunet_distillation_export_onnx.py +++ b/distillation/fast_nnunet_distillation_export_onnx.py @@ -146,16 +146,28 @@ def export_to_onnx(dataset_id, device=device ) - # Initialize network architecture (full trainer initialization not needed) + # Initialize network architecture (full trainer initialization not needed). + # determine_num_input_channels handles cascade (previous stage adds extra channels). num_input_channels = determine_num_input_channels(trainer.plans_manager, trainer.configuration_manager, dataset_json) num_output_channels = trainer.label_manager.num_segmentation_heads + # Pick conv/norm ops based on patch_size dim so 2d configurations work + dim = len(trainer.configuration_manager.patch_size) + if dim == 2: + conv_op_cls = torch.nn.Conv2d + norm_op_cls = torch.nn.InstanceNorm2d + elif dim == 3: + conv_op_cls = torch.nn.Conv3d + norm_op_cls = torch.nn.InstanceNorm3d + else: + raise ValueError(f"Unsupported patch_size dimensionality: {dim} (expected 2 or 3)") + # if nnunet_style is True, force using single channel input if nnunet_style: num_input_channels = 1 print(f"Using single channel fixed size mode, force using 1 input channel") - - print(f"Number of input channels: {num_input_channels}, Number of output channels: {num_output_channels}") + + print(f"Configuration: {configuration} ({dim}D). Input channels: {num_input_channels}, output channels: {num_output_channels}") # Get network architecture parameters if configuration in plans['configurations']: @@ -171,42 +183,50 @@ def export_to_onnx(dataset_id, if verbose: print(f"Found architecture info: {arch_info}") + # Per-dim defaults used only when plans omit kernel_sizes/strides (rare). + default_kernel = [3] * dim + default_stride_unit = [2] * dim + default_stride_first = [1] * dim + # Extract parameters from architecture if isinstance(arch_info, dict) and 'arch_kwargs' in arch_info: arch_kwargs = arch_info['arch_kwargs'] n_stages = arch_kwargs.get('n_stages', 6) features_per_stage = arch_kwargs.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages]) - strides = arch_kwargs.get('strides', [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1)) - kernel_sizes = arch_kwargs.get('kernel_sizes', [[3, 3, 3]] * n_stages) - + strides = arch_kwargs.get('strides', [default_stride_first] + [default_stride_unit] * (n_stages - 1)) + kernel_sizes = arch_kwargs.get('kernel_sizes', [default_kernel] * n_stages) + if verbose: print(f"Extracted from architecture: n_stages={n_stages}, features={features_per_stage}") else: # Fall back to defaults n_stages = 6 features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages] - strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1) - kernel_sizes = [[3, 3, 3]] * n_stages + strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1) + kernel_sizes = [default_kernel] * n_stages else: + default_kernel = [3] * dim + default_stride_unit = [2] * dim + default_stride_first = [1] * dim # Get number of stages and features - handle different plans formats if 'pool_op_kernel_sizes' in config: n_stages = len(config['pool_op_kernel_sizes']) features_per_stage = config.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages]) strides = config['pool_op_kernel_sizes'] - kernel_sizes = config.get('conv_kernel_sizes', [[3, 3, 3]] * n_stages) + kernel_sizes = config.get('conv_kernel_sizes', [default_kernel] * n_stages) elif 'architecture_kwargs' in config and 'arch_kwargs' in config['architecture_kwargs']: arch_kwargs = config['architecture_kwargs']['arch_kwargs'] n_stages = arch_kwargs.get('n_stages', 6) features_per_stage = arch_kwargs.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages]) - strides = arch_kwargs.get('strides', [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1)) - kernel_sizes = arch_kwargs.get('kernel_sizes', [[3, 3, 3]] * n_stages) + strides = arch_kwargs.get('strides', [default_stride_first] + [default_stride_unit] * (n_stages - 1)) + kernel_sizes = arch_kwargs.get('kernel_sizes', [default_kernel] * n_stages) else: # If configuration not found, use default values print("Warning: Cannot find complete network architecture configuration in plans, using defaults") n_stages = 6 features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages] - strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1) - kernel_sizes = [[3, 3, 3]] * n_stages + strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1) + kernel_sizes = [default_kernel] * n_stages # Apply feature reduction factor lite_features_per_stage = [max(f // feature_reduction_factor, 8) for f in features_per_stage] @@ -251,13 +271,13 @@ def export_to_onnx(dataset_id, num_classes=num_output_channels, n_stages=n_stages, features_per_stage=lite_features_per_stage, - conv_op=torch.nn.Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes], strides=[tuple(p) if isinstance(p, list) else p for p in strides], n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage n_conv_per_stage_decoder=[2] * (n_stages - 1), conv_bias=True, - norm_op=torch.nn.InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs={"eps": 1e-5, "affine": True}, nonlin=torch.nn.LeakyReLU, nonlin_kwargs={"inplace": True}, @@ -276,13 +296,13 @@ def export_to_onnx(dataset_id, num_classes=num_output_channels, n_stages=n_stages, features_per_stage=lite_features_per_stage, - conv_op=torch.nn.Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes], strides=[tuple(p) if isinstance(p, list) else p for p in strides], n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage n_conv_per_stage_decoder=[2] * (n_stages - 1), conv_bias=True, - norm_op=torch.nn.InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs={"eps": 1e-5, "affine": True}, nonlin=torch.nn.LeakyReLU, nonlin_kwargs={"inplace": True}, @@ -296,13 +316,13 @@ def export_to_onnx(dataset_id, num_classes=num_output_channels, n_stages=n_stages, features_per_stage=lite_features_per_stage, - conv_op=torch.nn.Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes], strides=[tuple(p) if isinstance(p, list) else p for p in strides], n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage n_conv_per_stage_decoder=[2] * (n_stages - 1), conv_bias=True, - norm_op=torch.nn.InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs={"eps": 1e-5, "affine": True}, nonlin=torch.nn.LeakyReLU, nonlin_kwargs={"inplace": True}, @@ -384,7 +404,7 @@ def forward(self, x): wrapped_model.eval() model = wrapped_model # Replace model with wrapped version - # Prepare export path + # Prepare export path (include configuration so 2d/3d_lowres/cascade don't overwrite each other) if output_path is None: output_dir = join(model_folder_fold, "exported_models") os.makedirs(output_dir, exist_ok=True) @@ -392,50 +412,52 @@ def forward(self, x): # add special identifier for single channel fixed size model da5_suffix = "_da5" if use_da5 else "" if nnunet_style: - output_path = join(output_dir, f"model_fold{fold}_{feature_reduction_factor}x{da5_suffix}_nnunet_format.onnx") + output_path = join(output_dir, f"model_{configuration}_fold{fold}_{feature_reduction_factor}x{da5_suffix}_nnunet_format.onnx") else: - output_path = join(output_dir, f"model_fold{fold}_{feature_reduction_factor}x{da5_suffix}.onnx") - - # Create input sample - use typical 3D medical image shape or custom shape + output_path = join(output_dir, f"model_{configuration}_fold{fold}_{feature_reduction_factor}x{da5_suffix}.onnx") + + # Create input sample - shape follows patch_size dimensionality if input_shape is None: - # Default size estimation (assuming input patch_size is the same as during training) if 'patch_size' in config: patch_size = config['patch_size'] input_shape = (1, num_input_channels, *patch_size) else: - # Get patch_size from trainer try: patch_size = trainer.configuration_manager.patch_size input_shape = (1, num_input_channels, *patch_size) - except: - # Default shape - input_shape = (1, num_input_channels, 128, 128, 128) + except Exception: + # Default shape per dim + input_shape = (1, num_input_channels, 256, 256) if dim == 2 else (1, num_input_channels, 128, 128, 128) - # if nnunet_style is True, force using single channel fixed size input + # if nnunet_style is True, force single channel input with fixed spatial size matching dim if nnunet_style: - # force using single channel and fixed size - if input_shape is None or len(input_shape) != 5: - # use default fixed size - input_shape = (1, 1, 128, 128, 128) + expected_len = 2 + dim # batch + channel + spatial + if input_shape is None or len(input_shape) != expected_len: + input_shape = (1, 1, 256, 256) if dim == 2 else (1, 1, 128, 128, 128) else: - # keep original spatial dimensions, but set channel number to 1 - input_shape = (input_shape[0], 1, input_shape[2], input_shape[3], input_shape[4]) - + spatial = tuple(input_shape[2:]) + input_shape = (input_shape[0], 1, *spatial) + print(f"Using single channel fixed size mode, input shape: {input_shape}") - + # Use randn instead of zeros for better InstanceNorm behavior torch.manual_seed(42) # For reproducibility dummy_input = torch.randn(input_shape, dtype=torch.float32).to(device) - - print(f"🔄 Exporting to ONNX (input: {input_shape})...") - - # Set dynamic axes for ONNX export + + print(f"Exporting to ONNX (input: {input_shape}, {dim}D)...") + + # Set dynamic axes for ONNX export — spatial axis count depends on dim if dynamic_axes and not nnunet_style: - # Batch size and spatial dimensions are dynamic - dynamic_axes_dict = { - 'input': {0: 'batch_size', 2: 'height', 3: 'width', 4: 'depth'}, - 'output': {0: 'batch_size', 2: 'height', 3: 'width', 4: 'depth'} - } + if dim == 2: + dynamic_axes_dict = { + 'input': {0: 'batch_size', 2: 'height', 3: 'width'}, + 'output': {0: 'batch_size', 2: 'height', 3: 'width'}, + } + else: + dynamic_axes_dict = { + 'input': {0: 'batch_size', 2: 'depth', 3: 'height', 4: 'width'}, + 'output': {0: 'batch_size', 2: 'depth', 3: 'height', 4: 'width'}, + } elif dynamic_axes and nnunet_style: # only batch size is dynamic, spatial dimensions are fixed dynamic_axes_dict = { @@ -561,7 +583,7 @@ def forward(self, x): def main(): parser = argparse.ArgumentParser(description='Export Fast nnUNet distillation student model to ONNX format') parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID (e.g., 776)') - parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration (default: 3d_fullres)') + parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)') parser.add_argument('-f', '--fold', type=int, default=0, help='Model training fold number (default: 0)') parser.add_argument('-r', '--reduction_factor', type=int, default=2, help='Feature reduction factor (default: 2)') parser.add_argument('-cp', '--checkpoint', type=str, default='checkpoint_final.pth', help='Checkpoint filename (default: checkpoint_final.pth)') diff --git a/distillation/fast_nnunet_distillation_train.py b/distillation/fast_nnunet_distillation_train.py index abb7fb1..0c2f43a 100755 --- a/distillation/fast_nnunet_distillation_train.py +++ b/distillation/fast_nnunet_distillation_train.py @@ -255,7 +255,7 @@ def main(): # Create command line argument parser parser = argparse.ArgumentParser(description='nnUNetv2 Knowledge Distillation Training') parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID (e.g., 776)') - parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration (default: 3d_fullres)') + parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)') parser.add_argument('-f', '--start_fold', type=int, default=0, help='Start fold number for training (default: 0)') parser.add_argument('-t', '--teacher_model_folder', type=str, help='Path to teacher model folder (if not provided, will be auto-constructed)') parser.add_argument('-tf', '--teacher_folds', type=int, nargs='+', diff --git a/distillation/fast_nnunet_resenc_distillation_export_onnx.py b/distillation/fast_nnunet_resenc_distillation_export_onnx.py index 6f448cd..e4e4282 100644 --- a/distillation/fast_nnunet_resenc_distillation_export_onnx.py +++ b/distillation/fast_nnunet_resenc_distillation_export_onnx.py @@ -246,17 +246,35 @@ def load_model_from_checkpoint(checkpoint_path, plans, dataset_json, configurati trainer.student_plans_identifier = student_plans_identifier # Get network architecture parameters + # determine_num_input_channels handles cascade (extra channels from previous stage) num_input_channels = determine_num_input_channels(trainer.plans_manager, trainer.configuration_manager, dataset_json) num_output_channels = trainer.label_manager.num_segmentation_heads - + + # Pick conv/norm ops based on patch_size dim so 2d configurations work + dim = len(trainer.configuration_manager.patch_size) + if dim == 2: + conv_op_cls = torch.nn.Conv2d + norm_op_cls = torch.nn.InstanceNorm2d + elif dim == 3: + conv_op_cls = torch.nn.Conv3d + norm_op_cls = torch.nn.InstanceNorm3d + else: + raise ValueError(f"Unsupported patch_size dimensionality: {dim} (expected 2 or 3)") + if verbose: + print(f"Spatial dim: {dim}D (configuration={configuration})") print(f"Number of input channels: {num_input_channels}") print(f"Number of output channels: {num_output_channels}") - + # Get architecture configuration if configuration in plans['configurations']: config = plans['configurations'][configuration] - + + # Per-dim defaults used only when plans omit kernel_sizes/strides + default_kernel = [3] * dim + default_stride_unit = [2] * dim + default_stride_first = [1] * dim + # Extract network parameters if 'architecture' in config: arch_info = config['architecture'] @@ -264,15 +282,15 @@ def load_model_from_checkpoint(checkpoint_path, plans, dataset_json, configurati arch_kwargs = arch_info['arch_kwargs'] n_stages = arch_kwargs.get('n_stages', 6) features_per_stage = arch_kwargs.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages]) - strides = arch_kwargs.get('strides', [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1)) - kernel_sizes = arch_kwargs.get('kernel_sizes', [[3, 3, 3]] * n_stages) + strides = arch_kwargs.get('strides', [default_stride_first] + [default_stride_unit] * (n_stages - 1)) + kernel_sizes = arch_kwargs.get('kernel_sizes', [default_kernel] * n_stages) n_blocks_per_stage = arch_kwargs.get('n_blocks_per_stage', [1, 3, 4, 6, 6, 6][:n_stages]) else: # Fall back to defaults n_stages = 6 features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages] - strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1) - kernel_sizes = [[3, 3, 3]] * n_stages + strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1) + kernel_sizes = [default_kernel] * n_stages n_blocks_per_stage = [1, 3, 4, 6, 6, 6][:n_stages] else: # Get from other configuration keys @@ -280,15 +298,15 @@ def load_model_from_checkpoint(checkpoint_path, plans, dataset_json, configurati n_stages = len(config['pool_op_kernel_sizes']) features_per_stage = config.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages]) strides = config['pool_op_kernel_sizes'] - kernel_sizes = config.get('conv_kernel_sizes', [[3, 3, 3]] * n_stages) + kernel_sizes = config.get('conv_kernel_sizes', [default_kernel] * n_stages) n_blocks_per_stage = [1, 3, 4, 6, 6, 6][:n_stages] else: # Use default values print("Warning: Cannot find complete network architecture configuration in plans, using defaults") n_stages = 6 features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages] - strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1) - kernel_sizes = [[3, 3, 3]] * n_stages + strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1) + kernel_sizes = [default_kernel] * n_stages n_blocks_per_stage = [1, 3, 4, 6, 6, 6][:n_stages] # Apply feature reduction factor @@ -346,13 +364,13 @@ def load_model_from_checkpoint(checkpoint_path, plans, dataset_json, configurati num_classes=num_output_channels, n_stages=n_stages, features_per_stage=lite_features_per_stage, - conv_op=torch.nn.Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes], strides=[tuple(p) if isinstance(p, list) else p for p in strides], n_blocks_per_stage=lite_n_blocks_per_stage, n_conv_per_stage_decoder=[1] * (n_stages - 1), conv_bias=True, - norm_op=torch.nn.InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs={"eps": 1e-5, "affine": True}, nonlin=torch.nn.LeakyReLU, nonlin_kwargs={"inplace": True}, @@ -365,13 +383,13 @@ def load_model_from_checkpoint(checkpoint_path, plans, dataset_json, configurati num_classes=num_output_channels, n_stages=n_stages, features_per_stage=lite_features_per_stage, - conv_op=torch.nn.Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes], strides=[tuple(p) if isinstance(p, list) else p for p in strides], n_conv_per_stage=[2] * n_stages, n_conv_per_stage_decoder=[2] * (n_stages - 1), conv_bias=True, - norm_op=torch.nn.InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs={"eps": 1e-5, "affine": True}, nonlin=torch.nn.LeakyReLU, nonlin_kwargs={"inplace": True}, @@ -556,20 +574,22 @@ def export_resenc_distillation_to_onnx(dataset_id, maybe_mkdir_p(output_dir) - # Get input shape + # Get input shape — fall back to dim-appropriate default if patch_size missing if 'patch_size' in config: patch_size = config['patch_size'] + dim = len(patch_size) input_shape = (1, num_input_channels, *patch_size) else: - # Default shape + # No patch_size found; default to 3D as historically (most uses cases) + dim = 3 input_shape = (1, num_input_channels, 128, 128, 128) - - # Determine model type for filename + + # Determine model type for filename (include configuration so 2d/3d_lowres/cascade don't overwrite) is_resenc_student = 'ResEnc' in student_plans_identifier model_type = "resenc" if is_resenc_student else "unet" if use_da5: model_type += "_da5" - + # Create dummy input # IMPORTANT: Use randn (normal distribution) instead of zeros for better InstanceNorm behavior # InstanceNorm computes statistics from the input, so using realistic data distribution helps @@ -578,7 +598,7 @@ def export_resenc_distillation_to_onnx(dataset_id, # Fixed batch size dummy_input = torch.randn((batch_size, num_input_channels, *input_shape[2:]), dtype=torch.float32).to(device) dynamic_axes = None - onnx_filename = f"{model_type}_distillation_fold{fold}_batch{batch_size}_r{feature_reduction_factor}.onnx" + onnx_filename = f"{model_type}_distillation_{configuration}_fold{fold}_batch{batch_size}_r{feature_reduction_factor}.onnx" else: # Dynamic batch size dummy_input = torch.randn(input_shape, dtype=torch.float32).to(device) @@ -586,7 +606,7 @@ def export_resenc_distillation_to_onnx(dataset_id, 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } - onnx_filename = f"{model_type}_distillation_fold{fold}_dynamic_r{feature_reduction_factor}.onnx" + onnx_filename = f"{model_type}_distillation_{configuration}_fold{fold}_dynamic_r{feature_reduction_factor}.onnx" output_path = join(output_dir, onnx_filename) @@ -740,7 +760,7 @@ def main(): parser = argparse.ArgumentParser(description='Export Fast nnUNet ResEnc distillation model to ONNX format') parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID') parser.add_argument('-o', '--output_dir', type=str, help='Output directory') - parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='Configuration name') + parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)') parser.add_argument('-f', '--fold', type=int, default=0, help='Fold number') parser.add_argument('-b', '--batch_size', type=int, default=0, help='Batch size, 0 means dynamic') parser.add_argument('-cp', '--checkpoint', type=str, default='checkpoint_final.pth', help='Checkpoint filename') diff --git a/distillation/fast_nnunet_resenc_distillation_train.py b/distillation/fast_nnunet_resenc_distillation_train.py index c9da5d6..f329760 100644 --- a/distillation/fast_nnunet_resenc_distillation_train.py +++ b/distillation/fast_nnunet_resenc_distillation_train.py @@ -282,7 +282,7 @@ def main(): # Create command line argument parser parser = argparse.ArgumentParser(description='nnUNetv2 ResEnc Knowledge Distillation Training') parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID (e.g., 793)') - parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration (default: 3d_fullres)') + parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)') parser.add_argument('-f', '--fold', type=int, default=0, help='Fold number for training (default: 0)') parser.add_argument('-t', '--teacher_model_folder', type=str, help='ResEnc teacher model folder path (if not provided, will be auto-constructed)') parser.add_argument('-tf', '--teacher_folds', type=int, nargs='+', diff --git a/distillation/nnunetv2/training/nnUNetTrainer/variants/nnUNetDistillationTrainer.py b/distillation/nnunetv2/training/nnUNetTrainer/variants/nnUNetDistillationTrainer.py index 4825880..8e64c39 100644 --- a/distillation/nnunetv2/training/nnUNetTrainer/variants/nnUNetDistillationTrainer.py +++ b/distillation/nnunetv2/training/nnUNetTrainer/variants/nnUNetDistillationTrainer.py @@ -24,7 +24,32 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn import Conv3d, InstanceNorm3d, LeakyReLU, ConvTranspose3d +from torch.nn import ( + Conv2d, Conv3d, + InstanceNorm2d, InstanceNorm3d, + LeakyReLU, + ConvTranspose2d, ConvTranspose3d, +) + + +def _dim_from_conv_op(conv_op): + """Return spatial dim (2 or 3) corresponding to a torch conv module class.""" + if conv_op is Conv2d: + return 2 + if conv_op is Conv3d: + return 3 + raise ValueError( + f"Unsupported conv_op: {conv_op}; expected torch.nn.Conv2d or torch.nn.Conv3d" + ) + + +def _default_norm_op_for(conv_op): + """Pick the matching InstanceNorm class for the given conv_op.""" + if conv_op is Conv2d: + return InstanceNorm2d + if conv_op is Conv3d: + return InstanceNorm3d + raise ValueError(f"Unsupported conv_op: {conv_op}") from torch import GradScaler from collections import OrderedDict import numpy as np @@ -86,7 +111,7 @@ def __init__(self, n_conv_per_stage: list = None, n_conv_per_stage_decoder: list = None, conv_bias: bool = True, - norm_op: type = InstanceNorm3d, + norm_op: type = None, norm_op_kwargs: dict = None, dropout_op: type = None, dropout_op_kwargs: dict = None, @@ -95,7 +120,12 @@ def __init__(self, deep_supervision: bool = True ): super().__init__() - + + # Defaults that depend on the spatial dim (2D vs 3D) inferred from conv_op + dim = _dim_from_conv_op(conv_op) + if norm_op is None: + norm_op = _default_norm_op_for(conv_op) + # Parameter settings if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True} @@ -104,13 +134,13 @@ def __init__(self, if features_per_stage is None: features_per_stage = [32, 64, 128, 256, 320, 320] if kernel_sizes is None: - kernel_sizes = [(3, 3, 3)] * n_stages + kernel_sizes = [(3,) * dim] * n_stages if n_conv_per_stage is None: n_conv_per_stage = [2] * n_stages if n_conv_per_stage_decoder is None: n_conv_per_stage_decoder = [2] * (n_stages - 1) if strides is None: - strides = [(1, 1, 1)] + [(2, 2, 2)] * (n_stages - 1) + strides = [(1,) * dim] + [(2,) * dim] * (n_stages - 1) # Check if parameter lengths match if not (len(features_per_stage) == n_stages and len(kernel_sizes) == n_stages and @@ -192,7 +222,7 @@ def __init__(self, n_blocks_per_stage: list = None, n_conv_per_stage_decoder: list = None, conv_bias: bool = True, - norm_op: type = InstanceNorm3d, + norm_op: type = None, norm_op_kwargs: dict = None, dropout_op: type = None, dropout_op_kwargs: dict = None, @@ -201,7 +231,12 @@ def __init__(self, deep_supervision: bool = True ): super().__init__() - + + # Defaults that depend on the spatial dim (2D vs 3D) inferred from conv_op + dim = _dim_from_conv_op(conv_op) + if norm_op is None: + norm_op = _default_norm_op_for(conv_op) + # Parameter settings if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True} @@ -210,14 +245,14 @@ def __init__(self, if features_per_stage is None: features_per_stage = [32, 64, 128, 256, 320, 320] if isinstance(kernel_sizes, int): - kernel_sizes = [kernel_sizes] * n_stages + kernel_sizes = [(kernel_sizes,) * dim] * n_stages if n_blocks_per_stage is None: # Reduced from ResEnc's (1, 3, 4, 6, 6, 6) to lighter version n_blocks_per_stage = [1, 2, 2, 3, 3, 3][:n_stages] if n_conv_per_stage_decoder is None: n_conv_per_stage_decoder = [1] * (n_stages - 1) if strides is None: - strides = [(1, 1, 1)] + [(2, 2, 2)] * (n_stages - 1) + strides = [(1,) * dim] + [(2,) * dim] * (n_stages - 1) # Check if parameter lengths match if not (len(features_per_stage) == n_stages and len(kernel_sizes) == n_stages and @@ -619,12 +654,26 @@ def build_network_architecture(self, else: self.print_to_log_file("Building lightweight standard UNet student model...") - # If input/output channel numbers are not provided, obtain them from plans + # If input/output channel numbers are not provided, obtain them from plans. + # determine_num_input_channels already accounts for cascade (previous stage adds + # num_classes extra input channels), so 3d_cascade_fullres works without changes. if num_input_channels is None: num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json) if num_output_channels is None: num_output_channels = self.label_manager.num_segmentation_heads - + + # Pick conv/norm ops based on patch_size dimensionality so 2d configurations work + dim = len(self.configuration_manager.patch_size) + if dim == 2: + conv_op_cls = Conv2d + norm_op_cls = InstanceNorm2d + elif dim == 3: + conv_op_cls = Conv3d + norm_op_cls = InstanceNorm3d + else: + raise ValueError(f"Unsupported patch_size dimensionality: {dim} (expected 2 or 3)") + self.print_to_log_file(f"Spatial dimensionality: {dim}D (configuration={self.configuration_name})") + # Check if there is a new format architecture field, if not, derive from configuration if 'architecture' in self.configuration_manager.configuration and 'arch_kwargs' in self.configuration_manager.configuration['architecture']: # New format plans @@ -635,20 +684,19 @@ def build_network_architecture(self, self.print_to_log_file("Detected old version plans format, manually build network parameters") # Get necessary parameters from configuration - dim = len(self.configuration_manager.patch_size) n_stages = len(self.configuration_manager.pool_op_kernel_sizes) + 1 - + # Base feature number and per stage feature number unet_max_num_features = self.plans_manager.plans.get('unet_max_num_features', 320) base_num_features = self.configuration_manager.configuration.get('UNet_base_num_features', 32) - + # Calculate feature number for each stage - features_per_stage = [min(base_num_features * 2 ** i, unet_max_num_features) + features_per_stage = [min(base_num_features * 2 ** i, unet_max_num_features) for i in range(n_stages)] - - # Get other network parameters + + # Get other network parameters (kernel default shape follows dim) conv_kernel_sizes = self.configuration_manager.configuration.get( - 'conv_kernel_sizes', [[3,3,3]] * n_stages) + 'conv_kernel_sizes', [[3] * dim] * n_stages) # Build pool kernel size list, need to add a starting (1,1,1) pool_op_kernel_sizes = [(1,)*dim] @@ -716,13 +764,13 @@ def build_network_architecture(self, num_classes=num_output_channels, n_stages=plan_arch["n_stages"], features_per_stage=lite_features_per_stage, - conv_op=Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(ks) if not isinstance(ks[0], (list, tuple)) else tuple(ks[0]) for ks in plan_arch["kernel_sizes"]], strides=[tuple(st) for st in plan_arch["strides"]], n_blocks_per_stage=lite_n_blocks_per_stage, n_conv_per_stage_decoder=plan_arch["n_conv_per_stage_decoder"], conv_bias=plan_arch["conv_bias"], - norm_op=InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs=plan_arch["norm_op_kwargs"], nonlin=LeakyReLU, nonlin_kwargs=plan_arch["nonlin_kwargs"], @@ -735,13 +783,13 @@ def build_network_architecture(self, num_classes=num_output_channels, n_stages=plan_arch["n_stages"], features_per_stage=lite_features_per_stage, - conv_op=Conv3d, + conv_op=conv_op_cls, kernel_sizes=[tuple(ks) if not isinstance(ks[0], (list, tuple)) else tuple(ks[0]) for ks in plan_arch["kernel_sizes"]], strides=[tuple(st) for st in plan_arch["strides"]], n_conv_per_stage=plan_arch["n_conv_per_stage"], n_conv_per_stage_decoder=plan_arch["n_conv_per_stage_decoder"], conv_bias=plan_arch["conv_bias"], - norm_op=InstanceNorm3d, + norm_op=norm_op_cls, norm_op_kwargs=plan_arch["norm_op_kwargs"], nonlin=LeakyReLU, nonlin_kwargs=plan_arch["nonlin_kwargs"], diff --git a/distillation/setup.py b/distillation/setup.py index be46694..4df881d 100644 --- a/distillation/setup.py +++ b/distillation/setup.py @@ -2,7 +2,7 @@ setup( name="nnunetv2_distillation", - version="1.2.2", + version="1.2.3", packages=find_packages(), install_requires=[ "torch>=1.6.0", diff --git a/docs/Distillation.md b/docs/Distillation.md index 8952ce3..0f84b50 100644 --- a/docs/Distillation.md +++ b/docs/Distillation.md @@ -105,7 +105,35 @@ nnUNetv2_train DATASET_ID 3d_fullres 4 -p nnUNetResEncUNetMPlans/nnUNetResEncUNe ### 2. Knowledge Distillation Training -Use the trained teacher models for knowledge distillation: +Use the trained teacher models for knowledge distillation. + +#### Supported nnUNet Configurations + +All four standard nnUNetv2 configurations are supported via `-c / --configuration`: + +| Configuration | Description | Notes | +| --- | --- | --- | +| `2d` | 2D slice-based U-Net | Uses Conv2d / InstanceNorm2d automatically | +| `3d_lowres` | Low-resolution 3D U-Net | Independent stage, also used as cascade stage 1 | +| `3d_fullres` | Full-resolution 3D U-Net (default) | Most common choice | +| `3d_cascade_fullres` | High-res second stage of the cascade | Reads `predicted_next_stage/` from the corresponding `3d_lowres` model | + +The student network's spatial dimensionality (2D vs 3D) and the cascade input-channel count are detected automatically from the configuration's `patch_size` and the upstream plans — no extra flags are needed. + +**Cascade workflow:** train and predict the lowres stage with stock nnUNetv2 first so the cascade inputs exist, then run distillation on each stage: + +```bash +# (1) Train and predict the upstream lowres teacher with stock nnUNetv2, +# which populates predicted_next_stage/3d_cascade_fullres/. +nnUNetv2_train DATASET_ID 3d_lowres 0 +nnUNetv2_predict_from_modelfolder ... # see nnUNetv2 docs for cascade prep + +# (2) Distill the lowres stage (optional; you can also keep the upstream lowres model). +nnUNetv2_distillation_train -d DATASET_ID -c 3d_lowres -f 0 -a 0.3 -temp 3.0 -r 2 + +# (3) Distill the cascade fullres stage — extra input channels are picked up automatically. +nnUNetv2_distillation_train -d DATASET_ID -c 3d_cascade_fullres -f 0 -a 0.3 -temp 3.0 -r 2 +``` #### Standard Knowledge Distillation @@ -133,6 +161,11 @@ nnUNetv2_distillation_train -d DATASET_ID -f 0 -a 0.3 -temp 3.0 -r 2 --use_da5 # Combine DA5 with other options nnUNetv2_distillation_train -d DATASET_ID -f 0 -tf 0 1 2 3 4 -a 0.3 -temp 3.0 -r 2 --use_da5 -c_continue + +# 2D / 3d_lowres / 3d_cascade_fullres examples (default is 3d_fullres) +nnUNetv2_distillation_train -d DATASET_ID -c 2d -f 0 -a 0.3 -temp 3.0 -r 2 +nnUNetv2_distillation_train -d DATASET_ID -c 3d_lowres -f 0 -a 0.3 -temp 3.0 -r 2 +nnUNetv2_distillation_train -d DATASET_ID -c 3d_cascade_fullres -f 0 -a 0.3 -temp 3.0 -r 2 ``` #### ResEnc Knowledge Distillation (Enhanced Performance) @@ -220,6 +253,12 @@ nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -da5 -v # Export with simplified ONNX nnUNetv2_distillation_export_onnx -d DATASET_ID -f 0 -r 2 -sim + +# Export 2D / 3d_lowres / 3d_cascade_fullres models — 2D dumps an N×C×H×W ONNX, +# 3d_cascade_fullres uses the lowres-augmented input channel count automatically. +nnUNetv2_distillation_export_onnx -d DATASET_ID -c 2d -f 0 -r 2 +nnUNetv2_distillation_export_onnx -d DATASET_ID -c 3d_lowres -f 0 -r 2 +nnUNetv2_distillation_export_onnx -d DATASET_ID -c 3d_cascade_fullres -f 0 -r 2 ``` #### ResEnc Distillation Model Export