Skip to content

Commit 9e4eeb0

Browse files
committed
[Refactor]: HFSpecDecMixin shared across HF spec-decoding plugins
Extract duplicated base-model discovery, forward pass, NVTX profiling, and torch.compile logic from HFEagleModel / HFDFlashModel into a shared mixin (hf_spec_mixin.py). HFEagleModel and HFDFlashModel now inherit from (HFSpecDecMixin, EagleModel/DFlashModel). Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent f208109 commit 9e4eeb0

3 files changed

Lines changed: 187 additions & 132 deletions

File tree

modelopt/torch/speculative/plugins/hf_dflash.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -50,96 +50,34 @@
5050
lazy rope pattern needed for MLA models.
5151
"""
5252

53-
import contextlib
5453
import logging
5554

5655
import torch
5756
import torch.nn.functional as F
58-
from torch.nn import CrossEntropyLoss
5957
from transformers import PreTrainedModel
6058
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config
6159
from transformers.trainer_pt_utils import LabelSmoother
6260
from transformers.utils import ModelOutput
6361

6462
from ..dflash.conversion import DFlashDMRegistry
6563
from ..dflash.dflash_model import DFlashModel
64+
from .hf_spec_mixin import HFSpecDecMixin
6665
from .modeling_dflash import ( # noqa: F401
6766
DFlashAttention,
6867
DFlashBaseModelOutput,
6968
DFlashModule,
7069
build_target_layer_ids,
7170
)
72-
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS
7371

7472
logger = logging.getLogger(__name__)
7573

7674
__all__ = ["HFDFlashModel"]
7775

7876

7977
@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
80-
class HFDFlashModel(DFlashModel):
78+
class HFDFlashModel(HFSpecDecMixin, DFlashModel):
8179
"""DFlash Model for HuggingFace transformers."""
8280

83-
@property
84-
def _base_model(self):
85-
return self.get_submodule(self.base_model_path)
86-
87-
@property
88-
def _base_model_embeddings(self):
89-
return self.get_submodule(self.base_model_embeddings_path)
90-
91-
@property
92-
def _base_model_lm_head(self):
93-
return self.get_submodule(self.base_model_lm_head_path)
94-
95-
@property
96-
def _base_llm_config(self):
97-
return (
98-
getattr(self.config, "text_config", None)
99-
or getattr(self.config, "llm_config", None)
100-
or self.config
101-
)
102-
103-
def _find_base_model_parts(self):
104-
"""Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.
105-
106-
Reuses the shared path constants from modeling_fakebase (same as EAGLE).
107-
"""
108-
for name, paths in {
109-
"base_model_path": _BASE_MODEL_PATHS,
110-
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
111-
"base_model_lm_head_path": _LM_HEAD_PATHS,
112-
}.items():
113-
for path in paths:
114-
try:
115-
submodule = self.get_submodule(path)
116-
assert isinstance(submodule, torch.nn.Module)
117-
setattr(self, name, path)
118-
break
119-
except Exception:
120-
continue
121-
else:
122-
raise ValueError(f"Part {name} not found in model")
123-
124-
def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs):
125-
"""Run the base model forward pass with optional freeze and base-model loss."""
126-
ctx = torch.no_grad() if freeze else contextlib.nullcontext()
127-
with ctx:
128-
outputs = super().forward(
129-
input_ids=input_ids,
130-
attention_mask=attention_mask,
131-
output_hidden_states=True,
132-
**kwargs,
133-
)
134-
base_loss = None
135-
if not freeze and labels is not None:
136-
loss_fct = CrossEntropyLoss()
137-
base_loss = loss_fct(
138-
outputs.logits.view(-1, outputs.logits.shape[-1]),
139-
labels.view(-1),
140-
)
141-
return outputs, base_loss
142-
14381
def modify(self, config):
14482
"""Initialize DFlash draft module."""
14583
super().modify(config)

modelopt/torch/speculative/plugins/hf_eagle.py

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
get_ttt_msk_func,
3737
temporary_set_config_value,
3838
)
39+
from .hf_spec_mixin import HFSpecDecMixin
3940
from .modeling_eagle import EagleBaseModelOutput, EagleModule
40-
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS
4141

4242
__all__ = ["HFARValidation", "HFEagleModel"]
4343

@@ -47,75 +47,14 @@
4747

4848

4949
@EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
50-
class HFEagleModel(EagleModel):
50+
class HFEagleModel(HFSpecDecMixin, EagleModel):
5151
"""Eagle Model Class for huggingface models."""
5252

53-
@property
54-
def _base_model(self):
55-
return self.get_submodule(self.base_model_path)
56-
57-
@property
58-
def _base_model_embeddings(self):
59-
return self.get_submodule(self.base_model_embeddings_path)
60-
61-
@property
62-
def _base_model_lm_head(self):
63-
return self.get_submodule(self.base_model_lm_head_path)
64-
65-
@property
66-
def _base_llm_config(self):
67-
"""Return the llm config for the base model, from LLM or VLM."""
68-
return (
69-
getattr(self.config, "text_config", None)
70-
or getattr(self.config, "llm_config", None)
71-
or self.config
72-
)
73-
74-
def _nvtx_range(self, name):
75-
"""Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
76-
if not self.eagle_enable_nvtx:
77-
return contextlib.nullcontext()
78-
try:
79-
import torch.cuda.nvtx as nvtx
80-
81-
return nvtx.range(name)
82-
except Exception as e:
83-
print(f"Failed to create NVTX range {name}: {e}")
84-
return contextlib.nullcontext()
85-
86-
def _find_base_model_parts(self):
87-
"""Find model parts from different models and set base_{part}_path attributes."""
88-
for name, paths in {
89-
"base_model_path": _BASE_MODEL_PATHS,
90-
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
91-
"base_model_lm_head_path": _LM_HEAD_PATHS,
92-
}.items():
93-
for path in paths:
94-
try:
95-
submodule = self.get_submodule(path)
96-
assert isinstance(submodule, torch.nn.Module)
97-
setattr(self, name, path)
98-
break
99-
except Exception:
100-
continue
101-
else:
102-
raise ValueError(f"Part {name} not found in model")
103-
104-
def _activate_torch_compile(self):
105-
import torch._dynamo
106-
107-
torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
108-
109-
compile_targets = [
110-
("_prepare_eagle_inputs", {}),
111-
("_eagle_forward", {"mode": "max-autotune"}),
112-
("_eagle_loss", {"fullgraph": True}),
113-
]
114-
for name, kwargs in compile_targets:
115-
try:
116-
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
117-
except Exception: # noqa: PERF203
118-
print(f"Disabling torch.compile for {name} due to compilation error.")
53+
_compile_targets = [
54+
("_prepare_eagle_inputs", {}),
55+
("_eagle_forward", {"mode": "max-autotune"}),
56+
("_eagle_loss", {"fullgraph": True}),
57+
]
11958

12059
def get_dummy_inputs(self) -> dict:
12160
"""Construct dummy inputs for export forward pass."""
@@ -285,6 +224,9 @@ def modify(
285224
if self.eagle_config._attn_implementation is None:
286225
self.eagle_config._attn_implementation = "sdpa"
287226

227+
# Mixin interface attribute
228+
self._enable_nvtx = self.eagle_enable_nvtx
229+
288230
# Set default aux_hidden_state layers
289231
if (
290232
self.eagle_config.use_aux_hidden_state
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared mixin for HuggingFace speculative decoding model classes."""
17+
18+
# mypy: disable-error-code="attr-defined,misc"
19+
20+
import contextlib
21+
22+
import torch
23+
from torch.nn import CrossEntropyLoss
24+
25+
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS
26+
27+
__all__ = ["HFSpecDecMixin"]
28+
29+
30+
class HFSpecDecMixin:
31+
"""Mixin providing HuggingFace base-model discovery for speculative decoding plugins.
32+
33+
Provides shared properties and methods for locating base-model submodules
34+
(backbone, embeddings, lm_head) and running the base-model forward pass.
35+
36+
Must be used with multiple inheritance alongside an algorithm-specific base
37+
(EagleModel, DFlashModel, etc.) that inherits from DynamicModule.
38+
39+
Example::
40+
41+
@EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
42+
class HFEagleModel(HFSpecDecMixin, EagleModel): ...
43+
"""
44+
45+
# -- Class attributes (subclasses may override) --
46+
47+
# List of (method_name, compile_kwargs) for _activate_torch_compile().
48+
# Example: [("_eagle_forward", {"mode": "max-autotune"}), ("_eagle_loss", {"fullgraph": True})]
49+
_compile_targets: list[tuple[str, dict]] = []
50+
51+
# -- Properties: base model access --
52+
53+
@property
54+
def _base_model(self):
55+
return self.get_submodule(self.base_model_path)
56+
57+
@property
58+
def _base_model_embeddings(self):
59+
return self.get_submodule(self.base_model_embeddings_path)
60+
61+
@property
62+
def _base_model_lm_head(self):
63+
return self.get_submodule(self.base_model_lm_head_path)
64+
65+
@property
66+
def _base_llm_config(self):
67+
"""Return the LLM config for the base model, handling VLM nesting."""
68+
return (
69+
getattr(self.config, "text_config", None)
70+
or getattr(self.config, "llm_config", None)
71+
or self.config
72+
)
73+
74+
# -- Methods: model discovery --
75+
76+
def _find_base_model_parts(self):
77+
"""Find model parts from different models and set base_{part}_path attributes.
78+
79+
Iterates over candidate submodule paths from modeling_fakebase to locate the
80+
base model backbone, embedding layer, and LM head.
81+
82+
Raises:
83+
ValueError: If any required model part cannot be found.
84+
"""
85+
for name, paths in {
86+
"base_model_path": _BASE_MODEL_PATHS,
87+
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
88+
"base_model_lm_head_path": _LM_HEAD_PATHS,
89+
}.items():
90+
for path in paths:
91+
try:
92+
submodule = self.get_submodule(path)
93+
assert isinstance(submodule, torch.nn.Module)
94+
setattr(self, name, path)
95+
break
96+
except Exception:
97+
continue
98+
else:
99+
raise ValueError(f"Part {name} not found in model")
100+
101+
# -- Methods: base model forward --
102+
103+
def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs):
104+
"""Run the base model forward pass with optional freeze and base-model loss.
105+
106+
Args:
107+
input_ids: Input token IDs.
108+
attention_mask: Attention mask.
109+
freeze: If True, run under torch.no_grad().
110+
labels: Optional labels for computing base model CE loss.
111+
**kwargs: Additional keyword arguments forwarded to the base model.
112+
113+
Returns:
114+
(outputs, base_loss) tuple where outputs is the raw model output and
115+
base_loss is the cross-entropy loss (None if freeze=True or labels=None).
116+
"""
117+
ctx = torch.no_grad() if freeze else contextlib.nullcontext()
118+
with ctx:
119+
outputs = super().forward(
120+
input_ids=input_ids,
121+
attention_mask=attention_mask,
122+
output_hidden_states=True,
123+
**kwargs,
124+
)
125+
base_loss = None
126+
if not freeze and labels is not None:
127+
loss_fct = CrossEntropyLoss()
128+
base_loss = loss_fct(
129+
outputs.logits.view(-1, outputs.logits.shape[-1]),
130+
labels.view(-1),
131+
)
132+
return outputs, base_loss
133+
134+
# -- Methods: profiling & compilation --
135+
136+
def _nvtx_range(self, name):
137+
"""Optionally create an NVTX range for profiling.
138+
139+
Enabled when the subclass sets ``self._enable_nvtx = True`` in ``modify()``.
140+
"""
141+
if not getattr(self, "_enable_nvtx", False):
142+
return contextlib.nullcontext()
143+
try:
144+
import torch.cuda.nvtx as nvtx
145+
146+
return nvtx.range(name)
147+
except Exception as e:
148+
print(f"Failed to create NVTX range {name}: {e}")
149+
return contextlib.nullcontext()
150+
151+
def _activate_torch_compile(self):
152+
"""Apply ``torch.compile`` to methods listed in ``_compile_targets``.
153+
154+
Each entry is ``(method_name, extra_kwargs)`` passed to ``torch.compile(..., dynamic=False)``.
155+
Failures fall back to eager mode silently.
156+
"""
157+
import torch._dynamo
158+
159+
torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
160+
161+
for name, kwargs in self._compile_targets:
162+
try:
163+
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
164+
except Exception: # noqa: PERF203
165+
print(f"Disabling torch.compile for {name} due to compilation error.")
166+
167+
# -- Methods: export interface (subclasses must override) --
168+
169+
def get_dummy_inputs(self) -> dict:
170+
"""Construct dummy inputs for export forward pass. Subclasses must override."""
171+
raise NotImplementedError
172+
173+
def get_exporter(self):
174+
"""Return the exporter for the draft model. Subclasses must override."""
175+
raise NotImplementedError

0 commit comments

Comments
 (0)