Skip to content

Commit ddddb81

Browse files
yeyu-nvidiaclaude
andcommitted
Add LoRA LR multiplier and detach base logits in EAGLE loss
Detach base_outputs.logits when used as soft labels in the EAGLE loss so gradients do not flow back to LoRA through the label path (which causes circular collapse). LoRA still receives EAGLE gradients via the hidden- state path (out_hiddens -> eagle_input_hiddens). Add eagle_base_lora_lr_multiplier (default 10x) to compensate for the weaker hidden-state gradient signal: LoRA parameters are split into a separate optimizer param group with lr = base_lr * multiplier. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 7f9eb26 commit ddddb81

4 files changed

Lines changed: 52 additions & 2 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,38 @@ def make_eagle_supervised_data_module(
170170
class EagleTrainerWithAccLog(Trainer):
171171
"""Wrapper around Trainer that logs training accuracy."""
172172

173+
def __init__(self, *args, lora_lr_multiplier: float = 1.0, **kwargs):
174+
super().__init__(*args, **kwargs)
175+
self.lora_lr_multiplier = lora_lr_multiplier
176+
177+
def create_optimizer(self):
178+
"""Override to give LoRA parameters a higher learning rate."""
179+
super().create_optimizer()
180+
if self.lora_lr_multiplier != 1.0:
181+
lora_ids = {
182+
id(p)
183+
for n, p in self.model.named_parameters()
184+
if "lora_" in n and p.requires_grad
185+
}
186+
if lora_ids:
187+
new_groups = []
188+
for group in self.optimizer.param_groups:
189+
lora = [p for p in group["params"] if id(p) in lora_ids]
190+
others = [p for p in group["params"] if id(p) not in lora_ids]
191+
if lora and others:
192+
new_groups.append({**group, "params": others})
193+
new_groups.append(
194+
{**group, "params": lora, "lr": group["lr"] * self.lora_lr_multiplier}
195+
)
196+
elif lora:
197+
new_groups.append(
198+
{**group, "lr": group["lr"] * self.lora_lr_multiplier}
199+
)
200+
else:
201+
new_groups.append(group)
202+
self.optimizer.param_groups = new_groups
203+
return self.optimizer
204+
173205
def compute_loss(self, *args, **kwargs):
174206
"""Override compute_loss to save train accs in trainer state."""
175207
if not hasattr(self.state, "training_accs"):

examples/speculative_decoding/launch_train.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ while [ $# -gt 0 ]; do
134134
if [[ "$1" != *=* ]]; then shift; fi
135135
EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT="${1#*=}"
136136
;;
137+
--eagle_base_lora_lr_multiplier*)
138+
if [[ "$1" != *=* ]]; then shift; fi
139+
EAGLE_BASE_LORA_LR_MULTIPLIER="${1#*=}"
140+
;;
137141
--eagle_base_lora*)
138142
if [[ "$1" != *=* ]]; then shift; fi
139143
EAGLE_BASE_LORA="${1#*=}"
@@ -184,6 +188,7 @@ EAGLE_BASE_LORA_RANK=${EAGLE_BASE_LORA_RANK:-64}
184188
EAGLE_BASE_LORA_ALPHA=${EAGLE_BASE_LORA_ALPHA:-16.0}
185189
EAGLE_BASE_LORA_TARGET_MODULES=${EAGLE_BASE_LORA_TARGET_MODULES:-""}
186190
EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT=${EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT:-1.0}
191+
EAGLE_BASE_LORA_LR_MULTIPLIER=${EAGLE_BASE_LORA_LR_MULTIPLIER:-10.0}
187192

188193

189194
if [[ "$MODE" == "eagle3" ]]; then
@@ -219,7 +224,8 @@ if [[ "$EAGLE_BASE_LORA" == "True" ]]; then
219224
LORA_ARGS="--eagle_base_lora True \
220225
--eagle_base_lora_rank $EAGLE_BASE_LORA_RANK \
221226
--eagle_base_lora_alpha $EAGLE_BASE_LORA_ALPHA \
222-
--eagle_base_lora_preservation_loss_weight $EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT"
227+
--eagle_base_lora_preservation_loss_weight $EAGLE_BASE_LORA_PRESERVATION_LOSS_WEIGHT \
228+
--eagle_base_lora_lr_multiplier $EAGLE_BASE_LORA_LR_MULTIPLIER"
223229
if [[ "$EAGLE_BASE_LORA_TARGET_MODULES" != "" ]]; then
224230
LORA_ARGS="$LORA_ARGS --eagle_base_lora_target_modules $EAGLE_BASE_LORA_TARGET_MODULES"
225231
fi

examples/speculative_decoding/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ class EagleArguments:
169169
)
170170
},
171171
)
172+
eagle_base_lora_lr_multiplier: float = field(
173+
default=10.0,
174+
metadata={
175+
"help": (
176+
"Learning rate multiplier for LoRA parameters relative to the base learning rate."
177+
)
178+
},
179+
)
172180

173181

174182
def train():
@@ -285,6 +293,7 @@ def train():
285293
processing_class=tokenizer,
286294
args=training_args,
287295
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
296+
lora_lr_multiplier=eagle_args.eagle_base_lora_lr_multiplier,
288297
**data_module,
289298
)
290299

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,10 @@ def forward(
10121012
# base model predict +1 tok, while eagle predict +2
10131013
# so we shift base model outputs compared to eagle outputs
10141014
# additionally, we mask the first n tok of eagle outputs at nth TTT step
1015-
base_outputs.logits[:, 1 + i + ttt_step :],
1015+
# Detach so the EAGLE loss treats base logits as fixed soft labels and does
1016+
# not backprop into the base model through this path. LoRA still receives
1017+
# EAGLE gradients via the hidden-state path (out_hiddens -> eagle_input_hiddens).
1018+
base_outputs.logits.detach()[:, 1 + i + ttt_step :],
10161019
eagle_logit[:, ttt_step : -(1 + i)],
10171020
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
10181021
)

0 commit comments

Comments
 (0)