Skip to content

Commit f208109

Browse files
committed
[Feat]: Offline DFlash training
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent f488231 commit f208109

5 files changed

Lines changed: 77 additions & 27 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
import modelopt.torch.opt as mto
5151
import modelopt.torch.speculative as mtsp
52-
from modelopt.torch.speculative.config import EagleConfig
52+
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
5353
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5454
from modelopt.torch.utils import print_rank_0
5555

@@ -303,18 +303,9 @@ def train():
303303
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
304304
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
305305
elif training_args.mode == "dflash":
306-
# Auto-detect mask_token_id from tokenizer if not set
307-
if not dflash_cfg.get("dflash_mask_token_id"):
308-
if tokenizer.mask_token_id is not None:
309-
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
310-
print_rank_0(
311-
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
312-
)
313-
else:
314-
raise ValueError(
315-
"mask_token_id not found in tokenizer and not set in config. "
316-
"Set dflash.dflash_mask_token_id in the training YAML."
317-
)
306+
dflash_cfg = DFlashConfig.model_validate(
307+
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
308+
).model_dump()
318309
mtsp.convert(model, [("dflash", dflash_cfg)])
319310
else:
320311
raise Exception(f"{training_args.mode} is not supported!")

modelopt/torch/speculative/config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _get_dflash_default_config():
6767
class DFlashConfig(ModeloptBaseConfig):
6868
"""DFlash config for block-wise parallel speculative decoding."""
6969

70+
dflash_offline: bool = ModeloptField(
71+
default=False,
72+
description="Whether to use detached DFlash (offline training from pre-computed hidden states).",
73+
)
74+
7075
dflash_block_size: int = ModeloptField(
7176
default=8,
7277
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
@@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
110115
description="Whether to use torch.compile on DFlash forward/loss methods.",
111116
)
112117

118+
@model_validator(mode="before")
119+
@classmethod
120+
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
121+
"""Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
122+
ctx = info.context if info.context else {}
123+
data_args = ctx.get("data_args")
124+
if data_args is not None and isinstance(data, dict):
125+
data["dflash_offline"] = data_args.offline_data_path is not None
126+
return data
127+
128+
@model_validator(mode="before")
129+
@classmethod
130+
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
131+
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
132+
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
133+
return data
134+
ctx = info.context if info.context else {}
135+
tokenizer = ctx.get("tokenizer")
136+
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
137+
data["dflash_mask_token_id"] = tokenizer.mask_token_id
138+
return data
139+
140+
@model_validator(mode="after")
141+
def _check_mask_token_id(self) -> "DFlashConfig":
142+
"""Validate that mask_token_id is set after all resolution attempts."""
143+
if self.dflash_mask_token_id is None:
144+
raise ValueError(
145+
"dflash_mask_token_id is required. Set it in the config YAML "
146+
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147+
"has a mask_token_id attribute."
148+
)
149+
return self
150+
113151

114152
class MedusaConfig(ModeloptBaseConfig):
115153
"""Medusa config."""

modelopt/torch/speculative/dflash/dflash_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _setup(self):
2727

2828
def modify(self, config):
2929
"""Base DFlash Model modify function. Child class should implement the details."""
30+
self.dflash_offline = config.dflash_offline
3031
self.dflash_block_size = config.dflash_block_size
3132
self.dflash_freeze_base_model = config.dflash_freeze_base_model
3233
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor

modelopt/torch/speculative/plugins/hf_dflash.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,17 @@ def modify(self, config):
181181
self.dflash_config.block_size = self.dflash_block_size
182182

183183
# Target layer IDs
184-
num_target_layers = base_config.num_hidden_layers
184+
num_target_layers = (
185+
base_config.num_orig_hidden_layers
186+
if self.dflash_offline
187+
else base_config.num_hidden_layers
188+
)
185189
num_draft_layers = self.dflash_config.num_hidden_layers
186190
self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers)
187191
self.dflash_config.target_layer_ids = self.target_layer_ids
188192

189-
# mask_token_id: set in DFlashConfig (or auto-detected by main.py from tokenizer)
190-
mask_id = config.dflash_mask_token_id
191-
if mask_id is None:
192-
raise ValueError(
193-
"dflash_mask_token_id is required. Set it in the config YAML "
194-
"(dflash.dflash_mask_token_id=TOKEN_ID) or let main.py auto-detect "
195-
"from tokenizer.mask_token_id."
196-
)
197-
self.mask_token_id = mask_id
193+
# mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context
194+
self.mask_token_id = config.dflash_mask_token_id
198195
logger.info("DFlash mask_token_id: %s", self.mask_token_id)
199196

200197
# Freeze base model
@@ -207,10 +204,17 @@ def modify(self, config):
207204
self.dflash_module = DFlashModule(self.dflash_config)
208205
# Match base model dtype/device. Skip if base is on meta (during from_pretrained
209206
# restore — the model will be moved to the correct device after weight loading).
210-
base_device = next(self._base_model.layers[-1].parameters()).device
207+
if self.dflash_offline:
208+
base_device = self._base_model_lm_head.weight.device
209+
else:
210+
base_device = next(self._base_model.layers[-1].parameters()).device
211211
if base_device.type != "meta":
212212
self.dflash_module.to(self._base_model.dtype).to(base_device)
213213

214+
# Delete base model layers for offline training (save memory)
215+
if self.dflash_offline:
216+
self._base_model._modules.pop("layers")
217+
214218
self.is_quantized = False
215219
self._num_anchors = self.dflash_num_anchors
216220

@@ -465,9 +469,17 @@ def forward(
465469
)
466470

467471
# 1. Run base model → extract target hidden states
468-
base_outputs = self._dflash_base_model_forward(
469-
input_ids, attention_mask, freeze=self.dflash_freeze_base_model
470-
)
472+
if self.dflash_offline:
473+
assert "base_model_outputs" in kwargs
474+
base_outputs = DFlashBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"])
475+
if base_outputs.logits is None and self.dflash_self_logit_distillation:
476+
# Compute logits from last-layer hidden states for KD loss
477+
out_hiddens = kwargs["base_model_outputs"].get("base_model_hidden_states")
478+
base_outputs.logits = self._base_model_lm_head(out_hiddens)
479+
else:
480+
base_outputs = self._dflash_base_model_forward(
481+
input_ids, attention_mask, freeze=self.dflash_freeze_base_model
482+
)
471483

472484
# 2. Build loss mask.
473485
# When labels are provided (answer_only_loss), they already encode both

modelopt/torch/speculative/plugins/modeling_dflash.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ class DFlashBaseModelOutput:
4242
target_hidden: torch.Tensor # concatenated hidden states from target layers [B, seq, N*H]
4343
logits: torch.Tensor | None = None # base model logits [B, seq, vocab]
4444

45+
@classmethod
46+
def from_offline_dict(cls, d: dict):
47+
"""Construct from a dict of pre-computed base model outputs (offline training)."""
48+
return cls(
49+
target_hidden=d.get("aux_hidden_states"),
50+
logits=d.get("base_model_logits"),
51+
)
52+
4553

4654
def build_target_layer_ids(num_target_layers, num_draft_layers):
4755
"""Select layers uniformly from the target model for feature extraction."""

0 commit comments

Comments
 (0)