Skip to content

Add --freeze-base-for-mtp to train MTP heads on frozen quantized base#4785

Draft
yeyu-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
yeyu-nvidia:yeyu/mtp-freeze-base
Draft

Add --freeze-base-for-mtp to train MTP heads on frozen quantized base#4785
yeyu-nvidia wants to merge 2 commits into
NVIDIA:mainfrom
yeyu-nvidia:yeyu/mtp-freeze-base

Conversation

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Summary

  • Add --freeze-base-for-mtp argument to freeze all base model parameters and train only MTP heads
  • Construct mtp_block_spec in modelopt_gpt_hybrid_builder() so MTP heads can be added to ModelOpt-quantized models
  • Add _freeze_base_for_mtp() helper that sets requires_grad=False on all non-MTP parameters

Use case: After NVFP4 QAD, load the quantized checkpoint with --mtp-num-layers 1 --freeze-base-for-mtp, which adds randomly-initialized MTP heads and trains them with lm_loss while keeping the quantized base frozen.

Test plan

  • Unit test: build model with freeze_base_for_mtp=True, verify only mtp.layers.* params have requires_grad=True
  • Dry run: load a small quantized checkpoint with --mtp-num-layers 1 --freeze-base-for-mtp, verify model builds and MTP params are randomly initialized
  • Short training run: verify MTP loss decreases while base model params stay unchanged

🤖 Generated with Claude Code

After NVFP4 QAD, this flag allows loading the quantized checkpoint,
adding randomly-initialized MTP heads, and training only MTP parameters
with lm_loss while keeping the base model frozen.

Changes:
- arguments.py: Add --freeze-base-for-mtp argument
- model_builder.py: Construct mtp_block_spec in modelopt builder,
  add _freeze_base_for_mtp() to freeze all non-MTP parameters

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Tests verify that _freeze_base_for_mtp correctly freezes all base model
parameters while keeping MTP layer parameters trainable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant