Skip to content

Add Primus knowledge-distillation support (v1.2.4)#7

Merged
77even merged 1 commit into
mainfrom
feature/primus-distillation
May 12, 2026
Merged

Add Primus knowledge-distillation support (v1.2.4)#7
77even merged 1 commit into
mainfrom
feature/primus-distillation

Conversation

@77even
Copy link
Copy Markdown
Owner

@77even 77even commented May 12, 2026

Summary

Adds end-to-end distillation support for the upstream nnunetv2 Primus transformer family (sizes S/B/M/L). Same multi-teacher KL pipeline as the standard / ResEnc distillation entry points, with a Primus student instead of a CNN U-Net student.

What's new

New module: distillation/primus_distillation_trainer.py

  • reduce_primus_dims(embed_dim, depth, num_heads, factor) — shrinks Primus hyperparameters while preserving the rotary-positional-embedding constraint. The 3D RoPE in Primus needs head_dim % 6 == 0, so the helper holds head_dim constant at the teacher's value and shrinks num_heads/depth by the requested factor; embed_dim follows.
  • LitePrimusStudent — thin wrapper around dynamic_network_architectures.architectures.primus.Primus that accepts the reduced (embed_dim, depth, num_heads) and asserts patch_size % 8 == 0.
  • nnUNetDistillationPrimusTrainer / nnUNetDistillationPrimusTrainerDA5 — inherit the KL/multi-teacher/fold-rotation pipeline from nnUNetDistillationTrainer, override build_network_architecture to return a Primus student, disable deep supervision (Primus is single-resolution), and use AdamW + linear-warmup → polynomial schedule that mirrors upstream AbstractPrimus.

Entry-point scripts

  • nnUNetv2_primus_distillation_train (fast_nnunet_primus_distillation_train.py) — mirrors the resenc training script and adds -ts/--teacher_size {S,B,M,L} plus -w/--warmup_epochs. Default teacher folder auto-derived from -ts:
    {nnUNet_results}/{Dataset}/nnUNet_Primus_{S|B|M|L}_Trainer__nnUNetPlans__{configuration}/
  • nnUNetv2_primus_distillation_export_onnx (fast_nnunet_primus_distillation_export_onnx.py) — single-tensor output (no deep-supervision wrapping), 3D-only, opset 17, dynamic batch + spatial axes by default. Reconstructs the student architecture from -ts + -r so weights load.

Packaging

  • setup.py: bump 1.2.31.2.4, register the two new console_scripts (nnUNetv2_primus_distillation_train, nnUNetv2_primus_distillation_export_onnx) and the three new py_modules.

Docs

  • docs/Distillation.md: new Primus training section (teacher prep, training command examples for all sizes, multi-teacher ensemble, DA5) and a Primus ONNX export section.

Teacher → student size table (with default -r 2)

Teacher embed_dim depth heads Student (r=2)
S 396 12 6 198 / 6 / 3
B 792 12 12 396 / 6 / 6
M (default) 864 16 12 432 / 8 / 6
L 1056 24 16 528 / 12 / 8

head_dim is preserved in every case (66 for S/B/L, 72 for M).

Test plan

CPU smoke tests (local-only, not committed) — 11/11 pass on macOS:

  • 3D / 2D LiteNNUNetStudent forward / backward
  • 3D / 2D LiteResEncStudent forward / backward
  • 3D cascade-shaped input forward
  • distillation_loss_fn finite + differentiable
  • 2D ONNX export round-trip (max diff ~3e-06)
  • reduce_primus_dims preserves the divisibility invariants across all (S/B/M/L) × (r=1,2,4)
  • LitePrimusStudent (M, r=2) forward + backward; shape (1, 3, 32, 32, 32), ~18.6M params
  • LitePrimusStudent rejects patch_size not divisible by 8
  • Primus ONNX export + round-trip on CPU

Still needed on the GPU box before tagging:

  • Train an upstream Primus teacher (nnUNetv2_train DATASET_ID 3d_fullres 0 -tr nnUNet_Primus_M_Trainer) and a few distillation steps with nnUNetv2_primus_distillation_train -d DATASET_ID -ts M -r 2.
  • Same for at least one other teacher size (S or L) to exercise the size-routing logic.
  • One round of nnUNetv2_primus_distillation_train --use_da5 to exercise the DA5 path.
  • ONNX export end-to-end against a real Primus checkpoint, including --simplify.

Follow-up

After merge: tag v1.2.4 + release. The refactor branch (PR #4) will then rebase onto the new main so the Primus support ships in v1.3.0 too.

Adds support for distilling upstream nnunetv2 Primus transformer teachers (sizes
S/B/M/L) into smaller Primus students. Mirrors the existing standard/ResEnc
distillation entry points and reuses the multi-teacher KL pipeline.

Changes:
- distillation/primus_distillation_trainer.py: new module exposing
  reduce_primus_dims, LitePrimusStudent, nnUNetDistillationPrimusTrainer, and
  nnUNetDistillationPrimusTrainerDA5. The student keeps head_dim constant
  (Primus' 3D rotary positional embedding needs head_dim divisible by 6) and
  shrinks num_heads/depth by the reduction factor; embed_dim follows. The
  trainer reuses the parent KL pipeline, disables deep supervision (Primus
  is single-resolution), and swaps in AdamW + linear-warmup -> polynomial
  schedule that mirrors upstream AbstractPrimus.
- distillation/fast_nnunet_primus_distillation_train.py: new training entry
  point with -ts/--teacher_size {S,B,M,L} and the usual -tf / --use_da5 /
  -rotate_folds / etc. Teacher folder auto-derived from teacher_size.
- distillation/fast_nnunet_primus_distillation_export_onnx.py: new ONNX
  export entry point. Single-tensor output (no deep-supervision wrapping),
  3D-only, opset 17, dynamic batch+spatial axes by default.
- setup.py: bump 1.2.3 -> 1.2.4, register the two new console_scripts and
  the three new py_modules.
- docs/Distillation.md: new Primus sections covering teacher prep, training
  command examples, and ONNX export.

Tested locally on macOS CPU: 11/11 smoke tests pass, including LitePrimusStudent
forward+backward at multiple sizes and a Primus ONNX round-trip. Full GPU
end-to-end training + ONNX export against real upstream Primus checkpoints
needs to be run on the GPU box.
@77even 77even added the enhancement New feature or request label May 12, 2026
@77even 77even merged commit d4f1ed0 into main May 12, 2026
@77even 77even deleted the feature/primus-distillation branch May 12, 2026 17:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant