Skip to content

Commit 9ae5302

Browse files
committed
revert behavior change
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent f488231 commit 9ae5302

3 files changed

Lines changed: 43 additions & 66 deletions

File tree

modelopt/torch/speculative/plugins/hf_dflash.py

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,18 @@
5050
lazy rope pattern needed for MLA models.
5151
"""
5252

53-
import contextlib
5453
import logging
5554

5655
import torch
5756
import torch.nn.functional as F
58-
from torch.nn import CrossEntropyLoss
5957
from transformers import PreTrainedModel
6058
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config
6159
from transformers.trainer_pt_utils import LabelSmoother
6260
from transformers.utils import ModelOutput
6361

6462
from ..dflash.conversion import DFlashDMRegistry
6563
from ..dflash.dflash_model import DFlashModel
66-
from .modeling_dflash import ( # noqa: F401
67-
DFlashAttention,
68-
DFlashBaseModelOutput,
69-
DFlashModule,
70-
build_target_layer_ids,
71-
)
64+
from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401
7265
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS
7366

7467
logger = logging.getLogger(__name__)
@@ -121,25 +114,6 @@ def _find_base_model_parts(self):
121114
else:
122115
raise ValueError(f"Part {name} not found in model")
123116

124-
def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs):
125-
"""Run the base model forward pass with optional freeze and base-model loss."""
126-
ctx = torch.no_grad() if freeze else contextlib.nullcontext()
127-
with ctx:
128-
outputs = super().forward(
129-
input_ids=input_ids,
130-
attention_mask=attention_mask,
131-
output_hidden_states=True,
132-
**kwargs,
133-
)
134-
base_loss = None
135-
if not freeze and labels is not None:
136-
loss_fct = CrossEntropyLoss()
137-
base_loss = loss_fct(
138-
outputs.logits.view(-1, outputs.logits.shape[-1]),
139-
labels.view(-1),
140-
)
141-
return outputs, base_loss
142-
143117
def modify(self, config):
144118
"""Initialize DFlash draft module."""
145119
super().modify(config)
@@ -406,16 +380,6 @@ def _compute_loss(
406380

407381
return loss, accuracy
408382

409-
def _dflash_base_model_forward(
410-
self, input_ids, attention_mask, freeze=True
411-
) -> DFlashBaseModelOutput:
412-
"""Run base model and extract target hidden states for DFlash."""
413-
outputs, _ = self._base_model_forward(input_ids, attention_mask, freeze=freeze)
414-
# hidden_states[0] is the embedding output; layer i output is at index i+1
415-
selected = [outputs.hidden_states[lid + 1] for lid in self.target_layer_ids]
416-
target_hidden = torch.cat(selected, dim=-1)
417-
return DFlashBaseModelOutput(target_hidden=target_hidden, logits=outputs.logits)
418-
419383
def forward(
420384
self,
421385
input_ids=None,
@@ -464,10 +428,18 @@ def forward(
464428
f"Adjust training_seq_len or use padding."
465429
)
466430

467-
# 1. Run base model → extract target hidden states
468-
base_outputs = self._dflash_base_model_forward(
469-
input_ids, attention_mask, freeze=self.dflash_freeze_base_model
470-
)
431+
# 1. Run base model → hidden states
432+
# TODO: For co-training the base model, remove no_grad and eval() switch.
433+
with torch.no_grad():
434+
base_outputs = super().forward(
435+
input_ids=input_ids,
436+
attention_mask=attention_mask,
437+
output_hidden_states=True,
438+
)
439+
440+
offset = 1
441+
selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids]
442+
target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H]
471443

472444
# 2. Build loss mask.
473445
# When labels are provided (answer_only_loss), they already encode both
@@ -497,18 +469,13 @@ def forward(
497469
)
498470
full_pos = self._build_position_ids(seq_len, anchor_positions, device)
499471
attn_mask = self._build_draft_attention_mask(
500-
seq_len,
501-
anchor_positions,
502-
block_keep_mask,
503-
n_blocks,
504-
base_outputs.target_hidden.dtype,
505-
device,
472+
seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device
506473
)
507474

508475
# 5. Draft forward
509476
hidden = self.dflash_module(
510477
noise_embedding=noise_embedding,
511-
target_hidden=base_outputs.target_hidden,
478+
target_hidden=target_hidden,
512479
position_ids=full_pos,
513480
attention_mask=attn_mask,
514481
)
@@ -582,14 +549,29 @@ def pseudo_speculative_generate(self, input_ids, steps=1):
582549
base_token: Next token from base model [B, 1].
583550
draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1.
584551
"""
585-
base_outputs = self._dflash_base_model_forward(input_ids, attention_mask=None, freeze=True)
586-
assert base_outputs.logits is not None
587-
base_token = base_outputs.logits[:, -1:, :].argmax(dim=-1).to(input_ids.device)
552+
# Call the base model's inner model directly (avoids DynamicModule dispatch)
553+
model_output = self._base_model(
554+
input_ids=input_ids,
555+
output_hidden_states=True,
556+
)
557+
# Compute logits via lm_head
558+
base_logits = self._base_model_lm_head(model_output.last_hidden_state)
559+
# Build output with hidden_states
560+
base_outputs = ModelOutput(
561+
logits=base_logits,
562+
hidden_states=model_output.hidden_states,
563+
)
564+
base_logits = base_outputs.logits
565+
base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device)
588566

589567
if steps < 1:
590568
return base_token, None
591569

592-
target_hidden = base_outputs.target_hidden
570+
# Extract target hidden states (raw, before FC projection)
571+
hid_offset = 1
572+
selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
573+
target_hidden = torch.cat(selected, dim=-1)
574+
593575
block_size = self.dflash_block_size
594576
bsz = input_ids.shape[0]
595577
device = input_ids.device

modelopt/torch/speculative/plugins/hf_eagle.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,25 @@ def _nvtx_range(self, name):
8585

8686
def _find_base_model_parts(self):
8787
"""Find model parts from different models and set base_{part}_path attributes."""
88-
for name, paths in {
88+
base_model_parts_mapping = {
8989
"base_model_path": _BASE_MODEL_PATHS,
9090
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
9191
"base_model_lm_head_path": _LM_HEAD_PATHS,
92-
}.items():
92+
}
93+
94+
for name, paths in base_model_parts_mapping.items():
95+
found_submodule = False
9396
for path in paths:
9497
try:
9598
submodule = self.get_submodule(path)
9699
assert isinstance(submodule, torch.nn.Module)
100+
print(f"Found {name} at {path}")
101+
found_submodule = True
97102
setattr(self, name, path)
98103
break
99104
except Exception:
100105
continue
101-
else:
106+
if not found_submodule:
102107
raise ValueError(f"Part {name} not found in model")
103108

104109
def _activate_torch_compile(self):

modelopt/torch/speculative/plugins/modeling_dflash.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
The draft architecture is independent of the target model.
2121
"""
2222

23-
from dataclasses import dataclass
24-
2523
import torch
2624
from torch import nn
2725
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
@@ -32,15 +30,7 @@
3230
)
3331
from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half
3432

35-
__all__ = ["DFlashBaseModelOutput", "DFlashModule", "build_target_layer_ids"]
36-
37-
38-
@dataclass
39-
class DFlashBaseModelOutput:
40-
"""Output container for base model forward pass in DFlash training."""
41-
42-
target_hidden: torch.Tensor # concatenated hidden states from target layers [B, seq, N*H]
43-
logits: torch.Tensor | None = None # base model logits [B, seq, vocab]
33+
__all__ = ["DFlashModule", "build_target_layer_ids"]
4434

4535

4636
def build_target_layer_ids(num_target_layers, num_draft_layers):

0 commit comments

Comments
 (0)