forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
379 lines (336 loc) · 15.1 KB
/
main.py
File metadata and controls
379 lines (336 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from dataclasses import dataclass, field
from typing import Literal
import torch
import transformers
from accelerate import ParallelismConfig
from eagle_utils import (
EagleTrainerWithAccLog,
EagleTrainingPlot,
LoRAWarmupCallback,
make_speculative_data_module,
patch_ring_attention_for_ttt,
)
from omegaconf import OmegaConf
from transformers.trainer_utils import get_last_checkpoint
import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0
torch.manual_seed(0)
mto.enable_huggingface_checkpointing()
@dataclass
class ModelArguments:
model_name_or_path: str | None = field(
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
metadata={"help": "HuggingFace model ID or local path to the base model."},
)
use_fake_base_for_offline: bool = field(
default=False,
metadata={
"help": "Load model architecture without real base weights. Offline training only."
},
)
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code when loading model."}
)
@dataclass
class DataArguments:
data_path: str = field(
default=None,
metadata={"help": "Path to the online training data."},
)
offline_data_path: str = field(
default=None,
metadata={
"help": "Path to offline training data directory (.pt files). This argument enables offline mode.",
},
)
lazy_preprocess: bool = True
draft_vocab_cache: str | None = field(
default=None,
metadata={"help": "Path to draft vocabulary cache file."},
)
chat_template: str = field(
default=None,
metadata={
"help": "Jinja chat template with {% generation %} tags for answer_only_loss. "
"If not set, the tokenizer's built-in template is used (must already have generation tags)."
},
)
vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."})
vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."})
sample_size: int = field(
default=-1,
metadata={"help": "Number of samples to use for training. Use -1 to use all samples."},
)
def __post_init__(self):
if self.sample_size == 0 or self.sample_size < -1:
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
training_seq_len: int = field(
default=2048,
metadata={
"help": (
"Training sequence length. Sequences will be right padded or truncated to this length."
)
},
)
mode: Literal["eagle3", "medusa", "dflash"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."}
)
ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."})
answer_only_loss: bool = field(
default=False,
metadata={
"help": "Mask loss on non-assistant tokens. Requires a chat_template with generation tags."
},
)
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
dp_shard_size: int | None = field(
default=None,
metadata={"help": "Data parallelism shard size. None = auto (total_gpu / cp_size)."},
)
@dataclass
class MedusaArguments:
medusa_num_heads: int | None = field(default=1)
medusa_num_layers: int | None = field(default=1)
def _parse_cli() -> tuple[str, list[str]]:
"""Parse --config (required) from argv; return remaining args as config overrides.
Extra arguments use OmegaConf dotlist syntax, e.g.
``model.model_name_or_path=meta-llama/Llama-3.2-1B training.output_dir=ckpts/test``.
"""
p = argparse.ArgumentParser(add_help=False)
p.add_argument("--config", required=True, help="Path to the YAML config file.")
args, overrides = p.parse_known_args()
return args.config, overrides
def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict, dict]:
"""Load training config from a YAML file with sections: model, data, training, eagle/dflash.
*overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``)
applied on top of the YAML.
Returns:
hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict()
eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert()
dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert()
"""
merged = OmegaConf.load(config_path)
if overrides:
merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides)))
cfg = OmegaConf.to_container(merged, resolve=True)
# Eagle/DFlash sections map directly to config fields — no field enumeration needed.
eagle_cfg = cfg.get("eagle", {})
dflash_cfg = cfg.get("dflash", {})
hf_cfg = {
**cfg.get("model", {}),
**cfg.get("data", {}),
**cfg.get("training", {}),
}
if hf_cfg.get("dp_shard_size") is None:
cp_size = hf_cfg.get("cp_size", 1)
# Use WORLD_SIZE (total GPUs across all nodes) when available, else local GPU count.
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
hf_cfg["dp_shard_size"] = world_size // cp_size
return hf_cfg, eagle_cfg, dflash_cfg
def train():
config_path, overrides = _parse_cli()
hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides)
parser = transformers.HfArgumentParser(
(
ModelArguments,
DataArguments,
TrainingArguments,
MedusaArguments,
)
)
model_args, data_args, training_args, medusa_args = parser.parse_dict(
hf_cfg, allow_extra_keys=True
)
if not data_args.data_path and not data_args.offline_data_path:
raise ValueError(
"Either data.data_path or data.offline_data_path must be set in the config."
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
# Auto-compute dp_replicate_size so that
# dp_replicate_size * dp_shard_size * cp_size == world_size.
# Note: torch.cuda.device_count() returns per-node GPU count, not world_size.
# WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total.
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
parallel_size = training_args.dp_shard_size * training_args.cp_size
if world_size % parallel_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) "
f"= {parallel_size}"
)
dp_replicate_size = world_size // parallel_size
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size,
dp_shard_size=training_args.dp_shard_size,
dp_replicate_size=dp_replicate_size,
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
training_args.parallelism_config.sp_backend = None
print_rank_0(
f"arguments: {model_args}, {training_args}, {medusa_args}, "
f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}"
)
# Detect checkpoint to resume from
last_checkpoint = (
get_last_checkpoint(training_args.output_dir)
if os.path.isdir(training_args.output_dir)
else None
)
if last_checkpoint:
print_rank_0(f"Last checkpoint detected: {last_checkpoint}")
checkpoint = training_args.resume_from_checkpoint or last_checkpoint
use_offline_training = data_args.offline_data_path is not None
if checkpoint:
with patch_transformers5_params_loading():
model = load_vlm_or_llm(
checkpoint, dtype="auto", trust_remote_code=model_args.trust_remote_code
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
checkpoint, trust_remote_code=model_args.trust_remote_code
)
else:
# To avoid OOM for large models, we load and convert model on CPU first.
# Model will be moved to GPU during HF trainer.init().
if use_offline_training:
# Load config first to preserve original num_hidden_layers before
# load_vlm_or_llm may reduce layers for offline space savings.
model_config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
)
model = load_vlm_or_llm(
model_args.model_name_or_path,
use_fake_base=model_args.use_fake_base_for_offline,
use_offline_training=use_offline_training,
dtype="auto",
device_map="cpu",
trust_remote_code=model_args.trust_remote_code,
)
if use_offline_training:
# When doing offline training, we need to set num_hidden_layers
# since we override it when loading the model for space savings.
# Some models (e.g. Kimi-K2.5) use non-standard config attributes,
# so fall back to the model's own config if the attribute is missing.
model.config.num_orig_hidden_layers = getattr(
model_config, "num_hidden_layers", model.config.num_hidden_layers
)
if hasattr(model.config, "layer_types"):
del (
model.config.layer_types
) # remove layer_types to avoid mismatch with the modified model
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.training_seq_len,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.mode == "medusa":
config = {
"medusa_num_heads": medusa_args.medusa_num_heads,
"medusa_num_layers": medusa_args.medusa_num_layers,
}
mtsp.convert(model, [("medusa", config)])
elif training_args.mode == "eagle3":
# Validate and rewrite eagle config fields
eagle_cfg = EagleConfig.model_validate(
eagle_cfg,
context={"training_args": training_args, "data_args": data_args},
).model_dump()
mtsp.convert(model, [("eagle", eagle_cfg)])
# Load draft vocab cache if the draft model uses a compressed vocabulary
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
if not os.path.isfile(data_args.draft_vocab_cache):
raise FileNotFoundError(
f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}"
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
dflash_cfg = DFlashConfig.model_validate(
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
).model_dump()
mtsp.convert(model, [("dflash", dflash_cfg)])
else:
raise Exception(f"{training_args.mode} is not supported!")
# Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast
# them. We iterate named_buffers and reassign via the owning module to
# keep the module tree consistent. Parameters are left on CPU — the HF
# Trainer will move them during init.
if torch.cuda.is_available():
_target_dev = torch.device("cuda", 0)
for name, buf in list(model.named_buffers()):
if buf.device.type == "cpu":
parts = name.split(".")
mod = model
for p in parts[:-1]:
mod = getattr(mod, p)
setattr(mod, parts[-1], buf.to(_target_dev))
print_rank_0("Loading dataset...")
is_dflash = training_args.mode == "dflash"
if training_args.mode in ("eagle3", "medusa", "dflash"):
data_module = make_speculative_data_module(
tokenizer,
data_args,
train_len=training_args.training_seq_len,
answer_only_loss=training_args.answer_only_loss,
shift_labels=not is_dflash,
)
callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)]
if eagle_cfg.get("eagle_base_lora") and eagle_cfg.get("eagle_base_lora_warmup_steps", 0) > 0:
callbacks.append(LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]))
trainer = EagleTrainerWithAccLog(
model=model,
processing_class=tokenizer,
args=training_args,
callbacks=callbacks,
**data_module,
)
# Manually enable this to return loss in eval
trainer.can_return_loss = True
# Make sure label_smoother is None
assert trainer.label_smoother is None, (
"label_smoother is not supported in speculative decoding!"
)
print_rank_0("Start training...")
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
trainer.save_model(training_args.output_dir)
if __name__ == "__main__":
train()