Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 70 additions & 48 deletions distillation/fast_nnunet_distillation_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,28 @@ def export_to_onnx(dataset_id,
device=device
)

# Initialize network architecture (full trainer initialization not needed)
# Initialize network architecture (full trainer initialization not needed).
# determine_num_input_channels handles cascade (previous stage adds extra channels).
num_input_channels = determine_num_input_channels(trainer.plans_manager, trainer.configuration_manager, dataset_json)
num_output_channels = trainer.label_manager.num_segmentation_heads

# Pick conv/norm ops based on patch_size dim so 2d configurations work
dim = len(trainer.configuration_manager.patch_size)
if dim == 2:
conv_op_cls = torch.nn.Conv2d
norm_op_cls = torch.nn.InstanceNorm2d
elif dim == 3:
conv_op_cls = torch.nn.Conv3d
norm_op_cls = torch.nn.InstanceNorm3d
else:
raise ValueError(f"Unsupported patch_size dimensionality: {dim} (expected 2 or 3)")

# if nnunet_style is True, force using single channel input
if nnunet_style:
num_input_channels = 1
print(f"Using single channel fixed size mode, force using 1 input channel")
print(f"Number of input channels: {num_input_channels}, Number of output channels: {num_output_channels}")

print(f"Configuration: {configuration} ({dim}D). Input channels: {num_input_channels}, output channels: {num_output_channels}")

# Get network architecture parameters
if configuration in plans['configurations']:
Expand All @@ -171,42 +183,50 @@ def export_to_onnx(dataset_id,
if verbose:
print(f"Found architecture info: {arch_info}")

# Per-dim defaults used only when plans omit kernel_sizes/strides (rare).
default_kernel = [3] * dim
default_stride_unit = [2] * dim
default_stride_first = [1] * dim

# Extract parameters from architecture
if isinstance(arch_info, dict) and 'arch_kwargs' in arch_info:
arch_kwargs = arch_info['arch_kwargs']
n_stages = arch_kwargs.get('n_stages', 6)
features_per_stage = arch_kwargs.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages])
strides = arch_kwargs.get('strides', [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1))
kernel_sizes = arch_kwargs.get('kernel_sizes', [[3, 3, 3]] * n_stages)
strides = arch_kwargs.get('strides', [default_stride_first] + [default_stride_unit] * (n_stages - 1))
kernel_sizes = arch_kwargs.get('kernel_sizes', [default_kernel] * n_stages)

if verbose:
print(f"Extracted from architecture: n_stages={n_stages}, features={features_per_stage}")
else:
# Fall back to defaults
n_stages = 6
features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages]
strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1)
kernel_sizes = [[3, 3, 3]] * n_stages
strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1)
kernel_sizes = [default_kernel] * n_stages
else:
default_kernel = [3] * dim
default_stride_unit = [2] * dim
default_stride_first = [1] * dim
# Get number of stages and features - handle different plans formats
if 'pool_op_kernel_sizes' in config:
n_stages = len(config['pool_op_kernel_sizes'])
features_per_stage = config.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages])
strides = config['pool_op_kernel_sizes']
kernel_sizes = config.get('conv_kernel_sizes', [[3, 3, 3]] * n_stages)
kernel_sizes = config.get('conv_kernel_sizes', [default_kernel] * n_stages)
elif 'architecture_kwargs' in config and 'arch_kwargs' in config['architecture_kwargs']:
arch_kwargs = config['architecture_kwargs']['arch_kwargs']
n_stages = arch_kwargs.get('n_stages', 6)
features_per_stage = arch_kwargs.get('features_per_stage', [32, 64, 128, 256, 320, 320][:n_stages])
strides = arch_kwargs.get('strides', [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1))
kernel_sizes = arch_kwargs.get('kernel_sizes', [[3, 3, 3]] * n_stages)
strides = arch_kwargs.get('strides', [default_stride_first] + [default_stride_unit] * (n_stages - 1))
kernel_sizes = arch_kwargs.get('kernel_sizes', [default_kernel] * n_stages)
else:
# If configuration not found, use default values
print("Warning: Cannot find complete network architecture configuration in plans, using defaults")
n_stages = 6
features_per_stage = [32, 64, 128, 256, 320, 320][:n_stages]
strides = [[1, 1, 1]] + [[2, 2, 2]] * (n_stages - 1)
kernel_sizes = [[3, 3, 3]] * n_stages
strides = [default_stride_first] + [default_stride_unit] * (n_stages - 1)
kernel_sizes = [default_kernel] * n_stages

# Apply feature reduction factor
lite_features_per_stage = [max(f // feature_reduction_factor, 8) for f in features_per_stage]
Expand Down Expand Up @@ -251,13 +271,13 @@ def export_to_onnx(dataset_id,
num_classes=num_output_channels,
n_stages=n_stages,
features_per_stage=lite_features_per_stage,
conv_op=torch.nn.Conv3d,
conv_op=conv_op_cls,
kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes],
strides=[tuple(p) if isinstance(p, list) else p for p in strides],
n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage
n_conv_per_stage_decoder=[2] * (n_stages - 1),
conv_bias=True,
norm_op=torch.nn.InstanceNorm3d,
norm_op=norm_op_cls,
norm_op_kwargs={"eps": 1e-5, "affine": True},
nonlin=torch.nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
Expand All @@ -276,13 +296,13 @@ def export_to_onnx(dataset_id,
num_classes=num_output_channels,
n_stages=n_stages,
features_per_stage=lite_features_per_stage,
conv_op=torch.nn.Conv3d,
conv_op=conv_op_cls,
kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes],
strides=[tuple(p) if isinstance(p, list) else p for p in strides],
n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage
n_conv_per_stage_decoder=[2] * (n_stages - 1),
conv_bias=True,
norm_op=torch.nn.InstanceNorm3d,
norm_op=norm_op_cls,
norm_op_kwargs={"eps": 1e-5, "affine": True},
nonlin=torch.nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
Expand All @@ -296,13 +316,13 @@ def export_to_onnx(dataset_id,
num_classes=num_output_channels,
n_stages=n_stages,
features_per_stage=lite_features_per_stage,
conv_op=torch.nn.Conv3d,
conv_op=conv_op_cls,
kernel_sizes=[tuple(k) if isinstance(k, list) else k for k in kernel_sizes],
strides=[tuple(p) if isinstance(p, list) else p for p in strides],
n_conv_per_stage=[2] * n_stages, # Default 2 convolution layers per stage
n_conv_per_stage_decoder=[2] * (n_stages - 1),
conv_bias=True,
norm_op=torch.nn.InstanceNorm3d,
norm_op=norm_op_cls,
norm_op_kwargs={"eps": 1e-5, "affine": True},
nonlin=torch.nn.LeakyReLU,
nonlin_kwargs={"inplace": True},
Expand Down Expand Up @@ -384,58 +404,60 @@ def forward(self, x):
wrapped_model.eval()
model = wrapped_model # Replace model with wrapped version

# Prepare export path
# Prepare export path (include configuration so 2d/3d_lowres/cascade don't overwrite each other)
if output_path is None:
output_dir = join(model_folder_fold, "exported_models")
os.makedirs(output_dir, exist_ok=True)

# add special identifier for single channel fixed size model
da5_suffix = "_da5" if use_da5 else ""
if nnunet_style:
output_path = join(output_dir, f"model_fold{fold}_{feature_reduction_factor}x{da5_suffix}_nnunet_format.onnx")
output_path = join(output_dir, f"model_{configuration}_fold{fold}_{feature_reduction_factor}x{da5_suffix}_nnunet_format.onnx")
else:
output_path = join(output_dir, f"model_fold{fold}_{feature_reduction_factor}x{da5_suffix}.onnx")
# Create input sample - use typical 3D medical image shape or custom shape
output_path = join(output_dir, f"model_{configuration}_fold{fold}_{feature_reduction_factor}x{da5_suffix}.onnx")

# Create input sample - shape follows patch_size dimensionality
if input_shape is None:
# Default size estimation (assuming input patch_size is the same as during training)
if 'patch_size' in config:
patch_size = config['patch_size']
input_shape = (1, num_input_channels, *patch_size)
else:
# Get patch_size from trainer
try:
patch_size = trainer.configuration_manager.patch_size
input_shape = (1, num_input_channels, *patch_size)
except:
# Default shape
input_shape = (1, num_input_channels, 128, 128, 128)
except Exception:
# Default shape per dim
input_shape = (1, num_input_channels, 256, 256) if dim == 2 else (1, num_input_channels, 128, 128, 128)

# if nnunet_style is True, force using single channel fixed size input
# if nnunet_style is True, force single channel input with fixed spatial size matching dim
if nnunet_style:
# force using single channel and fixed size
if input_shape is None or len(input_shape) != 5:
# use default fixed size
input_shape = (1, 1, 128, 128, 128)
expected_len = 2 + dim # batch + channel + spatial
if input_shape is None or len(input_shape) != expected_len:
input_shape = (1, 1, 256, 256) if dim == 2 else (1, 1, 128, 128, 128)
else:
# keep original spatial dimensions, but set channel number to 1
input_shape = (input_shape[0], 1, input_shape[2], input_shape[3], input_shape[4])
spatial = tuple(input_shape[2:])
input_shape = (input_shape[0], 1, *spatial)

print(f"Using single channel fixed size mode, input shape: {input_shape}")

# Use randn instead of zeros for better InstanceNorm behavior
torch.manual_seed(42) # For reproducibility
dummy_input = torch.randn(input_shape, dtype=torch.float32).to(device)
print(f"🔄 Exporting to ONNX (input: {input_shape})...")
# Set dynamic axes for ONNX export

print(f"Exporting to ONNX (input: {input_shape}, {dim}D)...")

# Set dynamic axes for ONNX export — spatial axis count depends on dim
if dynamic_axes and not nnunet_style:
# Batch size and spatial dimensions are dynamic
dynamic_axes_dict = {
'input': {0: 'batch_size', 2: 'height', 3: 'width', 4: 'depth'},
'output': {0: 'batch_size', 2: 'height', 3: 'width', 4: 'depth'}
}
if dim == 2:
dynamic_axes_dict = {
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 2: 'height', 3: 'width'},
}
else:
dynamic_axes_dict = {
'input': {0: 'batch_size', 2: 'depth', 3: 'height', 4: 'width'},
'output': {0: 'batch_size', 2: 'depth', 3: 'height', 4: 'width'},
}
elif dynamic_axes and nnunet_style:
# only batch size is dynamic, spatial dimensions are fixed
dynamic_axes_dict = {
Expand Down Expand Up @@ -561,7 +583,7 @@ def forward(self, x):
def main():
parser = argparse.ArgumentParser(description='Export Fast nnUNet distillation student model to ONNX format')
parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID (e.g., 776)')
parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration (default: 3d_fullres)')
parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)')
parser.add_argument('-f', '--fold', type=int, default=0, help='Model training fold number (default: 0)')
parser.add_argument('-r', '--reduction_factor', type=int, default=2, help='Feature reduction factor (default: 2)')
parser.add_argument('-cp', '--checkpoint', type=str, default='checkpoint_final.pth', help='Checkpoint filename (default: checkpoint_final.pth)')
Expand Down
2 changes: 1 addition & 1 deletion distillation/fast_nnunet_distillation_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def main():
# Create command line argument parser
parser = argparse.ArgumentParser(description='nnUNetv2 Knowledge Distillation Training')
parser.add_argument('-d', '--dataset_id', type=str, required=True, help='Dataset ID (e.g., 776)')
parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration (default: 3d_fullres)')
parser.add_argument('-c', '--configuration', type=str, default='3d_fullres', help='nnUNet configuration: 2d / 3d_lowres / 3d_fullres / 3d_cascade_fullres (default: 3d_fullres)')
parser.add_argument('-f', '--start_fold', type=int, default=0, help='Start fold number for training (default: 0)')
parser.add_argument('-t', '--teacher_model_folder', type=str, help='Path to teacher model folder (if not provided, will be auto-constructed)')
parser.add_argument('-tf', '--teacher_folds', type=int, nargs='+',
Expand Down
Loading