Skip to content

Commit 676ae7d

Browse files
authored
support eagle3 offline training with per_device_train_batch_size>1 (#264)
1 parent 26567b2 commit 676ae7d

20 files changed

Lines changed: 513 additions & 206 deletions

angelslim/compressor/speculative/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
create_target_model,
2525
data_generation_work_flow,
2626
get_supported_chat_template_type_strings,
27+
infer_model_params,
2728
)
2829

2930
__all__ = [
@@ -40,4 +41,5 @@
4041
"DatasetManager",
4142
"get_supported_chat_template_type_strings",
4243
"TargetHead",
44+
"infer_model_params",
4345
]

angelslim/compressor/speculative/train/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TargetHead,
1111
create_draft_model,
1212
create_target_model,
13+
infer_model_params,
1314
)
1415
from .trainer import Eagle3TrainerFactory
1516

@@ -24,4 +25,5 @@
2425
"DatasetManager",
2526
"get_supported_chat_template_type_strings",
2627
"TargetHead",
28+
"infer_model_params",
2729
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .draft import DraftModelConfig, create_draft_model
2+
from .model_utils import infer_model_params
23
from .target import TargetHead, create_target_model
34

45
__all__ = [
56
"create_draft_model",
67
"DraftModelConfig",
78
"create_target_model",
89
"TargetHead",
10+
"infer_model_params",
911
]

angelslim/compressor/speculative/train/models/draft/base_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,16 @@ def build_vocab_mapping(self, dataset, cache_path):
205205

206206
self.t2d.copy_(t2d)
207207
self.d2t.copy_(d2t)
208+
209+
def load_vocab_mapping(self, vocab_mapping_path):
210+
"""
211+
Load pre-computed vocab mapping directly from disk.
212+
213+
Args:
214+
vocab_mapping_path: Path to the vocab_mapping.pt file saved by generate_hidden
215+
"""
216+
cache = torch.load(vocab_mapping_path)
217+
d2t = cache["d2t"]
218+
t2d = cache["t2d"]
219+
self.t2d.copy_(t2d)
220+
self.d2t.copy_(d2t)

angelslim/compressor/speculative/train/models/draft/llama_eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def apply_interleaved_mrope(self, freqs, mrope_section):
180180
@torch.no_grad()
181181
def forward(self, x, position_ids, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
182182
if position_ids.ndim == 2:
183-
position_ids = position_ids.unsqueeze(1)
184-
# position_ids = position_ids[None].expand(3, position_ids.shape[0], -1)
183+
# expand (batch, seq_len) to (3, batch, seq_len), match MRoPE T/H/W layout
184+
position_ids = position_ids[None].expand(3, position_ids.shape[0], -1)
185185

186186
inv_freq_expanded = (
187187
self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)

angelslim/compressor/speculative/train/models/model_utils.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Optional, Tuple
1616

1717
import torch
18+
from transformers import AutoConfig
1819

1920
__all__ = [
2021
"make_causal_mask",
2122
"expand_mask",
2223
"repeat_kv",
2324
"rotate_half",
2425
"apply_rotary_pos_emb",
26+
"infer_model_params",
2527
]
2628

2729

@@ -107,3 +109,79 @@ def apply_rotary_pos_emb_mrope(q, k, cos, sin, position_ids=None, unsqueeze_dim=
107109
q_embed = (q * cos) + (rotate_half(q) * sin)
108110
k_embed = (k * cos) + (rotate_half(k) * sin)
109111
return q_embed, k_embed
112+
113+
114+
# model_type -> (lm_head_key, embed_weight_key, chat_template_type)
115+
# key: model_type (from AutoConfig)
116+
MODEL_TYPE_PARAM_MAP: dict = {
117+
"qwen3_vl": (
118+
"model.language_model.embed_tokens.weight",
119+
"model.language_model.embed_tokens.weight",
120+
"qwen3_vl",
121+
),
122+
"qwen3_vl_moe": (
123+
"model.language_model.embed_tokens.weight",
124+
"model.language_model.embed_tokens.weight",
125+
"qwen3_vl",
126+
),
127+
"hunyuan_vl": (
128+
"model.embed_tokens.weight",
129+
"model.embed_tokens.weight",
130+
"hunyuan_vl",
131+
),
132+
"qwen2_audio": (
133+
"lm_head.weight",
134+
"language_model.model.embed_tokens.weight",
135+
"qwen2_audio",
136+
),
137+
"qwen3": (
138+
"lm_head.weight",
139+
"model.embed_tokens.weight",
140+
"qwen3",
141+
),
142+
"qwen2_5": (
143+
"lm_head.weight",
144+
"model.embed_tokens.weight",
145+
"qwen2.5",
146+
),
147+
"llama": (
148+
"lm_head.weight",
149+
"model.embed_tokens.weight",
150+
"qwen3",
151+
),
152+
}
153+
154+
155+
def infer_model_params(
156+
model_name_or_path: str,
157+
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
158+
"""
159+
auto-detect lm_head_key、embed_weight_key、chat_template_type from target model path
160+
Args:
161+
model_name_or_path: target model path
162+
163+
Returns:
164+
(lm_head_key, embed_weight_key, chat_template_type)
165+
(None, None, None) if failed to auto-detect
166+
"""
167+
try:
168+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
169+
model_type = getattr(config, "model_type", None)
170+
print(f"[Auto-detect] Detected model_type: {model_type}")
171+
if model_type in MODEL_TYPE_PARAM_MAP:
172+
lm_head_key, embed_weight_key, chat_template_type = MODEL_TYPE_PARAM_MAP[model_type]
173+
print(
174+
f"[Auto-detect] lm_head_key={lm_head_key}, "
175+
f"embed_weight_key={embed_weight_key}, "
176+
f"chat_template_type={chat_template_type}"
177+
)
178+
return lm_head_key, embed_weight_key, chat_template_type
179+
else:
180+
print(
181+
f"[Auto-detect] No preset mapping found for model_type={model_type!r}, "
182+
"will use command-line specified values"
183+
)
184+
return None, None, None
185+
except Exception as e:
186+
print(f"[Auto-detect] Failed to read model config: {e}")
187+
return None, None, None

angelslim/compressor/speculative/train/models/target/target_model_wrapper.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,9 @@ def get_hidden_states_and_logits(
370370

371371
def hook(module, args, kwargs):
372372
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
373-
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
373+
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
374374
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
375-
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
375+
position_ids_list.append(kwargs["position_ids"].clone().detach())
376376
return args, kwargs
377377

378378
if self.target_model_type == "qwen3_vl":
@@ -440,9 +440,9 @@ def get_aux_and_target_hiddens(
440440

441441
def hook(module, args, kwargs):
442442
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
443-
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
443+
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
444444
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
445-
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
445+
position_ids_list.append(kwargs["position_ids"].clone().detach())
446446
return args, kwargs
447447

448448
if self.target_model_type == "qwen3_vl":
@@ -572,9 +572,9 @@ def get_hidden_states_and_logits(
572572

573573
def hook(module, args, kwargs):
574574
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
575-
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
575+
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
576576
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
577-
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
577+
position_ids_list.append(kwargs["position_ids"].clone().detach())
578578
return args, kwargs
579579

580580
handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True)
@@ -628,9 +628,9 @@ def get_aux_and_target_hiddens(
628628

629629
def hook(module, args, kwargs):
630630
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
631-
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach().cpu())
631+
inputs_embeds_list.append(kwargs["inputs_embeds"].clone().detach())
632632
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
633-
position_ids_list.append(kwargs["position_ids"].clone().detach().cpu())
633+
position_ids_list.append(kwargs["position_ids"].clone().detach())
634634
return args, kwargs
635635

636636
handle = self.model.language_model.register_forward_pre_hook(hook, with_kwargs=True)

angelslim/compressor/speculative/train/trainer/eagle3_trainer.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import time
1617
from abc import ABC, abstractmethod
1718
from typing import Dict, List, Optional, Tuple
1819

@@ -44,6 +45,67 @@ def __init__(self, draft_model: nn.Module, length: int, **kwargs):
4445
"""
4546
super().__init__(model=draft_model, **kwargs)
4647
self.length = length
48+
self._train_start_time = None
49+
self._pending_log: dict = (
50+
{}
51+
) # cache acc/ploss log for merging with base Trainer's loss log
52+
self._pending_log_count: int = 0 # accumulated batch count for averaging the cached log
53+
54+
def train(self, *args, **kwargs):
55+
"""Override train method to record training start time for estimating remaining time."""
56+
self._train_start_time = time.time()
57+
return super().train(*args, **kwargs)
58+
59+
def log(self, logs: dict, start_time: Optional[float] = None) -> None:
60+
"""
61+
rewrite log method to merge acc/ploss log with base Trainer's loss log.
62+
"""
63+
if "loss" in logs and self._pending_log:
64+
# merge cached acc/ploss data (average)
65+
count = max(self._pending_log_count, 1)
66+
acc_ploss = {k: v / count for k, v in self._pending_log.items()}
67+
merged = {}
68+
69+
# step
70+
max_steps = 0
71+
if self.state is not None:
72+
global_step = self.state.global_step
73+
max_steps = self.state.max_steps
74+
merged["step"] = global_step
75+
76+
# epoch
77+
if "epoch" in logs:
78+
merged["epoch"] = logs["epoch"]
79+
if "loss" in logs:
80+
merged["loss"] = logs["loss"]
81+
if "grad_norm" in logs:
82+
merged["grad_norm"] = logs["grad_norm"]
83+
84+
if "learning_rate" in logs:
85+
merged["lr"] = logs["learning_rate"]
86+
87+
# acc/ploss
88+
merged.update(acc_ploss)
89+
90+
# remaining_time
91+
if (
92+
self.state is not None
93+
and self._train_start_time is not None
94+
and global_step > 0
95+
and max_steps > 0
96+
):
97+
elapsed = time.time() - self._train_start_time
98+
time_per_step = elapsed / global_step
99+
remaining_seconds = int(time_per_step * (max_steps - global_step))
100+
hours, remainder = divmod(remaining_seconds, 3600)
101+
minutes, seconds = divmod(remainder, 60)
102+
merged["remaining_time"] = f"{hours:02d}h:{minutes:02d}m:{seconds:02d}s"
103+
104+
self._pending_log.clear()
105+
self._pending_log_count = 0
106+
super().log(merged, start_time)
107+
else:
108+
super().log(logs, start_time)
47109

48110
@property
49111
def draft_model(self) -> nn.Module:
@@ -131,7 +193,11 @@ def prepare_attention_mask_and_position_ids(
131193
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
132194
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
133195
else:
134-
position_ids = position_ids.view(-1, seq_length).long()
196+
if position_ids.ndim == 3:
197+
# MRoPE format: (3, batch, seq_len), keep as-is
198+
position_ids = position_ids.long()
199+
else:
200+
position_ids = position_ids.view(-1, seq_length).long()
135201

136202
if attention_mask is None:
137203
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.bool, device=device)
@@ -210,15 +276,12 @@ def draft_model_training_time_test(
210276
ploss_weight = [0.8**i for i in range(len(plosses))]
211277
ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
212278

213-
log = {f"{log_prefix}/acc_{i}": round(float(acces[i]), 3) for i in range(len(acces))}
214-
log.update(
215-
{
216-
f"{log_prefix}/ploss_{i}": round(float(plosses[i].item()), 3)
217-
for i in range(len(plosses))
218-
}
219-
)
220-
self.log(log)
221-
279+
log = {f"{log_prefix}/acc_{i}": acces[i] for i in range(len(acces))}
280+
log.update({f"{log_prefix}/ploss_{i}": plosses[i].item() for i in range(len(plosses))})
281+
# Cache log for merging with base Trainer's loss log
282+
for k, v in log.items():
283+
self._pending_log[k] = self._pending_log.get(k, 0.0) + v
284+
self._pending_log_count += 1
222285
# Step 9: Return loss
223286
return ploss
224287

docs/source/features/speculative_decoding/eagle/vlm_eagle.md

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,6 @@ bash scripts/speculative/hunyuan_ocr/generate_vlm_hidden_for_draft_model.sh
8888
# For Qwen3-VL series
8989
bash scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model.sh
9090
```
91-
- 离线hidden_states采集时,如果由于pixel_values数组太长导致 OverflowError: There was an overflow with type <class 'list'>. 请采用分batch处理方式见:
92-
93-
```shell
94-
# For HunyuanOCR
95-
bash scripts/speculative/hunyuan_ocr/generate_vlm_hidden_for_draft_model_batch.sh
96-
# For Qwen3-VL series
97-
bash scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model_batch.sh
98-
```
9991

10092
> 注意:qwen3_vl系列模型生成hidden states需要更新transformers>=5.0.0,
10193
或者cherry-pick: https://github.com/huggingface/transformers/pull/42609,
@@ -106,7 +98,8 @@ bash scripts/speculative/qwen3_vl/generate_vlm_hidden_for_draft_model.sh
10698
在使用前,需要在脚本中配置以下参数:
10799

108100
- `DATASET_PATH`: 输入数据集的HF名称或本地路径
109-
- `MODEL_NAME`: 目标模型的HF名称或本地路径
101+
- `TARGET_MODEL_NAME_OR_PATH`: 目标模型的HF名称或本地路径
102+
- `DRAFT_MODEL_CONFIG_PATH`: 草稿模型的config路径
110103
- `TARGET_BACKEND`: 目标模型后端,目前仅支持HF
111104
- `MODEL_MAX_LENGTH`: 生成数据的上下文长度
112105
- `CHAT_TEMPLATE_TYPE`: 目标模型的目标类型,目前支持qwen3_vl/hunyuan_vl
@@ -159,7 +152,6 @@ bash scripts/speculative/qwen3_vl/train_eagle3_vlm_offline.sh
159152

160153
- `TARGET_MODEL_NAME_OR_PATH`: 目标模型的HF名称或本地名称
161154
- `DRAFT_MODEL_CONFIG_PATH`: 草稿模型的config路径
162-
- `TRAIN_DATA_PATH`: 训练数据路径,.jsonl格式
163155
- `TRAIN_HIDDEN_PATH`: 训练hidden states数据路径
164156
- `EVAL_HIDDEN_PATH`: 验证hidden states数据路径
165157
- `OUTPUT_DIR`: Eagle3模型输出路径
@@ -175,7 +167,8 @@ AngelSlim提供了HunyuanOCR和Qwen3-VL系列模型vLLM backend的Eagle3基准
175167

176168
### 4.1 vLLM基准测试
177169

178-
> vLLM 适配参考: [Support Eagle3 for HunyuanOCR & Qwen3-VL](https://github.com/vllm-project/vllm/pull/32230)
170+
> vLLM 建议版本0.16.0以上,已支持Hunyuan/HunyuanVL/Qwen3-VL。
171+
> HunyuanOCR & Qwen3VLMoe & Qwen2Audio 适配需要cherry-pick这个PR: [feature: support eagle3 for HunyuanOCR & Qwen3VLMoe & Qwen2Audio](https://github.com/vllm-project/vllm/pull/32230)
179172
180173
#### 4.1.1 基本用法
181174

@@ -197,7 +190,7 @@ python3 tools/vllm_offline_eagle3_vlm_batch.py \
197190
- `--draft_model`: Eagle辅助模型路径(必需)
198191

199192
**基准测试配置:**
200-
- `--dataset`: 基准数据集名称,默认为 `lmms-lab/textvqa`, 可选【`lmms-lab/textvqa`,`MMMU/MMMU`,`Lin-Chen/MMStar`,`opendatalab/OmniDocBench`,`Lin-Chen/MMStar`
193+
- `--dataset`: 基准数据集名称,默认为 `lmms-lab/textvqa`, 可选【`lmms-lab/textvqa`,`MMMU/MMMU`,`Lin-Chen/MMStar`,`opendatalab/OmniDocBench`,`Lin-Chen/MMStar`。也支持本地的数据集路径,格式见: 2.1 数据组织形式
201194
- `--use_eagle`: 运行Eagle3推理,默认为False
202195
- `--output_file`: 输出结果文件路径
203196
- `--num_prompts`: 测试用例数量,默认为100
@@ -233,11 +226,10 @@ python3 tools/vllm_offline_eagle3_vlm_batch.py \
233226
--output_file "$OUTPUT_FILE"
234227
```
235228

236-
**Baseline基准测试:**
229+
**Baseline基准测试(不使用投机采样)**
237230
```shell
238231
python3 tools/vllm_offline_eagle3_vlm_batch.py \
239232
--target_model Qwen/Qwen3-VL-2B-Instruct \
240-
--num_spec_tokens 4 \
241233
--dataset "$task" \
242234
--num_prompts 80 \
243235
--temp 0 \

0 commit comments

Comments
 (0)