Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def make_speculative_data_module(
train_len=None,
answer_only_loss=False,
shift_labels=True,
seed: int = 0,
) -> dict:
"""Create data module for speculative decoding training.

Expand All @@ -74,7 +75,36 @@ def make_speculative_data_module(
chat_template = f.read()
print_rank_0(f"Loaded chat template from {template_path}")

if data_args.offline_data_path is None:
mode = getattr(data_args, "mode", "online")
if mode == "streaming":
# ``train_len`` right-truncates during tokenization and is also the collator's
# pad target; caller must ensure ``train_len <= vllm.max_model_len``.
print_rank_0(f"Streaming hidden states from {data_args.streaming_server_url}")
from modelopt.torch.speculative.plugins.hf_streaming_dataset import (
EagleVllmStreamingConfig,
EagleVllmStreamingDataset,
)

ds = load_dataset("json", data_files=data_args.data_path, split="train")
if data_args.sample_size > 0:
ds = ds.select(range(data_args.sample_size))
streaming_cfg = EagleVllmStreamingConfig(
server_url=data_args.streaming_server_url,
model=data_args.streaming_model_name,
shared_storage_root=data_args.streaming_shared_storage_path,
max_seq_len=train_len,
answer_only_loss=answer_only_loss,
prefetch=data_args.streaming_prefetch,
seed=seed,
)
train_dataset = EagleVllmStreamingDataset(
entries=ds,
tokenizer=tokenizer,
config=streaming_cfg,
)
data_collator = EagleOfflineDataCollator(train_len=train_len)

elif mode == "online":
train_dataset = ShardedDataset("json", data_files=data_args.data_path)

if not data_args.vlm_processor:
Expand Down
16 changes: 11 additions & 5 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ def train():
training_args = HfTrainingArguments(**recipe.training.model_dump())
init_distributed_env(training_args)

if not dry_run and not recipe.data.data_path and not recipe.data.offline_data_path:
raise ValueError(
"Either data.data_path or data.offline_data_path must be set in the config."
)
if not dry_run and recipe.data.mode in ("online", "streaming") and not recipe.data.data_path:
raise ValueError(f"data.mode={recipe.data.mode!r} requires data.data_path.")
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
Expand All @@ -181,7 +179,7 @@ def train():

checkpoint = training_args.resume_from_checkpoint or last_checkpoint

use_offline_training = recipe.data.offline_data_path is not None
use_offline_training = recipe.data.mode != "online"

if checkpoint:
with patch_transformers5_params_loading():
Expand Down Expand Up @@ -269,6 +267,7 @@ def train():
train_len=training_args.training_seq_len,
answer_only_loss=training_args.answer_only_loss,
shift_labels=not is_dflash,
seed=training_args.seed,
)

callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
Expand All @@ -278,6 +277,13 @@ def train():
and recipe.eagle.eagle_base_lora_warmup_steps > 0
):
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
if recipe.data.mode == "streaming":
# Skip-on-resume happens inside the dataset (no re-fetch from server);
# disable HF Trainer's own data skip so the offset isn't applied twice.
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback

training_args.ignore_data_skip = True
callbacks.append(StreamingResumeCallback())

trainer = EagleTrainerWithAccLog(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion modelopt/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase):

@model_validator(mode="after")
def _derive_eagle_offline(self) -> ModelOptEagleRecipe:
self.eagle.eagle_offline = self.data.offline_data_path is not None
self.eagle.eagle_offline = self.data.mode != "online"
return self

@model_validator(mode="after")
Expand Down
5 changes: 3 additions & 2 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ class EagleConfig(ModeloptBaseConfig):
eagle_offline: bool = ModeloptField(
default=False,
description=(
"Whether to use detached Eagle. Derived by ModelOptEagleRecipe from "
"data.offline_data_path; not user-configurable."
"Whether the Eagle module consumes pre-computed hidden states (offline or streaming) "
"instead of running the base model in-process. Derived by ModelOptEagleRecipe from "
"``data.mode``; not user-configurable."
),
)

Expand Down
Loading
Loading