You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add LoRA co-training support for HF EAGLE speculative decoding (#1060)
### What does this PR do?
Type of change: New feature + bug fixes
Adds **LoRA co-training** support for HF EAGLE speculative decoding.
When `eagle_base_lora=True`, HF PEFT LoRA adapters are injected into the
base model and co-trained alongside the EAGLE draft module in a single
online training pass. A preservation loss (KL divergence between the
original frozen base model output and the LoRA-adapted output) prevents
base model drift. LoRA adapter weights are exported in standard peft
format alongside EAGLE draft artifacts.
### Key features
- **LoRA injection**: `peft.inject_adapter_in_model` applied in-place
(no wrapper), keeping the existing `HFEagleModel` structure intact.
- **Preservation loss**: Cross-entropy `H(ref, lora)` — equivalent
gradient to `KL(ref || lora)` since `H(ref)` is constant w.r.t. LoRA
params.
- **Warmup schedule**: `eagle_base_lora_warmup_steps` freezes LoRA for N
steps while the EAGLE head stabilizes, then enables co-training via a
`LoRAWarmupCallback`.
- **Logits detach regularization**: `eagle_base_lora_logits_detach_prob`
stochastically detaches base logits from the EAGLE loss path, preventing
LoRA from degenerating to maximize EAGLE accuracy at the cost of base
model quality.
- **Export**: Standard peft format (`adapter_model.safetensors` +
`adapter_config.json`) alongside EAGLE draft model.
- **Merge script**: `scripts/merge_lora.py` merges LoRA weights into the
base model and restores the original `config.json` (avoids transformers
5.x rewriting `rope_theta` → `rope_parameters` which breaks
vLLM/TRT-LLM).
- **Multinode fix**: `dp_shard_size` now uses `WORLD_SIZE` instead of
local GPU count.
### Config options
```python
mtsp.convert(model, mode=[("eagle", {
"eagle_base_lora": True, # enable LoRA co-training
"eagle_base_lora_rank": 64, # LoRA rank
"eagle_base_lora_alpha": 16.0, # LoRA scaling
"eagle_base_lora_target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
"eagle_base_lora_preservation_loss_weight": 0.1, # preservation loss weight
"eagle_base_lora_warmup_steps": 0, # freeze LoRA for N steps
"eagle_base_lora_logits_detach_prob": 0.5, # detach prob (0=never, 1=always)
})])
```
### Experimental results (Qwen3-8B, checkpoint-60000)
Base model quality preserved across detach_prob sweep (lm_eval: IFEval,
ARC-C, Winogrande — results pending final collection).
**Acceptance rate** (mt_bench, draft_length=3, output_length=4096,
temperature=0):
| detach_prob | vLLM AR | TRT-LLM AR |
|---|---|---|
| baseline (no LoRA) | 2.14 | 2.15 |
| 0.5 | 1.45 | 1.44 |
| 0.8 | **3.06** | **3.01** |
| 0.85 | 2.90 | 2.90 |
| 0.9 | 2.76 | 2.77 |
| 0.95 | 2.51 | 2.58 |
| 0.99 | 2.37 | 2.37 |
| 0.999 | 2.30 | 2.27 |
| 0.9999 | 2.31 | 2.26 |
Best AR at `detach_prob=0.8`: ~40% improvement over baseline.
### Testing
`tests/unit/torch/speculative/plugins/test_hf_speculative_lora.py` (5
tests):
- `test_lora_layers_injected` — LoRA layers present after conversion
- `test_trainable_params` — only `lora_*` and `eagle_module` params are
trainable
- `test_forward_returns_loss` — forward returns non-zero scalar loss
- `test_eagle_offline_incompatible` — `eagle_base_lora=True` +
`eagle_offline=True` raises `ValueError`
- `test_export_lora_artifacts` — export produces standard peft adapter
files
### Bug fixes (included in this PR)
1. **`launch_train.sh` case pattern ordering**: glob
`--eagle_base_lora*` was before specific patterns
(`--eagle_base_lora_rank*`, etc.), silently swallowing LoRA args.
2. **LoRA optimizer exclusion during warmup**: warmup freezing excluded
LoRA from the optimizer entirely; fixed with `add_param_group` in the
callback.
3. **`merge_lora.py` config.json**: `save_pretrained()` with
transformers >=5.x rewrites `rope_theta` → `rope_parameters`, breaking
vLLM positional embeddings. Fixed by copying the original base model
config.
4. **Multinode `dp_shard_size`**: used local GPU count instead of
`WORLD_SIZE`.
### Checklist
- [x] Backward compatible (all new config fields have defaults)
- [x] Uses `peft` via lazy imports (no hard dependency)
- [x] Unit tests added
- [x] Online HF training only (`eagle_offline=True` blocked)
---------
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
0 commit comments