Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "hunyuan_vl",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 120000,
"eod_token_id": 120020,
"eos_token_id": 120020,
"dtype": "bfloat16",
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"image_start_token_id": 120118,
"image_end_token_id": 120119,
"image_token_id": 120120,
"image_newline_token_id": 120121,
"initializer_range": 0.02,
"intermediate_size": 3584,
"max_position_embeddings": 32768,
"num_attention_heads": 16,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"use_cache": true,
"vocab_size": 120818,
"tie_word_embeddings": true,
"transformers_version": "4.57.1",
"draft_vocab_size": 32000,
"modal_type": "VLM"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3_vl",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"num_attention_heads": 16,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"type": "default",
"rope_type": "default",
"mrope_interleaved": true,
"mrope_section": [
24,
20,
20
]
},
"rope_theta": 5000000,
"use_cache": true,
"vocab_size": 151936,
"tie_word_embeddings": true,
"transformers_version": "4.57.1",
"image_token_id": 151655,
"video_token_id": 151656,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"draft_vocab_size": 32000,
"modal_type": "VLM"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3_vl",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"type": "default",
"rope_type": "default",
"mrope_interleaved": true,
"mrope_section": [
24,
20,
20
]
},
"rope_theta": 5000000,
"use_cache": true,
"vocab_size": 151936,
"tie_word_embeddings": false,
"transformers_version": "4.57.1",
"image_token_id": 151655,
"video_token_id": 151656,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"draft_vocab_size": 32000,
"modal_type": "VLM"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3_vl",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
Expand Down
24 changes: 24 additions & 0 deletions angelslim/compressor/speculative/train/data/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,22 @@
class ChatTemplateType(Enum):
"""Supported chat template types."""

QWEN2_AUDIO = "qwen2_audio"
QWEN3 = "qwen3"
HUNYUAN = "hunyuan"
QWEN3_VL = "qwen3_vl"
HUNYUAN_7B = "hunyuan_7b"
HUNYUAN_VL = "hunyuan_vl"


# String to ChatTemplateType mapping
CHAT_TEMPLATE_TYPE_MAPPING = {
"qwen2_audio": ChatTemplateType.QWEN2_AUDIO,
"qwen3": ChatTemplateType.QWEN3,
"hunyuan": ChatTemplateType.HUNYUAN,
"hunyuan_7b": ChatTemplateType.HUNYUAN_7B,
"qwen3_vl": ChatTemplateType.QWEN3_VL,
"hunyuan_vl": ChatTemplateType.HUNYUAN_VL,
}


Expand Down Expand Up @@ -133,6 +137,26 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
}
],
),
ChatTemplateType.QWEN2_AUDIO: ChatTemplate(
user_header="<|im_start|>user\n",
assistant_header="<|im_start|>assistant\n",
system_prompt=[
{
"type": "text",
"text": ("You are a helpful assistant."),
}
],
),
ChatTemplateType.HUNYUAN_VL: ChatTemplate(
user_header="<|hy_Assistant|>",
assistant_header="<|hy_User|>",
system_prompt=[
{
"type": "text",
"text": "",
}
],
),
}

def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:
Expand Down
93 changes: 93 additions & 0 deletions angelslim/compressor/speculative/train/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ def paddingtensor3D_CBN(tensor_list):
return torch.cat(out_tensor_list, dim=1)


def paddingtensor3D_BCN(tensor_list):
if all(tensor is None for tensor in tensor_list):
return None
N = max(tensor.shape[-1] for tensor in tensor_list if tensor is not None)
out_tensor_list = []
for tensor in tensor_list:
b, c, n = tensor.shape
outtensor = torch.zeros(b, c, N, dtype=tensor_list[0].dtype)
if tensor is not None:
outtensor[:, :, :n] = tensor
out_tensor_list.append(outtensor)
return torch.cat(out_tensor_list, dim=0)


def paddingtensor3D_BHW(tensor_list):
if all(tensor is None for tensor in tensor_list):
return None
Expand Down Expand Up @@ -240,11 +254,90 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch["target_hiddens"] = torch.cat(
[paddingtensor(item["target_hiddens"], max_length) for item in features]
)
if all(
"inputs_embeds" in item and item["inputs_embeds"] is not None
for item in features
):
batch["inputs_embeds"] = torch.cat(
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
)
if all(
"position_ids" in item and item["position_ids"] is not None
for item in features
):
batch["position_ids"] = paddingtensor3D_CBN(
[item["position_ids"] for item in features]
)

return batch


class VLMHunyuanDataCollatorWithPadding:

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length = max(item["input_ids"].shape[1] for item in features)
batch_input_ids = torch.cat(
[paddingtensor2D(item["input_ids"], max_length) for item in features]
)
batch_attention_mask = torch.cat(
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
)
batch_loss_mask = torch.cat(
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
)
batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"loss_mask": batch_loss_mask,
"hidden_states": None,
"target_hiddens": None,
"inputs_embeds": None,
"position_ids": None,
"input_position_ids": None,
}

if "pixel_values" in features[0]:
batch["pixel_values"] = paddingtensor3D_BHW(
[item["pixel_values"] for item in features]
)

if all(
"image_grid_thw" in item and item["image_grid_thw"] is not None
for item in features
):
batch["image_grid_thw"] = torch.cat(
[item["image_grid_thw"] for item in features], dim=0
)

# Check if both hidden_states and target_hiddens exist in all features
if all(
"hidden_states" in item and "target_hiddens" in item for item in features
):
batch["hidden_states"] = torch.cat(
[paddingtensor(item["hidden_states"], max_length) for item in features]
)
batch["target_hiddens"] = torch.cat(
[paddingtensor(item["target_hiddens"], max_length) for item in features]
)
if all(
"inputs_embeds" in item and item["inputs_embeds"] is not None
for item in features
):
batch["inputs_embeds"] = torch.cat(
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
)
if all(
"input_position_ids" in item and item["input_position_ids"] is not None
for item in features
):
batch["input_position_ids"] = paddingtensor3D_BCN(
[item["input_position_ids"] for item in features]
)
if all(
"position_ids" in item and item["position_ids"] is not None
for item in features
):
batch["position_ids"] = torch.cat(
[paddingtensor2D(item["position_ids"], max_length) for item in features]
)
return batch
4 changes: 4 additions & 0 deletions angelslim/compressor/speculative/train/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
chat_template_type: Optional[Union[str, ChatTemplateType]] = None,
display: bool = False,
cache_in_memory: bool = False,
target_model_type: Optional[str] = None,
):
"""
Initialize DatasetManager with DataArguments.
Expand All @@ -60,6 +61,7 @@ def __init__(
self.model_max_length = model_max_length
self.display = display
self.cache_in_memory = cache_in_memory
self.target_model_type = target_model_type

# Convert chat_template_type to ChatTemplateType enum
if chat_template_type is None:
Expand All @@ -76,6 +78,7 @@ def __init__(
self.online_dataset_builder = DatasetBuilderFactory.create(
training_mode="online",
modal_type=data_args.modal_type,
target_model_type=self.target_model_type,
tokenizer=tokenizer,
max_length=model_max_length,
shuffle_seed=data_args.shuffle_seed,
Expand All @@ -86,6 +89,7 @@ def __init__(
self.offline_dataset_builder = DatasetBuilderFactory.create(
training_mode="offline",
modal_type=data_args.modal_type,
target_model_type=self.target_model_type,
cache_in_memory=cache_in_memory,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@
# limitations under the License.

from .dataset_builder_factory import DatasetBuilderFactory
from .offline_dataset_builder import OfflineLLMDatasetBuilder, OfflineVLMDatasetBuilder
from .online_dataset_builder import OnlineLLMDatasetBuilder, OnlineVLMDatasetBuilder
from .offline_dataset_builder import (
OfflineLLMDatasetBuilder,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用import OfflineVLMDatasetBuilder?

OfflineVLMDatasetBuilder,
OfflineVLMHunyuanVLDatasetBuilder,
)
from .online_dataset_builder import (
OnlineLLMDatasetBuilder,
OnlineVLMDatasetBuilder,
OnlineVLMHunyuanVLDatasetBuilder,
)

__all__ = [
"OnlineLLMDatasetBuilder",
"OnlineVLMDatasetBuilder",
"OnlineVLMHunyuanVLDatasetBuilder",
"OfflineLLMDatasetBuilder",
"OfflineVLMDatasetBuilder",
"OfflineVLMHunyuanVLDatasetBuilder",
"DatasetBuilderFactory",
]
Loading
Loading