Skip to content

Commit e5c46c7

Browse files
committed
refactor: interface
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 864dc89 commit e5c46c7

6 files changed

Lines changed: 40 additions & 21 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def make_speculative_data_module(
7474
chat_template = f.read()
7575
print_rank_0(f"Loaded chat template from {template_path}")
7676

77-
streaming_url = getattr(data_args, "streaming_server_url", None)
78-
if streaming_url is not None:
79-
# Streaming: trainer is a client of a running vllm serve
80-
print_rank_0(f"Streaming hidden states from {streaming_url}")
77+
mode = getattr(data_args, "mode", "online")
78+
if mode == "streaming":
79+
# Trainer is an HTTP client of a running vllm serve; samples stream in lazily.
80+
print_rank_0(f"Streaming hidden states from {data_args.streaming_server_url}")
8181
from modelopt.torch.speculative.plugins.hf_streaming_dataset import (
8282
StreamingHiddenStatesDataset,
8383
)
@@ -89,15 +89,15 @@ def make_speculative_data_module(
8989
train_dataset = StreamingHiddenStatesDataset(
9090
entries=entries,
9191
tokenizer=tokenizer,
92-
server_url=streaming_url,
92+
server_url=data_args.streaming_server_url,
9393
model=data_args.streaming_model_name,
9494
max_seq_len=data_args.streaming_max_seq_len,
9595
answer_only_loss=answer_only_loss,
9696
prefetch=data_args.streaming_prefetch,
9797
)
9898
data_collator = EagleOfflineDataCollator(train_len=train_len)
9999

100-
elif data_args.offline_data_path is None:
100+
elif mode == "online":
101101
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
102102

103103
if not data_args.vlm_processor:

examples/speculative_decoding/main.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def train():
152152
training_args = HfTrainingArguments(**recipe.training.model_dump())
153153
init_distributed_env(training_args)
154154

155-
if not recipe.data.data_path and not recipe.data.offline_data_path:
156-
raise ValueError(
157-
"Either data.data_path or data.offline_data_path must be set in the config."
158-
)
155+
if recipe.data.mode in ("online", "streaming") and not recipe.data.data_path:
156+
raise ValueError(f"data.mode={recipe.data.mode!r} requires data.data_path.")
159157
if training_args.cp_size > 1:
160158
patch_ring_attention_for_ttt()
161159
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
@@ -174,10 +172,7 @@ def train():
174172

175173
checkpoint = training_args.resume_from_checkpoint or last_checkpoint
176174

177-
use_offline_training = (
178-
recipe.data.offline_data_path is not None
179-
or recipe.data.streaming_server_url is not None
180-
)
175+
use_offline_training = recipe.data.mode != "online"
181176

182177
if checkpoint:
183178
with patch_transformers5_params_loading():

modelopt/recipe/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,7 @@ class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase):
156156

157157
@model_validator(mode="after")
158158
def _derive_eagle_offline(self) -> ModelOptEagleRecipe:
159-
self.eagle.eagle_offline = (
160-
self.data.offline_data_path is not None or self.data.streaming_server_url is not None
161-
)
159+
self.eagle.eagle_offline = self.data.mode != "online"
162160
return self
163161

164162
@model_validator(mode="after")

modelopt/torch/speculative/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ class EagleConfig(ModeloptBaseConfig):
139139
eagle_offline: bool = ModeloptField(
140140
default=False,
141141
description=(
142-
"Whether to use detached Eagle. Derived by ModelOptEagleRecipe from "
143-
"data.offline_data_path; not user-configurable."
142+
"Whether the Eagle module consumes pre-computed hidden states (offline or streaming) "
143+
"instead of running the base model in-process. Derived by ModelOptEagleRecipe from "
144+
"``data.mode``; not user-configurable."
144145
),
145146
)
146147

modelopt/torch/speculative/plugins/hf_training_args.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
from __future__ import annotations
3131

32-
from pydantic import BaseModel, ConfigDict, field_validator
32+
from typing import Literal
33+
34+
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
3335

3436

3537
class ModelArguments(BaseModel):
@@ -47,6 +49,7 @@ class DataArguments(BaseModel):
4749

4850
model_config = ConfigDict(extra="forbid")
4951

52+
mode: Literal["online", "offline", "streaming"] = "online"
5053
data_path: str | None = None
5154
offline_data_path: str | None = None
5255
lazy_preprocess: bool = True
@@ -67,6 +70,25 @@ def _check_sample_size(cls, v: int) -> int:
6770
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
6871
return v
6972

73+
@model_validator(mode="after")
74+
def _check_mode_requirements(self) -> DataArguments:
75+
# Backward-compat: if ``mode`` is left at its default ("online") but a mode-specific
76+
# field is set, promote to the corresponding mode. Explicit ``mode`` always wins.
77+
if self.mode == "online":
78+
if self.offline_data_path is not None:
79+
self.mode = "offline"
80+
elif self.streaming_server_url is not None:
81+
self.mode = "streaming"
82+
if self.mode == "offline" and not self.offline_data_path:
83+
raise ValueError("data.mode='offline' requires data.offline_data_path")
84+
if self.mode == "streaming" and not (
85+
self.streaming_server_url and self.streaming_model_name
86+
):
87+
raise ValueError(
88+
"data.mode='streaming' requires data.streaming_server_url and data.streaming_model_name"
89+
)
90+
return self
91+
7092

7193
class TrainingArguments(BaseModel):
7294
"""Speculative-decoding extensions on top of ``transformers.TrainingArguments``.

modelopt_recipes/general/speculative_decoding/eagle3.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ model:
1313

1414
# maps to DataArguments (main.py)
1515
data:
16+
# online | offline | streaming. Set offline_data_path for offline; set
17+
# streaming_server_url + streaming_model_name for streaming.
18+
mode: online
1619
data_path: input_conversations/train.jsonl
1720
offline_data_path:
1821
draft_vocab_cache:
@@ -48,7 +51,7 @@ training:
4851

4952
# maps to EagleConfig (modelopt/torch/speculative/config.py).
5053
eagle:
51-
# eagle_offline is derived from data.offline_data_path; do not set here.
54+
# eagle_offline is derived from data.mode; do not set here.
5255
eagle_decoder_type: llama
5356
eagle_ttt_steps: 3
5457
eagle_mix_hidden_states: false

0 commit comments

Comments
 (0)