Skip to content

Commit 0f3877a

Browse files
yeyu-nvidiaclaude
andcommitted
Init lora_B with small random values and restore logits gradient path
B=0 initialization creates a saddle point where the preservation gradient is exactly zero at init, allowing the EAGLE logits gradient to dominate unopposed before preservation can react. Initialize lora_B with N(0, 0.01) so the preservation loss is active from step 0 and constrains LoRA from the start. With preservation active at init, restore the direct logits gradient path (remove detach on base_outputs.logits in EAGLE loss) to give LoRA a strong training signal while relying on preservation loss to prevent collapse. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent ec61f24 commit 0f3877a

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,15 @@ def _inject_base_lora(self):
560560
bias="none",
561561
)
562562
inject_adapter_in_model(lora_config, self._base_model, adapter_name="default")
563-
# Unfreeze only the LoRA parameters
563+
# Unfreeze only the LoRA parameters and initialize lora_B with small random values
564+
# instead of the default zeros. B=0 creates a saddle point where the preservation
565+
# gradient is zero at init, allowing the EAGLE gradient to dominate unopposed.
566+
# A small non-zero B ensures preservation loss is active from step 0.
564567
for name, param in self._base_model.named_parameters():
565568
if "lora_" in name:
566569
param.requires_grad = True
570+
if "lora_B" in name:
571+
torch.nn.init.normal_(param, std=0.01)
567572

568573
def _set_base_lora_enabled(self, enabled: bool) -> None:
569574
"""Enable or disable LoRA adapters in the base model."""
@@ -1017,10 +1022,7 @@ def forward(
10171022
# base model predict +1 tok, while eagle predict +2
10181023
# so we shift base model outputs compared to eagle outputs
10191024
# additionally, we mask the first n tok of eagle outputs at nth TTT step
1020-
# Detach so the EAGLE loss treats base logits as fixed soft labels and does
1021-
# not backprop into the base model through this path. LoRA still receives
1022-
# EAGLE gradients via the hidden-state path (out_hiddens -> eagle_input_hiddens).
1023-
base_outputs.logits.detach()[:, 1 + i + ttt_step :],
1025+
base_outputs.logits[:, 1 + i + ttt_step :],
10241026
eagle_logit[:, ttt_step : -(1 + i)],
10251027
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
10261028
)

0 commit comments

Comments
 (0)