Skip to content

Commit 89fefb1

Browse files
committed
squash: streaming dataset
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 7038dec commit 89fefb1

12 files changed

Lines changed: 1434 additions & 11 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def make_speculative_data_module(
5959
train_len=None,
6060
answer_only_loss=False,
6161
shift_labels=True,
62+
seed: int = 0,
6263
) -> dict:
6364
"""Create data module for speculative decoding training.
6465
@@ -74,7 +75,36 @@ def make_speculative_data_module(
7475
chat_template = f.read()
7576
print_rank_0(f"Loaded chat template from {template_path}")
7677

77-
if data_args.offline_data_path is None:
78+
mode = getattr(data_args, "mode", "online")
79+
if mode == "streaming":
80+
# ``train_len`` right-truncates during tokenization and is also the collator's
81+
# pad target; caller must ensure ``train_len <= vllm.max_model_len``.
82+
print_rank_0(f"Streaming hidden states from {data_args.streaming_server_url}")
83+
from modelopt.torch.speculative.plugins.hf_streaming_dataset import (
84+
EagleVllmStreamingConfig,
85+
EagleVllmStreamingDataset,
86+
)
87+
88+
ds = load_dataset("json", data_files=data_args.data_path, split="train")
89+
if data_args.sample_size > 0:
90+
ds = ds.select(range(data_args.sample_size))
91+
streaming_cfg = EagleVllmStreamingConfig(
92+
server_url=data_args.streaming_server_url,
93+
model=data_args.streaming_model_name,
94+
shared_storage_root=data_args.streaming_shared_storage_path,
95+
max_seq_len=train_len,
96+
answer_only_loss=answer_only_loss,
97+
prefetch=data_args.streaming_prefetch,
98+
seed=seed,
99+
)
100+
train_dataset = EagleVllmStreamingDataset(
101+
entries=ds,
102+
tokenizer=tokenizer,
103+
config=streaming_cfg,
104+
)
105+
data_collator = EagleOfflineDataCollator(train_len=train_len)
106+
107+
elif mode == "online":
78108
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
79109

80110
if not data_args.vlm_processor:

examples/speculative_decoding/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def train():
152152
training_args = HfTrainingArguments(**recipe.training.model_dump())
153153
init_distributed_env(training_args)
154154

155-
if not recipe.data.data_path and not recipe.data.offline_data_path:
156-
raise ValueError(
157-
"Either data.data_path or data.offline_data_path must be set in the config."
158-
)
155+
if recipe.data.mode in ("online", "streaming") and not recipe.data.data_path:
156+
raise ValueError(f"data.mode={recipe.data.mode!r} requires data.data_path.")
159157
if training_args.cp_size > 1:
160158
patch_ring_attention_for_ttt()
161159
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
@@ -174,7 +172,7 @@ def train():
174172

175173
checkpoint = training_args.resume_from_checkpoint or last_checkpoint
176174

177-
use_offline_training = recipe.data.offline_data_path is not None
175+
use_offline_training = recipe.data.mode != "online"
178176

179177
if checkpoint:
180178
with patch_transformers5_params_loading():
@@ -249,6 +247,7 @@ def train():
249247
train_len=training_args.training_seq_len,
250248
answer_only_loss=training_args.answer_only_loss,
251249
shift_labels=not is_dflash,
250+
seed=training_args.seed,
252251
)
253252

254253
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
@@ -258,6 +257,13 @@ def train():
258257
and recipe.eagle.eagle_base_lora_warmup_steps > 0
259258
):
260259
callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps))
260+
if recipe.data.mode == "streaming":
261+
# Skip-on-resume happens inside the dataset (no re-fetch from server);
262+
# disable HF Trainer's own data skip so the offset isn't applied twice.
263+
from modelopt.torch.speculative.plugins.hf_streaming_dataset import StreamingResumeCallback
264+
265+
training_args.ignore_data_skip = True
266+
callbacks.append(StreamingResumeCallback())
261267

262268
trainer = EagleTrainerWithAccLog(
263269
model=model,

modelopt/recipe/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase):
156156

157157
@model_validator(mode="after")
158158
def _derive_eagle_offline(self) -> ModelOptEagleRecipe:
159-
self.eagle.eagle_offline = self.data.offline_data_path is not None
159+
self.eagle.eagle_offline = self.data.mode != "online"
160160
return self
161161

162162
@model_validator(mode="after")

modelopt/torch/speculative/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ class EagleConfig(ModeloptBaseConfig):
139139
eagle_offline: bool = ModeloptField(
140140
default=False,
141141
description=(
142-
"Whether to use detached Eagle. Derived by ModelOptEagleRecipe from "
143-
"data.offline_data_path; not user-configurable."
142+
"Whether the Eagle module consumes pre-computed hidden states (offline or streaming) "
143+
"instead of running the base model in-process. Derived by ModelOptEagleRecipe from "
144+
"``data.mode``; not user-configurable."
144145
),
145146
)
146147

0 commit comments

Comments
 (0)