Skip to content

Commit c35fb3b

Browse files
authored
Feature: Support Qwen3-VL and HunyuanOCR eagle3 speculative decoding (#196)
1 parent d43bcbc commit c35fb3b

27 files changed

Lines changed: 1331 additions & 101 deletions
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"target_model_type": "hunyuan_vl",
7+
"attention_bias": false,
8+
"attention_dropout": 0.0,
9+
"bos_token_id": 120000,
10+
"eod_token_id": 120020,
11+
"eos_token_id": 120020,
12+
"dtype": "bfloat16",
13+
"head_dim": 128,
14+
"hidden_act": "silu",
15+
"hidden_size": 1024,
16+
"image_start_token_id": 120118,
17+
"image_end_token_id": 120119,
18+
"image_token_id": 120120,
19+
"image_newline_token_id": 120121,
20+
"initializer_range": 0.02,
21+
"intermediate_size": 3584,
22+
"max_position_embeddings": 32768,
23+
"num_attention_heads": 16,
24+
"num_hidden_layers": 1,
25+
"num_key_value_heads": 8,
26+
"rms_norm_eps": 1e-06,
27+
"rope_theta": 10000.0,
28+
"use_cache": true,
29+
"vocab_size": 120818,
30+
"tie_word_embeddings": true,
31+
"transformers_version": "4.57.1",
32+
"draft_vocab_size": 32000,
33+
"modal_type": "VLM"
34+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"target_model_type": "qwen3_vl",
7+
"attention_bias": false,
8+
"attention_dropout": 0.0,
9+
"bos_token_id": 151643,
10+
"dtype": "bfloat16",
11+
"eos_token_id": 151645,
12+
"head_dim": 128,
13+
"hidden_act": "silu",
14+
"hidden_size": 2048,
15+
"initializer_range": 0.02,
16+
"intermediate_size": 6144,
17+
"max_position_embeddings": 262144,
18+
"num_attention_heads": 16,
19+
"num_hidden_layers": 1,
20+
"num_key_value_heads": 8,
21+
"rms_norm_eps": 1e-06,
22+
"rope_scaling": {
23+
"type": "default",
24+
"rope_type": "default",
25+
"mrope_interleaved": true,
26+
"mrope_section": [
27+
24,
28+
20,
29+
20
30+
]
31+
},
32+
"rope_theta": 5000000,
33+
"use_cache": true,
34+
"vocab_size": 151936,
35+
"tie_word_embeddings": true,
36+
"transformers_version": "4.57.1",
37+
"image_token_id": 151655,
38+
"video_token_id": 151656,
39+
"vision_end_token_id": 151653,
40+
"vision_start_token_id": 151652,
41+
"draft_vocab_size": 32000,
42+
"modal_type": "VLM"
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"target_model_type": "qwen3_vl",
7+
"attention_bias": false,
8+
"attention_dropout": 0.0,
9+
"bos_token_id": 151643,
10+
"dtype": "bfloat16",
11+
"eos_token_id": 151645,
12+
"head_dim": 128,
13+
"hidden_act": "silu",
14+
"hidden_size": 2048,
15+
"initializer_range": 0.02,
16+
"intermediate_size": 6144,
17+
"max_position_embeddings": 262144,
18+
"num_attention_heads": 32,
19+
"num_hidden_layers": 1,
20+
"num_key_value_heads": 4,
21+
"rms_norm_eps": 1e-06,
22+
"rope_scaling": {
23+
"type": "default",
24+
"rope_type": "default",
25+
"mrope_interleaved": true,
26+
"mrope_section": [
27+
24,
28+
20,
29+
20
30+
]
31+
},
32+
"rope_theta": 5000000,
33+
"use_cache": true,
34+
"vocab_size": 151936,
35+
"tie_word_embeddings": false,
36+
"transformers_version": "4.57.1",
37+
"image_token_id": 151655,
38+
"video_token_id": 151656,
39+
"vision_end_token_id": 151653,
40+
"vision_start_token_id": 151652,
41+
"draft_vocab_size": 32000,
42+
"modal_type": "VLM"
43+
}

angelslim/compressor/speculative/train/configs/qwen3-vl-4b-eagle3-mrope.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"Eagle3LlamaForCausalLM"
44
],
55
"model_type": "llama",
6+
"target_model_type": "qwen3_vl",
67
"attention_bias": false,
78
"attention_dropout": 0.0,
89
"bos_token_id": 151643,

angelslim/compressor/speculative/train/data/chat_templates.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@
2626
class ChatTemplateType(Enum):
2727
"""Supported chat template types."""
2828

29+
QWEN2_AUDIO = "qwen2_audio"
2930
QWEN3 = "qwen3"
3031
HUNYUAN = "hunyuan"
3132
QWEN3_VL = "qwen3_vl"
3233
HUNYUAN_7B = "hunyuan_7b"
34+
HUNYUAN_VL = "hunyuan_vl"
3335

3436

3537
# String to ChatTemplateType mapping
3638
CHAT_TEMPLATE_TYPE_MAPPING = {
39+
"qwen2_audio": ChatTemplateType.QWEN2_AUDIO,
3740
"qwen3": ChatTemplateType.QWEN3,
3841
"hunyuan": ChatTemplateType.HUNYUAN,
3942
"hunyuan_7b": ChatTemplateType.HUNYUAN_7B,
4043
"qwen3_vl": ChatTemplateType.QWEN3_VL,
44+
"hunyuan_vl": ChatTemplateType.HUNYUAN_VL,
4145
}
4246

4347

@@ -133,6 +137,26 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
133137
}
134138
],
135139
),
140+
ChatTemplateType.QWEN2_AUDIO: ChatTemplate(
141+
user_header="<|im_start|>user\n",
142+
assistant_header="<|im_start|>assistant\n",
143+
system_prompt=[
144+
{
145+
"type": "text",
146+
"text": ("You are a helpful assistant."),
147+
}
148+
],
149+
),
150+
ChatTemplateType.HUNYUAN_VL: ChatTemplate(
151+
user_header="<|hy_Assistant|>",
152+
assistant_header="<|hy_User|>",
153+
system_prompt=[
154+
{
155+
"type": "text",
156+
"text": "",
157+
}
158+
],
159+
),
136160
}
137161

138162
def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def paddingtensor3D_CBN(tensor_list):
130130
return torch.cat(out_tensor_list, dim=1)
131131

132132

133+
def paddingtensor3D_BCN(tensor_list):
134+
if all(tensor is None for tensor in tensor_list):
135+
return None
136+
N = max(tensor.shape[-1] for tensor in tensor_list if tensor is not None)
137+
out_tensor_list = []
138+
for tensor in tensor_list:
139+
b, c, n = tensor.shape
140+
outtensor = torch.zeros(b, c, N, dtype=tensor_list[0].dtype)
141+
if tensor is not None:
142+
outtensor[:, :, :n] = tensor
143+
out_tensor_list.append(outtensor)
144+
return torch.cat(out_tensor_list, dim=0)
145+
146+
133147
def paddingtensor3D_BHW(tensor_list):
134148
if all(tensor is None for tensor in tensor_list):
135149
return None
@@ -240,11 +254,90 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
240254
batch["target_hiddens"] = torch.cat(
241255
[paddingtensor(item["target_hiddens"], max_length) for item in features]
242256
)
257+
if all(
258+
"inputs_embeds" in item and item["inputs_embeds"] is not None
259+
for item in features
260+
):
243261
batch["inputs_embeds"] = torch.cat(
244262
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
245263
)
264+
if all(
265+
"position_ids" in item and item["position_ids"] is not None
266+
for item in features
267+
):
246268
batch["position_ids"] = paddingtensor3D_CBN(
247269
[item["position_ids"] for item in features]
248270
)
249271

250272
return batch
273+
274+
275+
class VLMHunyuanDataCollatorWithPadding:
276+
277+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
278+
max_length = max(item["input_ids"].shape[1] for item in features)
279+
batch_input_ids = torch.cat(
280+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
281+
)
282+
batch_attention_mask = torch.cat(
283+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
284+
)
285+
batch_loss_mask = torch.cat(
286+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
287+
)
288+
batch = {
289+
"input_ids": batch_input_ids,
290+
"attention_mask": batch_attention_mask,
291+
"loss_mask": batch_loss_mask,
292+
"hidden_states": None,
293+
"target_hiddens": None,
294+
"inputs_embeds": None,
295+
"position_ids": None,
296+
"input_position_ids": None,
297+
}
298+
299+
if "pixel_values" in features[0]:
300+
batch["pixel_values"] = paddingtensor3D_BHW(
301+
[item["pixel_values"] for item in features]
302+
)
303+
304+
if all(
305+
"image_grid_thw" in item and item["image_grid_thw"] is not None
306+
for item in features
307+
):
308+
batch["image_grid_thw"] = torch.cat(
309+
[item["image_grid_thw"] for item in features], dim=0
310+
)
311+
312+
# Check if both hidden_states and target_hiddens exist in all features
313+
if all(
314+
"hidden_states" in item and "target_hiddens" in item for item in features
315+
):
316+
batch["hidden_states"] = torch.cat(
317+
[paddingtensor(item["hidden_states"], max_length) for item in features]
318+
)
319+
batch["target_hiddens"] = torch.cat(
320+
[paddingtensor(item["target_hiddens"], max_length) for item in features]
321+
)
322+
if all(
323+
"inputs_embeds" in item and item["inputs_embeds"] is not None
324+
for item in features
325+
):
326+
batch["inputs_embeds"] = torch.cat(
327+
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
328+
)
329+
if all(
330+
"input_position_ids" in item and item["input_position_ids"] is not None
331+
for item in features
332+
):
333+
batch["input_position_ids"] = paddingtensor3D_BCN(
334+
[item["input_position_ids"] for item in features]
335+
)
336+
if all(
337+
"position_ids" in item and item["position_ids"] is not None
338+
for item in features
339+
):
340+
batch["position_ids"] = torch.cat(
341+
[paddingtensor2D(item["position_ids"], max_length) for item in features]
342+
)
343+
return batch

angelslim/compressor/speculative/train/data/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
chat_template_type: Optional[Union[str, ChatTemplateType]] = None,
4141
display: bool = False,
4242
cache_in_memory: bool = False,
43+
target_model_type: Optional[str] = None,
4344
):
4445
"""
4546
Initialize DatasetManager with DataArguments.
@@ -60,6 +61,7 @@ def __init__(
6061
self.model_max_length = model_max_length
6162
self.display = display
6263
self.cache_in_memory = cache_in_memory
64+
self.target_model_type = target_model_type
6365

6466
# Convert chat_template_type to ChatTemplateType enum
6567
if chat_template_type is None:
@@ -76,6 +78,7 @@ def __init__(
7678
self.online_dataset_builder = DatasetBuilderFactory.create(
7779
training_mode="online",
7880
modal_type=data_args.modal_type,
81+
target_model_type=self.target_model_type,
7982
tokenizer=tokenizer,
8083
max_length=model_max_length,
8184
shuffle_seed=data_args.shuffle_seed,
@@ -86,6 +89,7 @@ def __init__(
8689
self.offline_dataset_builder = DatasetBuilderFactory.create(
8790
training_mode="offline",
8891
modal_type=data_args.modal_type,
92+
target_model_type=self.target_model_type,
8993
cache_in_memory=cache_in_memory,
9094
)
9195

angelslim/compressor/speculative/train/data/dataset_builder/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,23 @@
1313
# limitations under the License.
1414

1515
from .dataset_builder_factory import DatasetBuilderFactory
16-
from .offline_dataset_builder import OfflineLLMDatasetBuilder, OfflineVLMDatasetBuilder
17-
from .online_dataset_builder import OnlineLLMDatasetBuilder, OnlineVLMDatasetBuilder
16+
from .offline_dataset_builder import (
17+
OfflineLLMDatasetBuilder,
18+
OfflineVLMDatasetBuilder,
19+
OfflineVLMHunyuanVLDatasetBuilder,
20+
)
21+
from .online_dataset_builder import (
22+
OnlineLLMDatasetBuilder,
23+
OnlineVLMDatasetBuilder,
24+
OnlineVLMHunyuanVLDatasetBuilder,
25+
)
1826

1927
__all__ = [
2028
"OnlineLLMDatasetBuilder",
2129
"OnlineVLMDatasetBuilder",
30+
"OnlineVLMHunyuanVLDatasetBuilder",
2231
"OfflineLLMDatasetBuilder",
2332
"OfflineVLMDatasetBuilder",
33+
"OfflineVLMHunyuanVLDatasetBuilder",
2434
"DatasetBuilderFactory",
2535
]

0 commit comments

Comments
 (0)