Skip to content

Commit cadfba6

Browse files
committed
squash: streaming dataset
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 40a4dd3 commit cadfba6

12 files changed

Lines changed: 1434 additions & 11 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def make_speculative_data_module(
5959
train_len=None,
6060
answer_only_loss=False,
6161
shift_labels=True,
62+
seed: int = 0,
6263
) -> dict:
6364
"""Create data module for speculative decoding training.
6465
@@ -74,7 +75,36 @@ def make_speculative_data_module(
7475
chat_template = f.read()
7576
print_rank_0(f"Loaded chat template from {template_path}")
7677

77-
if data_args.offline_data_path is None:
78+
mode = getattr(data_args, "mode", "online")
79+
if mode == "streaming":
80+
# ``train_len`` right-truncates during tokenization and is also the collator's
81+
# pad target; caller must ensure ``train_len <= vllm.max_model_len``.
82+
print_rank_0(f"Streaming hidden states from {data_args.streaming_server_url}")
83+
from modelopt.torch.speculative.plugins.hf_streaming_dataset import (
84+
EagleVllmStreamingConfig,
85+
EagleVllmStreamingDataset,
86+
)
87+
88+
ds = load_dataset("json", data_files=data_args.data_path, split="train")
89+
if data_args.sample_size > 0:
90+
ds = ds.select(range(data_args.sample_size))
91+
streaming_cfg = EagleVllmStreamingConfig(
92+
server_url=data_args.streaming_server_url,
93+
model=data_args.streaming_model_name,
94+
shared_storage_root=data_args.streaming_shared_storage_path,
95+
max_seq_len=train_len,
96+
answer_only_loss=answer_only_loss,
97+
prefetch=data_args.streaming_prefetch,
98+
seed=seed,
99+
)
100+
train_dataset = EagleVllmStreamingDataset(
101+
entries=ds,
102+
tokenizer=tokenizer,
103+
config=streaming_cfg,
104+
)
105+
data_collator = EagleOfflineDataCollator(train_len=train_len)
106+
107+
elif mode == "online":
78108
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
79109

80110
if not data_args.vlm_processor:

examples/speculative_decoding/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,8 @@ def train():
159159
training_args = HfTrainingArguments(**recipe.training.model_dump())
160160
init_distributed_env(training_args)
161161

162-
if not dry_run and not recipe.data.data_path and not recipe.data.offline_data_path:
163-
raise ValueError(
164-
"Either data.data_path or data.offline_data_path must be set in the config."
165-
)
162+
if not dry_run and recipe.data.mode in ("online", "streaming") and not recipe.data.data_path:
163+
raise ValueError(f"data.mode={recipe.data.mode!r} requires data.data_path.")
166164
if training_args.cp_size > 1:
167165
patch_ring_attention_for_ttt()
168166
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
@@ -181,7 +179,7 @@ def train():
181179

182180
checkpoint = training_args.resume_from_checkpoint or last_checkpoint
183181

184-
use_offline_training = recipe.data.offline_data_path is not None
182+
use_offline_training = recipe.data.mode != "online"
185183

186184
if checkpoint:
187185
with patch_transformers5_params_loading():
@@ -269,6 +267,7 @@ def train():
269267
train_len=training_args.training_seq_len,
270268
answer_only_loss=training_args.answer_only_loss,
271269
shift_labels=not is_dflash,
270+
seed=training_args.seed,
272271
)
273272

274273
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
@@ -278,6 +277,13 @@ def train():
278277
and recipe.eagle.eagle_base_lora_warmup_steps > 0
279278
):
280279
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
280+
if recipe.data.mode == "streaming":
281+
# Skip-on-resume happens inside the dataset (no re-fetch from server);
282+
# disable HF Trainer's own data skip so the offset isn't applied twice.
283+
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback
284+
285+
training_args.ignore_data_skip = True
286+
callbacks.append(StreamingResumeCallback())
281287

282288
trainer = EagleTrainerWithAccLog(
283289
model=model,

modelopt/recipe/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase):
149149

150150
@model_validator(mode="after")
151151
def _derive_eagle_offline(self) -> ModelOptEagleRecipe:
152-
self.eagle.eagle_offline = self.data.offline_data_path is not None
152+
self.eagle.eagle_offline = self.data.mode != "online"
153153
return self
154154

155155
@model_validator(mode="after")

modelopt/torch/speculative/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ class EagleConfig(ModeloptBaseConfig):
139139
eagle_offline: bool = ModeloptField(
140140
default=False,
141141
description=(
142-
"Whether to use detached Eagle. Derived by ModelOptEagleRecipe from "
143-
"data.offline_data_path; not user-configurable."
142+
"Whether the Eagle module consumes pre-computed hidden states (offline or streaming) "
143+
"instead of running the base model in-process. Derived by ModelOptEagleRecipe from "
144+
"``data.mode``; not user-configurable."
144145
),
145146
)
146147

0 commit comments

Comments
 (0)