Skip to content

Commit 07ae8e7

Browse files
yeyu-nvidiaclaude
andauthored
Add LoRA co-training support for HF EAGLE speculative decoding (#1060)
### What does this PR do? Type of change: New feature + bug fixes Adds **LoRA co-training** support for HF EAGLE speculative decoding. When `eagle_base_lora=True`, HF PEFT LoRA adapters are injected into the base model and co-trained alongside the EAGLE draft module in a single online training pass. A preservation loss (KL divergence between the original frozen base model output and the LoRA-adapted output) prevents base model drift. LoRA adapter weights are exported in standard peft format alongside EAGLE draft artifacts. ### Key features - **LoRA injection**: `peft.inject_adapter_in_model` applied in-place (no wrapper), keeping the existing `HFEagleModel` structure intact. - **Preservation loss**: Cross-entropy `H(ref, lora)` — equivalent gradient to `KL(ref || lora)` since `H(ref)` is constant w.r.t. LoRA params. - **Warmup schedule**: `eagle_base_lora_warmup_steps` freezes LoRA for N steps while the EAGLE head stabilizes, then enables co-training via a `LoRAWarmupCallback`. - **Logits detach regularization**: `eagle_base_lora_logits_detach_prob` stochastically detaches base logits from the EAGLE loss path, preventing LoRA from degenerating to maximize EAGLE accuracy at the cost of base model quality. - **Export**: Standard peft format (`adapter_model.safetensors` + `adapter_config.json`) alongside EAGLE draft model. - **Merge script**: `scripts/merge_lora.py` merges LoRA weights into the base model and restores the original `config.json` (avoids transformers 5.x rewriting `rope_theta` → `rope_parameters` which breaks vLLM/TRT-LLM). - **Multinode fix**: `dp_shard_size` now uses `WORLD_SIZE` instead of local GPU count. ### Config options ```python mtsp.convert(model, mode=[("eagle", { "eagle_base_lora": True, # enable LoRA co-training "eagle_base_lora_rank": 64, # LoRA rank "eagle_base_lora_alpha": 16.0, # LoRA scaling "eagle_base_lora_target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], "eagle_base_lora_preservation_loss_weight": 0.1, # preservation loss weight "eagle_base_lora_warmup_steps": 0, # freeze LoRA for N steps "eagle_base_lora_logits_detach_prob": 0.5, # detach prob (0=never, 1=always) })]) ``` ### Experimental results (Qwen3-8B, checkpoint-60000) Base model quality preserved across detach_prob sweep (lm_eval: IFEval, ARC-C, Winogrande — results pending final collection). **Acceptance rate** (mt_bench, draft_length=3, output_length=4096, temperature=0): | detach_prob | vLLM AR | TRT-LLM AR | |---|---|---| | baseline (no LoRA) | 2.14 | 2.15 | | 0.5 | 1.45 | 1.44 | | 0.8 | **3.06** | **3.01** | | 0.85 | 2.90 | 2.90 | | 0.9 | 2.76 | 2.77 | | 0.95 | 2.51 | 2.58 | | 0.99 | 2.37 | 2.37 | | 0.999 | 2.30 | 2.27 | | 0.9999 | 2.31 | 2.26 | Best AR at `detach_prob=0.8`: ~40% improvement over baseline. ### Testing `tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py` (5 tests): - `test_lora_layers_injected` — LoRA layers present after conversion - `test_trainable_params` — only `lora_*` and `eagle_module` params are trainable - `test_forward_returns_loss` — forward returns non-zero scalar loss - `test_eagle_offline_incompatible` — `eagle_base_lora=True` + `eagle_offline=True` raises `ValueError` - `test_export_lora_artifacts` — export produces standard peft adapter files ### Bug fixes (included in this PR) 1. **`launch_train.sh` case pattern ordering**: glob `--eagle_base_lora*` was before specific patterns (`--eagle_base_lora_rank*`, etc.), silently swallowing LoRA args. 2. **LoRA optimizer exclusion during warmup**: warmup freezing excluded LoRA from the optimizer entirely; fixed with `add_param_group` in the callback. 3. **`merge_lora.py` config.json**: `save_pretrained()` with transformers >=5.x rewrites `rope_theta` → `rope_parameters`, breaking vLLM positional embeddings. Fixed by copying the original base model config. 4. **Multinode `dp_shard_size`**: used local GPU count instead of `WORLD_SIZE`. ### Checklist - [x] Backward compatible (all new config fields have defaults) - [x] Uses `peft` via lazy imports (no hard dependency) - [x] Unit tests added - [x] Online HF training only (`eagle_offline=True` blocked) --------- Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 361f7e3 commit 07ae8e7

File tree

12 files changed

+586
-39
lines changed

12 files changed

+586
-39
lines changed

examples/specdec_bench/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def run_simple(args):
157157
tensor_parallel_size=args.tp_size,
158158
moe_expert_parallel_size=args.ep_size,
159159
trust_remote_code=args.trust_remote_code,
160+
tokenizer_path=args.tokenizer,
160161
**engine_args,
161162
)
162163

examples/specdec_bench/specdec_bench/models/vllm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
7272
num_speculative_tokens = specdec.get("num_speculative_tokens", 3)
7373
engine_args = AsyncEngineArgs(
7474
model=model_dir,
75+
tokenizer=kwargs.get("tokenizer_path"),
7576
trust_remote_code=kwargs.get("trust_remote_code", False),
7677
tensor_parallel_size=kwargs.get("tensor_parallel_size", 1),
7778
enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1,

examples/speculative_decoding/eagle_utils.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,107 @@ def make_speculative_data_module(
120120
class EagleTrainerWithAccLog(Trainer):
121121
"""Wrapper around Trainer that logs training accuracy."""
122122

123+
def __init__(
124+
self,
125+
*args,
126+
lora_lr_multiplier: float = 1.0,
127+
**kwargs,
128+
):
129+
super().__init__(*args, **kwargs)
130+
self.lora_lr_multiplier = lora_lr_multiplier
131+
132+
def create_optimizer(self):
133+
"""Override to give LoRA parameters a higher learning rate."""
134+
super().create_optimizer()
135+
if self.lora_lr_multiplier != 1.0:
136+
lora_ids = {
137+
id(p) for n, p in self.model.named_parameters() if "lora_" in n and p.requires_grad
138+
}
139+
if lora_ids:
140+
new_groups = []
141+
for group in self.optimizer.param_groups:
142+
lora = [p for p in group["params"] if id(p) in lora_ids]
143+
others = [p for p in group["params"] if id(p) not in lora_ids]
144+
if lora and others:
145+
new_groups.append({**group, "params": others})
146+
new_groups.append(
147+
{**group, "params": lora, "lr": group["lr"] * self.lora_lr_multiplier}
148+
)
149+
elif lora:
150+
new_groups.append({**group, "lr": group["lr"] * self.lora_lr_multiplier})
151+
else:
152+
new_groups.append(group)
153+
self.optimizer.param_groups = new_groups
154+
return self.optimizer
155+
123156
def compute_loss(self, *args, **kwargs):
124-
"""Override compute_loss to save train accs in trainer state."""
157+
"""Override compute_loss to save train accs and per-component losses in trainer state."""
125158
if not hasattr(self.state, "training_accs"):
126159
self.state.training_accs = []
160+
if not hasattr(self.state, "component_losses"):
161+
self.state.component_losses = {"eagle": [], "preservation": []}
127162
kwargs.pop("num_items_in_batch", None)
128163
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
129-
if hasattr(outputs, "train_acc"):
164+
if hasattr(outputs, "train_acc") and any(outputs.train_acc):
130165
self.state.training_accs.append(outputs.train_acc)
166+
# Track per-component losses
167+
for key, attr in [
168+
("eagle", "eagle_loss"),
169+
("preservation", "preservation_loss"),
170+
]:
171+
val = getattr(outputs, attr, None)
172+
if val is not None:
173+
self.state.component_losses[key].append(val.item())
131174
return loss
132175

133176

177+
class LoRAWarmupCallback(TrainerCallback):
178+
"""Manages LoRA warmup: freezes LoRA during warmup, unfreezes after."""
179+
180+
def __init__(self, warmup_steps: int):
181+
self.warmup_steps = warmup_steps
182+
self._activated = False
183+
184+
def on_step_begin(self, args, state, control, **kwargs):
185+
"""Check if warmup is over and activate LoRA co-training."""
186+
if self._activated:
187+
return control
188+
if state.global_step >= self.warmup_steps:
189+
model = kwargs["model"]
190+
# Unwrap DDP/FSDP if needed
191+
raw_model = model.module if hasattr(model, "module") else model
192+
if hasattr(raw_model, "_lora_cotraining_active"):
193+
raw_model._lora_cotraining_active = True
194+
# Unfreeze LoRA parameters
195+
lora_params = []
196+
for name, param in raw_model._base_model.named_parameters():
197+
if "lora_" in name:
198+
param.requires_grad = True
199+
lora_params.append(param)
200+
201+
# Add LoRA params to optimizer — they were excluded at creation time
202+
# because requires_grad was False during warmup.
203+
optimizer = kwargs.get("optimizer")
204+
if optimizer is not None and lora_params:
205+
existing_ids = {id(p) for g in optimizer.param_groups for p in g["params"]}
206+
new_params = [p for p in lora_params if id(p) not in existing_ids]
207+
if new_params:
208+
optimizer.add_param_group(
209+
{
210+
"params": new_params,
211+
"lr": optimizer.param_groups[0]["lr"],
212+
"weight_decay": 0.0,
213+
}
214+
)
215+
print_rank_0(f" Added {len(new_params)} LoRA params to optimizer")
216+
217+
print_rank_0(
218+
f"Step {state.global_step}: LoRA warmup complete, enabling co-training."
219+
)
220+
self._activated = True
221+
return control
222+
223+
134224
class EagleTrainingPlot(TrainerCallback):
135225
"""Callback that plot training acc and AR during training."""
136226

@@ -176,8 +266,16 @@ def on_log(self, args, state, control, **kwargs):
176266
if logs:
177267
wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step)
178268

179-
# reset training_accs
269+
# Log per-component losses
270+
if hasattr(state, "component_losses"):
271+
for key, vals in state.component_losses.items():
272+
if vals:
273+
wandb.log({f"{key}_loss": np.mean(vals)}, step=state.global_step)
274+
275+
# reset training_accs and component_losses
180276
state.training_accs = []
277+
if hasattr(state, "component_losses"):
278+
state.component_losses = {"eagle": [], "preservation": []}
181279
return control
182280

183281
def on_step_end(self, args, state, control, **kwargs):
@@ -186,6 +284,7 @@ def on_step_end(self, args, state, control, **kwargs):
186284
return control
187285
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
188286
print_rank_0("Running AR validation...")
287+
torch.cuda.empty_cache()
189288
try:
190289
ars = validate_ar(
191290
model=kwargs["model"],

examples/speculative_decoding/main.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from eagle_utils import (
4141
EagleTrainerWithAccLog,
4242
EagleTrainingPlot,
43+
LoRAWarmupCallback,
4344
make_speculative_data_module,
4445
patch_ring_attention_for_ttt,
4546
)
@@ -183,7 +184,9 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic
183184

184185
if hf_cfg.get("dp_shard_size") is None:
185186
cp_size = hf_cfg.get("cp_size", 1)
186-
hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size
187+
# Use WORLD_SIZE (total GPUs across all nodes) when available, else local GPU count.
188+
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
189+
hf_cfg["dp_shard_size"] = world_size // cp_size
187190

188191
return hf_cfg, eagle_cfg, dflash_cfg
189192

@@ -316,6 +319,20 @@ def train():
316319
else:
317320
raise Exception(f"{training_args.mode} is not supported!")
318321

322+
# Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast
323+
# them. We iterate named_buffers and reassign via the owning module to
324+
# keep the module tree consistent. Parameters are left on CPU — the HF
325+
# Trainer will move them during init.
326+
if torch.cuda.is_available():
327+
_target_dev = torch.device("cuda", 0)
328+
for name, buf in list(model.named_buffers()):
329+
if buf.device.type == "cpu":
330+
parts = name.split(".")
331+
mod = model
332+
for p in parts[:-1]:
333+
mod = getattr(mod, p)
334+
setattr(mod, parts[-1], buf.to(_target_dev))
335+
319336
print_rank_0("Loading dataset...")
320337
is_dflash = training_args.mode == "dflash"
321338
if training_args.mode in ("eagle3", "dflash"):
@@ -327,11 +344,15 @@ def train():
327344
shift_labels=not is_dflash,
328345
)
329346

347+
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
348+
if eagle_cfg.get("eagle_base_lora") and eagle_cfg.get("eagle_base_lora_warmup_steps", 0) > 0:
349+
callbacks.append(LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]))
350+
330351
trainer = EagleTrainerWithAccLog(
331352
model=model,
332353
processing_class=tokenizer,
333354
args=training_args,
334-
callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)],
355+
callbacks=callbacks,
335356
**data_module,
336357
)
337358

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
accelerate>=1.12.0
2+
peft==0.18.1
13
transformers>=5.0,<5.4
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Merge LoRA weights from an exported EAGLE checkpoint into the base model and save.
17+
18+
Usage:
19+
python merge_lora.py \
20+
--base_model_path /path/to/original/base/model \
21+
--exported_lora_dir /path/to/exported/eagle/checkpoint \
22+
--output_path /path/to/merged/output
23+
24+
The exported checkpoint (from export_hf_checkpoint.py) contains
25+
adapter_model.safetensors and adapter_config.json in standard peft format.
26+
This script loads the original base model, applies the trained LoRA adapters,
27+
merges them into the base weights, and saves the fused model + tokenizer.
28+
"""
29+
30+
import argparse
31+
from pathlib import Path
32+
33+
from safetensors.torch import load_file
34+
from transformers import AutoModelForCausalLM, AutoTokenizer
35+
36+
37+
def parse_args():
38+
parser = argparse.ArgumentParser(
39+
description="Merge LoRA weights from an exported EAGLE checkpoint into the base model."
40+
)
41+
parser.add_argument(
42+
"--base_model_path",
43+
type=str,
44+
required=True,
45+
help="Path to the original base model (HF model name or local path).",
46+
)
47+
parser.add_argument(
48+
"--exported_lora_dir",
49+
type=str,
50+
required=True,
51+
help="Path to the exported EAGLE checkpoint containing adapter_model.safetensors.",
52+
)
53+
parser.add_argument(
54+
"--output_path",
55+
type=str,
56+
required=True,
57+
help="Directory to save the merged (fused) base model.",
58+
)
59+
return parser.parse_args()
60+
61+
62+
def main():
63+
args = parse_args()
64+
lora_dir = Path(args.exported_lora_dir)
65+
66+
# Verify exported files exist (standard peft naming)
67+
config_path = lora_dir / "adapter_config.json"
68+
weights_path = lora_dir / "adapter_model.safetensors"
69+
if not config_path.exists() or not weights_path.exists():
70+
raise FileNotFoundError(
71+
f"Expected adapter_config.json and adapter_model.safetensors "
72+
f"in {lora_dir}. Run export_hf_checkpoint.py first."
73+
)
74+
75+
lora_sd = load_file(weights_path)
76+
print(f"Loaded {len(lora_sd)} LoRA tensors from {lora_dir}")
77+
print(f" Sample keys: {list(lora_sd.keys())[:4]}")
78+
79+
# Load the original base model
80+
print(f"Loading base model from {args.base_model_path}...")
81+
model = AutoModelForCausalLM.from_pretrained(
82+
args.base_model_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True
83+
)
84+
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, trust_remote_code=True)
85+
86+
# Load LoRA adapter into the base model (export dir uses standard peft naming)
87+
print("Loading LoRA adapter via PeftModel.from_pretrained...")
88+
from peft import PeftModel
89+
90+
model = PeftModel.from_pretrained(model, str(lora_dir))
91+
print(" PeftModel loaded successfully")
92+
93+
# Debug: check adapter file keys vs model keys and values
94+
adapter_keys = set(lora_sd.keys())
95+
model_lora_keys = {k for k in model.state_dict() if ".lora_A." in k or ".lora_B." in k}
96+
print(f" Adapter file keys (first 4): {sorted(adapter_keys)[:4]}")
97+
print(f" Model LoRA keys (first 4): {sorted(model_lora_keys)[:4]}")
98+
# Check if exported lora_B values are actually non-zero
99+
for k, v in lora_sd.items():
100+
if ".lora_B." in k:
101+
print(f" Exported {k}: shape={v.shape}, norm={v.norm().item():.6f}")
102+
break
103+
104+
# Verify lora_B weights are non-zero (B is init'd to zero, so non-zero means loaded)
105+
lora_b_norms = [v.norm().item() for k, v in model.state_dict().items() if ".lora_B." in k]
106+
if not lora_b_norms or all(n == 0 for n in lora_b_norms):
107+
raise RuntimeError("LoRA-B weights are all zero — adapter loading failed.")
108+
print(
109+
f" Verified: {len(lora_b_norms)} LoRA-B matrices "
110+
f"(mean norm={sum(lora_b_norms) / len(lora_b_norms):.4f})"
111+
)
112+
113+
# Merge LoRA into base weights and remove adapter wrappers
114+
model = model.merge_and_unload()
115+
print("LoRA merged successfully.")
116+
117+
# Save
118+
print(f"Saving merged model to {args.output_path}...")
119+
model.save_pretrained(args.output_path)
120+
tokenizer.save_pretrained(args.output_path)
121+
122+
# Restore the original base model's config.json. save_pretrained() with newer
123+
# transformers (>=5.x) rewrites config fields (e.g. rope_theta → rope_parameters,
124+
# torch_dtype → dtype) which can confuse downstream engines like TRT-LLM or vLLM.
125+
# Since LoRA only changes weights — not architecture — the original config is correct.
126+
import shutil
127+
128+
base_config = Path(args.base_model_path) / "config.json"
129+
output_config = Path(args.output_path) / "config.json"
130+
if base_config.exists():
131+
shutil.copy2(str(base_config), str(output_config))
132+
print(f" Restored original config.json from {base_config}")
133+
134+
print(f"Done! Merged model saved to {args.output_path}")
135+
136+
137+
if __name__ == "__main__":
138+
main()

0 commit comments

Comments
 (0)