Skip to content

Commit c79eb96

Browse files
committed
refactor
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent af62827 commit c79eb96

6 files changed

Lines changed: 70 additions & 75 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from datasets import load_dataset
2929
from packaging.version import Version
3030
from scripts.ar_validate import validate_ar
31-
from torch.distributed.tensor.experimental._attention import _SDPAMerger
31+
32+
# from torch.distributed.tensor.experimental._attention import _SDPAMerger
3233
from torch.utils.data import Dataset
3334
from transformers import Trainer, TrainerCallback
3435
from transformers.trainer_pt_utils import LabelSmoother
@@ -382,11 +383,11 @@ def patch_ring_attention_for_ttt():
382383
)
383384

384385
# 3. Patch merger to skip the blank shard to avoid difference in output.
385-
original_sdpa_merger_step = _SDPAMerger.step
386+
original_sdpa_merger_step = torch.distributed.tensor.experimental._attention._SDPAMerger.step
386387

387388
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
388389
if lse.sum() <= 0:
389390
return
390391
return original_sdpa_merger_step(self, out, lse, partial)
391392

392-
_SDPAMerger.step = patched_sdpa_merger_step
393+
torch.distributed.tensor.experimental._attention._SDPAMerger.step = patched_sdpa_merger_step

examples/speculative_decoding/main.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ class DataArguments:
7676
},
7777
)
7878
lazy_preprocess: bool = True
79-
draft_vocab_cache_dir: str = field(
80-
default="draft_vocab_cache",
81-
metadata={"help": "Path to the d2t cache directory."},
79+
draft_vocab_cache: str = field(
80+
default=None,
81+
metadata={"help": "Path to d2t.pt cache file."},
8282
)
8383
vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."})
8484
vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."})
@@ -97,7 +97,7 @@ class TrainingArguments(transformers.TrainingArguments):
9797
)
9898
dataloader_drop_last: bool = field(default=True)
9999
bf16: bool = field(default=True)
100-
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
100+
mode: Literal["eagle3", "medusa"] = "eagle3"
101101
estimate_ar: bool = field(
102102
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
103103
)
@@ -147,30 +147,35 @@ def train():
147147
training_args.parallelism_config.sp_backend = None
148148
print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}")
149149

150-
# Detecting last checkpoint.
151-
last_checkpoint = None
152-
if os.path.isdir(training_args.output_dir):
153-
last_checkpoint = get_last_checkpoint(training_args.output_dir)
150+
# Detect checkpoint to resume from
151+
last_checkpoint = (
152+
get_last_checkpoint(training_args.output_dir)
153+
if os.path.isdir(training_args.output_dir)
154+
else None
155+
)
156+
if last_checkpoint:
154157
print_rank_0(f"Last checkpoint detected: {last_checkpoint}")
155158

156-
checkpoint = None
157-
if training_args.resume_from_checkpoint is not None:
158-
checkpoint = training_args.resume_from_checkpoint
159-
elif last_checkpoint is not None:
160-
checkpoint = last_checkpoint
159+
checkpoint = training_args.resume_from_checkpoint or last_checkpoint
161160

162161
use_offline_training = data_args.offline_data_path is not None
163162

163+
model_config = transformers.AutoConfig.from_pretrained(
164+
model_args.model_name_or_path, trust_remote_code=True
165+
)
166+
if "vl" in model_config.model_type.lower():
167+
model_cls = transformers.AutoModelForVision2Seq
168+
else:
169+
model_cls = transformers.AutoModelForCausalLM
170+
164171
if checkpoint:
165-
model = transformers.AutoModelForCausalLM.from_pretrained(
166-
checkpoint, torch_dtype="auto", trust_remote_code=True
167-
)
172+
model = model_cls.from_pretrained(checkpoint, torch_dtype="auto", trust_remote_code=True)
168173
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
169174
else:
170175
# To avoid OOM for large models, we load and convert model on CPU first.
171176
# Model will be moved to GPU during HF trainer.init().
172177
offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
173-
model = transformers.Qwen3VLForConditionalGeneration.from_pretrained(
178+
model = model_cls.from_pretrained(
174179
model_args.model_name_or_path,
175180
torch_dtype="auto",
176181
device_map="cpu",
@@ -180,77 +185,38 @@ def train():
180185
if use_offline_training:
181186
# When doing offline training, we need to set num_hidden_layers
182187
# since we override it when loading the model for space savings
183-
model_config = transformers.AutoConfig.from_pretrained(
184-
model_args.model_name_or_path, trust_remote_code=True
185-
)
186188
model.config.num_orig_hidden_layers = model_config.num_hidden_layers
187189
tokenizer = transformers.AutoTokenizer.from_pretrained(
188190
model_args.model_name_or_path,
189191
model_max_length=training_args.training_seq_len,
190192
trust_remote_code=True,
191193
)
192-
if tokenizer.chat_template is None:
193-
tokenizer.chat_template = (
194-
"{%- for message in messages %}"
195-
"{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}"
196-
"{%- endfor %}"
197-
)
198-
if tokenizer.pad_token_id is None:
199-
tokenizer.pad_token_id = tokenizer.eos_token_id
200-
201194
if training_args.mode == "medusa":
202195
config = {
203196
"medusa_num_heads": medusa_args.medusa_num_heads,
204197
"medusa_num_layers": medusa_args.medusa_num_layers,
205198
}
206199
mtsp.convert(model, [("medusa", config)])
207-
elif training_args.mode in ["eagle1", "eagle3"]:
208-
from modelopt.torch.speculative.config import (
209-
default_eagle_config,
210-
eagle3_default_config,
211-
kimik2_eagle_default_config,
200+
elif training_args.mode == "eagle3":
201+
custom_config = (
202+
json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {}
212203
)
213204

214-
if eagle_args.eagle_decoder_type == "kimik2":
215-
eagle_architecture_config = kimik2_eagle_default_config
216-
else:
217-
eagle_architecture_config = {
218-
"eagle1": default_eagle_config,
219-
"eagle3": eagle3_default_config,
220-
}[training_args.mode]
221-
222-
if eagle_args.eagle_config:
223-
with open(eagle_args.eagle_config) as f:
224-
custom_config = json.load(f)
225-
eagle_architecture_config.update(custom_config)
226-
227205
config = {
228206
"eagle_decoder_type": eagle_args.eagle_decoder_type,
229207
"eagle_offline": use_offline_training,
230-
"eagle_architecture_config": eagle_architecture_config,
208+
"eagle_architecture_config": custom_config,
231209
}
232210

233211
mtsp.convert(model, [("eagle", config)])
234212

235-
# read draft vocab cache
236-
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
237-
try:
238-
model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path))
239-
vocab_cache_path = os.path.join(
240-
data_args.draft_vocab_cache_dir, model_name, "d2t.pt"
241-
)
242-
vocab_cache = torch.load(vocab_cache_path)
243-
model.eagle_module.d2t = vocab_cache
244-
print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.")
245-
except Exception as e:
246-
raise e
247213
else:
248214
raise Exception(f"{training_args.mode} is not supported!")
249215

250216
print_rank_0("Loading dataset...")
251217
if training_args.mode == "medusa":
252218
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
253-
elif training_args.mode in ["eagle1", "eagle3"]:
219+
elif training_args.mode == "eagle3":
254220
data_module = make_eagle_supervised_data_module(
255221
tokenizer, data_args, train_len=training_args.training_seq_len
256222
)

modelopt/torch/speculative/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,8 @@ class EagleConfig(ModeloptBaseConfig):
105105
default="llama",
106106
description=("The class of eagle decoder to use. Available options: llama, kimik2"),
107107
)
108+
109+
draft_vocab_cache: str = ModeloptField(
110+
default=None,
111+
description=("Path to d2t.pt cache file."),
112+
)

modelopt/torch/speculative/eagle/conversion.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from modelopt.torch.opt.conversion import ModelLikeModule
2121
from modelopt.torch.opt.dynamic import _DMRegistryCls
2222
from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict
23+
from modelopt.torch.speculative.config import eagle3_default_config, kimik2_eagle_default_config
2324

2425
from ..config import EagleConfig
2526

@@ -38,6 +39,14 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
3839
EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls])
3940
break
4041

42+
# merge custom config with default config
43+
default_arch_config = {
44+
"llama": eagle3_default_config,
45+
"kimik2": kimik2_eagle_default_config,
46+
}[config.eagle_decoder_type]
47+
custom_config = config.eagle_architecture_config
48+
config.eagle_architecture_config = {**default_arch_config, **custom_config}
49+
4150
eagle_model = EagleDMRegistry.convert(model)
4251
eagle_model.modify(
4352
eagle_offline=config.eagle_offline,

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import contextlib
3333
import copy
34+
import os
3435
from typing import Any
3536

3637
import torch
@@ -49,6 +50,8 @@
4950
from transformers.utils import ModelOutput
5051
from transformers.utils.quantization_config import QuantizationMethod
5152

53+
from modelopt.torch.utils import print_rank_0
54+
5255
from ..eagle.conversion import EagleDMRegistry
5356
from ..eagle.eagle_model import EagleModel
5457
from ..eagle.utils import expand_mask, make_causal_mask
@@ -248,7 +251,16 @@ def __init__(self, config, decoder_layer_cls, bias=False):
248251
# Initialize the buffers to zero.
249252
# Their values depend on specific tokenzier and calibrate dataset, and should be set in training script.
250253
if config.draft_vocab_size < config.vocab_size:
251-
self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64))
254+
if config.draft_vocab_cache is not None and os.path.isfile(
255+
config.draft_vocab_cache
256+
):
257+
self.register_buffer("d2t", torch.load(config.draft_vocab_cache))
258+
print_rank_0(f"Loaded draft vocab cache from {config.draft_vocab_cache}.")
259+
else:
260+
raise FileNotFoundError(
261+
f"Draft vocab cache file not found: {config.draft_vocab_cache}"
262+
)
263+
252264
self.lm_head = nn.Linear(
253265
config.hidden_size,
254266
config.draft_vocab_size,
@@ -425,8 +437,11 @@ def _base_model_lm_head(self):
425437
@property
426438
def _base_llm_config(self):
427439
"""Return the llm config for the base model, from LLM or VLM."""
428-
# return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
429-
return self.config.text_config
440+
return (
441+
getattr(self.config, "text_config", None)
442+
or getattr(self.config, "llm_config", None)
443+
or self.config
444+
)
430445

431446
def _find_base_model_parts(self):
432447
"""Find model parts from different models and set base_{part}_path attributes."""
@@ -574,13 +589,6 @@ def modify(
574589
):
575590
self._set_default_aux_hidden_state_layers()
576591

577-
if self._base_llm_config.hidden_size != self.eagle_config.hidden_size:
578-
raise ValueError(
579-
"EAGLE module hidden size "
580-
f"{self.eagle_config.hidden_size} must match base model hidden size "
581-
f"{self._base_llm_config.hidden_size}!"
582-
)
583-
584592
# Freeze all parameters
585593
if self.eagle_freeze_base_model:
586594
for name, param in self.named_parameters():

modelopt/torch/utils/plugins/transformers_dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from datasets import load_dataset
2525
from transformers.trainer_pt_utils import LabelSmoother
2626

27+
from modelopt.torch.utils import print_rank_0
28+
2729
REMOVE_THINK_CHAT_TEMPLATE = (
2830
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
2931
)
@@ -147,10 +149,15 @@ def __init__(
147149
else:
148150
self._post_process_chat_template()
149151

152+
self._post_process_tokenizer()
150153
if self.tokenizer.chat_template is None:
151154
raise ValueError("No valid chat template!")
152155

153156
def _post_process_tokenizer(self):
157+
if self.tokenizer.pad_token_id is None:
158+
print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.")
159+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
160+
154161
if hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is None:
155162
if self.tokenizer.eos_token == "<|eot_id|>": # nosec
156163
self.tokenizer.pad_token = "<|end_of_text|>" # nosec
@@ -264,7 +271,6 @@ def __call__(self, examples):
264271
for example in examples:
265272
messages = example.get("messages", None)
266273
if messages is None:
267-
# print(example)
268274
conversations = example.get("conversations", None)
269275
if conversations is None:
270276
raise ValueError(

0 commit comments

Comments
 (0)