Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
"stage3_gather_16bit_weights_on_model_save": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},

"gradient_accumulation_steps": "auto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "hunyuan_v1_dense",
"attention_bias": false,
"attention_dropout": 0.0,
"attention_head_dim": 128,
Expand All @@ -26,7 +28,6 @@
"mask_init_id": 12,
"max_position_embeddings": 262144,
"mlp_bias": false,
"model_type": "llama",
"norm_type": "rms",
"num_attention_heads": 16,
"num_hidden_layers": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "hunyuan_v1_dense",
"attention_bias": false,
"attention_dropout": 0.0,
"attention_head_dim": 128,
Expand All @@ -26,7 +28,6 @@
"mask_init_id": 12,
"max_position_embeddings": 262144,
"mlp_bias": false,
"model_type": "llama",
"norm_type": "rms",
"num_attention_heads": 32,
"num_hidden_layers": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "hunyuan_v1_dense",
"attention_bias": false,
"attention_dropout": 0.1,
"attention_head_dim": 128,
Expand All @@ -26,7 +28,6 @@
"mask_init_id": 12,
"max_position_embeddings": 32768,
"mlp_bias": false,
"model_type": "llama",
"norm_type": "rms",
"num_attention_heads": 32,
"num_hidden_layers": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen2.5",
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
Expand All @@ -11,7 +13,6 @@
"intermediate_size": 4864,
"max_position_embeddings": 32768,
"max_window_layers": 24,
"model_type": "llama",
"num_attention_heads": 14,
"num_hidden_layers": 1,
"num_key_value_heads": 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen2.5",
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
Expand All @@ -11,7 +13,6 @@
"intermediate_size": 8960,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "llama",
"num_attention_heads": 12,
"num_hidden_layers": 1,
"num_key_value_heads": 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen2.5",
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
Expand All @@ -11,7 +13,6 @@
"intermediate_size": 11008,
"max_position_embeddings": 32768,
"max_window_layers": 36,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 1,
"num_key_value_heads": 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen2.5",
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
Expand All @@ -11,7 +13,6 @@
"intermediate_size": 18944,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "llama",
"num_attention_heads": 28,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3_moe",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"target_model_type": "qwen3",
"torch_dtype": "bfloat16",
"attention_bias": false,
"attention_dropout": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,24 @@ def _create_loss_mask_from_offsets(

return loss_mask

@staticmethod
def _normalize_content(content):
"""Normalize content to string format.

If content is a list (multimodal format like [{"type": "text", "text": "..."}]),
extract and concatenate all text items into a single string.
This ensures LLM mode can handle data in multimodal format.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text" and item.get("text"):
text_parts.append(item["text"])
return "".join(text_parts) if text_parts else ""
return content

def _build_messages(self, source: List[Dict]) -> List[Dict]:
# System message
if source[0]["role"] != "system":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ def __init__(
def get_data_collator(self) -> Any:
return DataCollatorWithPadding()

def _process_single_conversation(self, conversation_data):
"""
compatible with two format:
{"id": "0", "conversations": [
{"role": "user", "content": "xxx"},
{"role": "assistant", "content": "xxx"}
]}
{"id": "0", "conversations": [
{"role": "user", "content": [{"type": "text", "text": "xxx"}]},
{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}
]}
"""
if conversation_data:
for message in conversation_data:
content = message.get("content")
if isinstance(content, list):
message["content"] = self._normalize_content(content)
return super()._process_single_conversation(conversation_data)


@DatasetBuilderFactory.register("online", "VLM", "qwen2_5_vl")
@DatasetBuilderFactory.register("online", "VLM", "qwen3_vl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@


class Eagle3BaseDraftModel(PreTrainedModel, ABC):
supports_gradient_checkpointing = True

@abstractmethod
def compute_logits(self, input_ids, attention_mask):
pass
Expand All @@ -40,9 +42,12 @@ def combine_hidden_states(self, hidden_states):
pass

@abstractmethod
def get_input_embeddings(self, input_ids):
def embed_input_ids(self, input_ids):
pass

def get_input_embeddings(self):
return self.embed_tokens

def freeze_embed_weights(self):
for param in self.embed_tokens.parameters():
param.requires_grad = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def _init_rope(self):
if scaling_type == "mrope" or self.config.rope_scaling.get("mrope_interleaved", False):
self.rotary_emb = MRotaryEmbedding(self.config)
self.rope_apply_func = apply_rotary_pos_emb_mrope
elif scaling_type == "default":
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings
)
elif scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
Expand Down Expand Up @@ -559,6 +563,9 @@ def __init__(self, config):

self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)

# Required by new transformers gradient checkpointing format
self.gradient_checkpointing = False

def combine_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.fc(hidden_states)

Expand Down Expand Up @@ -589,7 +596,7 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.lm_head(norm_hidden_states)
return logits.float()

def get_input_embeddings(self, input_ids):
def embed_input_ids(self, input_ids):
inputs_embeds = self.embed_tokens(input_ids)
return inputs_embeds

Expand Down
16 changes: 12 additions & 4 deletions angelslim/compressor/speculative/train/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,29 @@ def apply_rotary_pos_emb_mrope(q, k, cos, sin, position_ids=None, unsqueeze_dim=


def infer_model_params(
model_name_or_path: str,
model_name_or_path: Optional[str],
model_type: Optional[str],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
auto-detect lm_head_key、embed_weight_key、chat_template_type from target model path
or model_type.

If model_type is provided, it will be used directly to look up the parameter map
without loading AutoConfig. Otherwise, model_name_or_path is used to auto-detect
model_type via AutoConfig.

Args:
model_name_or_path: target model path
model_name_or_path: target model path (optional if model_type is provided)
model_type: model type string, e.g. 'qwen2_5_vl', 'qwen3_vl'
(typically from draft_model_config.target_model_type)

Returns:
(lm_head_key, embed_weight_key, chat_template_type)
(None, None, None) if failed to auto-detect
"""
try:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model_type = getattr(config, "model_type", None)
print(f"[Auto-detect] Detected model_type: {model_type}")
print(f"model_type: {model_type}")
if model_type in MODEL_TYPE_PARAM_MAP:
lm_head_key, embed_weight_key, chat_template_type = MODEL_TYPE_PARAM_MAP[model_type]
# compatible with tie_word_embeddings=False/True
Expand Down
Loading
Loading