Skip to content

Commit 56f459f

Browse files
yeyu-nvidiaclaude
andcommitted
Address PR review feedback for eagle_base_lora feature
- Move peft imports (LoraConfig, inject_adapter_in_model, LoraLayer) inside the methods that use them (_inject_base_lora, _set_base_lora_enabled) so peft is not a hard top-level dependency for all speculative decoding users - Change eagle_base_lora_target_modules default from [] to None to avoid mutable default shared across config instances - Tighten LoRA key filtering from "lora_A" in k to ".lora_A." in k to avoid false positives, and add fail-fast RuntimeError when no LoRA tensors found Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 4087c80 commit 56f459f

3 files changed

Lines changed: 14 additions & 7 deletions

File tree

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,12 @@ def _export_lora(self, export_dir: Path, full_sd: dict):
195195
"""Export base model LoRA adapter weights alongside the eagle module artifacts."""
196196
from peft import LoraConfig
197197

198-
lora_sd = {k: v for k, v in full_sd.items() if "lora_A" in k or "lora_B" in k}
198+
lora_sd = {k: v for k, v in full_sd.items() if ".lora_A." in k or ".lora_B." in k}
199+
if not lora_sd:
200+
raise RuntimeError(
201+
"No LoRA adapter tensors found in the model state dict. "
202+
"Ensure eagle_base_lora=True and the model was converted with LoRA adapters."
203+
)
199204
save_file(lora_sd, export_dir / "lora_adapter_model.safetensors")
200205

201206
lora_config = LoraConfig(

modelopt/torch/speculative/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ class EagleConfig(ModeloptBaseConfig):
129129
description="LoRA alpha (scaling) for the base model adapters.",
130130
)
131131

132-
eagle_base_lora_target_modules: list = ModeloptField(
133-
default=[],
132+
eagle_base_lora_target_modules: list | None = ModeloptField(
133+
default=None,
134134
description=(
135135
"List of module name patterns to apply LoRA to in the base model "
136-
"(e.g. ['q_proj', 'v_proj']). Empty list uses peft defaults."
136+
"(e.g. ['q_proj', 'v_proj']). None uses peft defaults."
137137
),
138138
)
139139

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
import torch
3838
import transformers
3939
from packaging.version import Version
40-
from peft import LoraConfig
41-
from peft.mapping import inject_adapter_in_model
42-
from peft.tuners.lora import LoraLayer
4340
from torch import nn
4441
from torch.nn import CrossEntropyLoss
4542
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
@@ -552,6 +549,9 @@ def _get_eagle_device(self):
552549

553550
def _inject_base_lora(self):
554551
"""Inject HF PEFT LoRA adapters into the base model in-place and unfreeze them."""
552+
from peft import LoraConfig
553+
from peft.mapping import inject_adapter_in_model
554+
555555
target_modules = self.eagle_base_lora_target_modules or None
556556
lora_config = LoraConfig(
557557
r=self.eagle_base_lora_rank,
@@ -567,6 +567,8 @@ def _inject_base_lora(self):
567567

568568
def _set_base_lora_enabled(self, enabled: bool) -> None:
569569
"""Enable or disable LoRA adapters in the base model."""
570+
from peft.tuners.lora import LoraLayer
571+
570572
for module in self._base_model.modules():
571573
if isinstance(module, LoraLayer):
572574
module.enable_adapters(enabled)

0 commit comments

Comments
 (0)