Skip to content

Add multi-configuration support: 2d / 3d_lowres / 3d_cascade_fullres (v1.2.3)#6

Merged
77even merged 1 commit into
mainfrom
feature/multi-config-support
May 12, 2026
Merged

Add multi-configuration support: 2d / 3d_lowres / 3d_cascade_fullres (v1.2.3)#6
77even merged 1 commit into
mainfrom
feature/multi-config-support

Conversation

@77even
Copy link
Copy Markdown
Owner

@77even 77even commented May 12, 2026

Summary

Previously distillation only worked for 3d_fullres. The student model and ONNX export paths hardcoded Conv3d / InstanceNorm3d and a 5D dummy input shape, so 2d, 3d_lowres, and 3d_cascade_fullres all crashed (issue raised in chat).

This PR makes all four standard nnUNetv2 configurations work end-to-end through the distillation + ONNX export path.

Changes

  • Trainer (nnUNetDistillationTrainer)
    • Picks conv_op / norm_op from len(patch_size): 2D configs use Conv2d / InstanceNorm2d, 3D configs use Conv3d / InstanceNorm3d.
    • Cascade (3d_cascade_fullres) input channels come from upstream determine_num_input_channels, which already adds num_classes from the previous stage — no extra plumbing needed.
  • Student models (LiteNNUNetStudent, LiteResEncStudent)
    • Default kernel_sizes / strides shape and norm_op class are picked from conv_op.
    • norm_op default is now None with auto-selection. Existing callers that pass it explicitly stay valid.
  • ONNX export (fast_nnunet_distillation_export_onnx, fast_nnunet_resenc_distillation_export_onnx)
    • Detects dim from patch_size, builds the right N+C+spatial dummy input.
    • dynamic_axes are emitted per-dim (3 spatial axes in 3D, 2 in 2D).
    • Output filename now includes the configuration (e.g. model_2d_fold0_2x.onnx) so 2d / 3d_lowres / cascade exports don't overwrite each other.
  • CLI help text on --configuration now lists all four supported values for the four entry-point scripts.
  • Docs (docs/Distillation.md): configurations table, example commands per configuration, cascade workflow note.
  • Version bump: 1.2.2 -> 1.2.3.

Test plan

CPU smoke tests (local-only, not committed) — 7/7 pass on macOS:

  • 3D LiteNNUNetStudent forward / backward
  • 2D LiteNNUNetStudent forward / backward (verifies InstanceNorm2d was chosen)
  • 3D LiteResEncStudent forward / backward
  • 2D LiteResEncStudent forward / backward
  • 3D cascade-shaped input (extra input channels) forward / backward
  • distillation_loss_fn KL loss + backward
  • 2D ONNX export round-trip (max_diff = 2.6e-06)

GPU validation still needed before tagging the release:

  • A few distillation training steps on a real dataset for -c 2d, -c 3d_lowres, and -c 3d_cascade_fullres (cascade requires upstream nnUNetv2_train ... 3d_lowres + nnUNetv2_predict --save_probabilities to be run first so predicted_next_stage/ exists).
  • ONNX export end-to-end against a real 2D checkpoint, verifying the file name and tensor shape match expectations.

Follow-ups

Previously distillation only worked for 3d_fullres. The student model and
ONNX export paths hardcoded Conv3d/InstanceNorm3d and a 5D dummy input shape,
so any other configuration crashed.

Changes:
- nnUNetDistillationTrainer: pick conv_op/norm_op from len(patch_size); 2D
  configs now use Conv2d/InstanceNorm2d. Cascade input channels are handled
  by upstream determine_num_input_channels (already adds num_classes from
  the previous stage), so 3d_cascade_fullres works without extra plumbing.
- LiteNNUNetStudent / LiteResEncStudent: default kernel/stride/norm pick
  the right shape and op class from conv_op. norm_op default is now None
  with auto-selection; existing callers that pass it explicitly stay valid.
- ONNX export scripts: detect dim from patch_size, build N+C+spatial dummy
  input, set dynamic_axes per-dim, and include the configuration in the
  generated filename so 2d/3d_lowres/cascade don't overwrite each other.
- Help text on --configuration now lists all four supported values.
- docs/Distillation.md: documents the supported configurations table and
  adds example commands plus a cascade workflow note.
- Bump version to 1.2.3.

CPU smoke tests cover 2D and 3D forward/backward for both student types,
cascade-shaped (extra input channels) forward, distillation loss, and a
2D ONNX export round-trip (max_diff = 2.6e-06). Full GPU end-to-end
training and ONNX export against real checkpoints still need to be run
on the GPU box.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@77even 77even added the enhancement New feature or request label May 12, 2026
@77even 77even merged commit 3b5aa8e into main May 12, 2026
@77even 77even deleted the feature/multi-config-support branch May 12, 2026 16:26
77even added a commit that referenced this pull request May 12, 2026
Add multi-configuration support: 2d / 3d_lowres / 3d_cascade_fullres (v1.2.3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant