Skip to content

Commit 18f472c

Browse files
authored
feat: support Hunyuan Eagle3 training (#126)
1 parent b125081 commit 18f472c

9 files changed

Lines changed: 223 additions & 48 deletions

File tree

angelslim/compressor/speculative/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,31 @@
1313
# limitations under the License.
1414

1515
from .benchmark import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
16+
from .train import (
17+
DataCollatorWithPadding,
18+
DatasetManager,
19+
DraftModelConfig,
20+
OnlineEagle3Trainer,
21+
convert_sharegpt_data,
22+
convert_ultrachat_data,
23+
create_draft_model,
24+
create_target_model,
25+
data_generation_work_flow,
26+
get_supported_chat_template_type_strings,
27+
)
1628

17-
__all__ = ["BenchmarkEngine", "BenchmarkConfig", "BenchmarkMode"]
29+
__all__ = [
30+
"BenchmarkEngine",
31+
"BenchmarkConfig",
32+
"BenchmarkMode",
33+
"create_draft_model",
34+
"DraftModelConfig",
35+
"create_target_model",
36+
"OnlineEagle3Trainer",
37+
"data_generation_work_flow",
38+
"DataCollatorWithPadding",
39+
"convert_sharegpt_data",
40+
"convert_ultrachat_data",
41+
"DatasetManager",
42+
"get_supported_chat_template_type_strings",
43+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from .data import (
2+
DataCollatorWithPadding,
3+
DatasetManager,
4+
convert_sharegpt_data,
5+
convert_ultrachat_data,
6+
data_generation_work_flow,
7+
get_supported_chat_template_type_strings,
8+
)
9+
from .models import DraftModelConfig, create_draft_model, create_target_model
10+
from .trainer import OnlineEagle3Trainer
11+
12+
__all__ = [
13+
"create_draft_model",
14+
"DraftModelConfig",
15+
"create_target_model",
16+
"OnlineEagle3Trainer",
17+
"data_generation_work_flow",
18+
"DataCollatorWithPadding",
19+
"convert_sharegpt_data",
20+
"convert_ultrachat_data",
21+
"DatasetManager",
22+
"get_supported_chat_template_type_strings",
23+
]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"add_classification_head": false,
3+
"architectures": [
4+
"Eagle3LlamaForCausalLM"
5+
],
6+
"attention_bias": false,
7+
"attention_dropout": 0.0,
8+
"attention_head_dim": 128,
9+
"bos_token_id": 1,
10+
"cla_share_factor": 2,
11+
"class_num": 0,
12+
"dense_list": [
13+
3072,
14+
0
15+
],
16+
"eod_token_id": 120026,
17+
"eos_token_id": 120020,
18+
"head_dim": 128,
19+
"hidden_act": "silu",
20+
"hidden_size": 3072,
21+
"im_end_id": 5,
22+
"im_newline_id": 11,
23+
"im_start_id": 4,
24+
"initializer_range": 0.02,
25+
"intermediate_size": 8192,
26+
"mask_init_id": 12,
27+
"max_position_embeddings": 262144,
28+
"mlp_bias": false,
29+
"model_type": "llama",
30+
"norm_type": "rms",
31+
"num_attention_heads": 32,
32+
"num_hidden_layers": 1,
33+
"num_key_value_heads": 8,
34+
"org_vocab_size": 120818,
35+
"pad_id": 120002,
36+
"pad_token_id": 120002,
37+
"pool_type": "last",
38+
"pretraining_tp": 1,
39+
"rms_norm_eps": 1e-05,
40+
"rope_theta": 10000.0,
41+
"sep_token_id": 120007,
42+
"text_end_id": 7,
43+
"text_start_id": 6,
44+
"tie_word_embeddings": true,
45+
"torch_dtype": "bfloat16",
46+
"transformers_version": "4.41.2",
47+
"use_cache": true,
48+
"use_cla": false,
49+
"use_qk_norm": true,
50+
"use_rotary_pos_emb": true,
51+
"vocab_size": 120818,
52+
"draft_vocab_size": 32000
53+
}

angelslim/compressor/speculative/train/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .chat_templates import get_supported_chat_template_type_strings
1516
from .data_generation import data_generation_work_flow
1617
from .data_utils import (
1718
DataCollatorWithPadding,
@@ -26,4 +27,5 @@
2627
"convert_sharegpt_data",
2728
"convert_ultrachat_data",
2829
"data_generation_work_flow",
30+
"get_supported_chat_template_type_strings",
2931
]

angelslim/compressor/speculative/train/data/chat_templates.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ class ChatTemplateType(Enum):
2727
"""Supported chat template types."""
2828

2929
QWEN3 = "qwen3"
30+
HUNYUAN = "hunyuan"
3031

3132

3233
# String to ChatTemplateType mapping
3334
CHAT_TEMPLATE_TYPE_MAPPING = {
3435
"qwen3": ChatTemplateType.QWEN3,
36+
"hunyuan": ChatTemplateType.HUNYUAN,
3537
}
3638

3739

@@ -75,7 +77,22 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
7577
"correct. If you don't know the answer to a question, "
7678
"please don't share false information."
7779
),
78-
)
80+
),
81+
ChatTemplateType.HUNYUAN: ChatTemplate(
82+
user_header="<|hy_User|>",
83+
assistant_header="<|hy_Assistant|>",
84+
system_prompt=(
85+
"You are a helpful, respectful and honest assistant. "
86+
"Always answer as helpfully as possible, while being safe. "
87+
"Your answers should not include any harmful, unethical, racist, "
88+
"sexist, toxic, dangerous, or illegal content. Please ensure that "
89+
"your responses are socially unbiased and positive in nature.\n\n"
90+
"If a question does not make any sense, or is not factually "
91+
"coherent, explain why instead of answering something not "
92+
"correct. If you don't know the answer to a question, "
93+
"please don't share false information."
94+
),
95+
),
7996
}
8097

8198
def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:

angelslim/compressor/speculative/train/data/dataset.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,82 @@ def __init__(
3636
max_length: int = 2048,
3737
shuffle_seed: int = 42,
3838
chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3,
39+
display: bool = False,
3940
):
4041
self.tokenizer = tokenizer
4142
self.max_length = max_length
4243
self.shuffle_seed = shuffle_seed
4344
self.chat_template_type = chat_template_type
45+
self.display = display
46+
self.display_count = 0 # Track how many samples have been displayed
4447

4548
# Get chat template
4649
template = template_manager.get_template_dict(chat_template_type)
4750
self.user_header = template["user_header"]
4851
self.assistant_header = template["assistant_header"]
4952
self.system_prompt = template["system_prompt"]
5053

54+
def _visualize_loss_mask(
55+
self, input_ids: torch.Tensor, loss_mask: torch.Tensor, conversation: str
56+
) -> None:
57+
"""
58+
Visualize loss_mask with color-coded output.
59+
60+
Args:
61+
input_ids: Token IDs
62+
loss_mask: Loss mask tensor (1 for training, 0 for ignoring)
63+
conversation: Original conversation text
64+
"""
65+
# ANSI color codes
66+
RED = "\033[91m" # For masked out tokens (loss_mask=0)
67+
GREEN = "\033[92m" # For training tokens (loss_mask=1)
68+
RESET = "\033[0m" # Reset color
69+
BOLD = "\033[1m"
70+
71+
rank0_print("\n" + "=" * 80)
72+
rank0_print(f"{BOLD}Loss Mask Visualization{RESET}")
73+
rank0_print("=" * 80)
74+
75+
# Display legend
76+
rank0_print(f"\n{BOLD}Legend:{RESET}")
77+
rank0_print(f"{GREEN}■ Green: Training tokens (loss_mask=1){RESET}")
78+
rank0_print(f"{RED}■ Red: Ignored tokens (loss_mask=0){RESET}")
79+
80+
# Display statistics
81+
total_tokens = len(loss_mask)
82+
training_tokens = loss_mask.sum().item()
83+
ignored_tokens = total_tokens - training_tokens
84+
training_ratio = training_tokens / total_tokens * 100 if total_tokens > 0 else 0
85+
86+
rank0_print(f"\n{BOLD}Statistics:{RESET}")
87+
rank0_print(f"Total tokens: {total_tokens}")
88+
rank0_print(f"Training tokens: {training_tokens} ({training_ratio:.2f}%)")
89+
rank0_print(f"Ignored tokens: {ignored_tokens} ({100-training_ratio:.2f}%)")
90+
91+
# Display token-by-token visualization
92+
rank0_print(f"\n{BOLD}Token-by-token visualization:{RESET}")
93+
rank0_print("-" * 80)
94+
95+
decoded_tokens = []
96+
for token_id, mask_value in zip(input_ids, loss_mask):
97+
token_text = self.tokenizer.decode([token_id], skip_special_tokens=False)
98+
99+
# Choose color based on mask value
100+
color = GREEN if mask_value == 1 else RED
101+
102+
# Format token with color
103+
colored_token = f"{color}{token_text}{RESET}"
104+
decoded_tokens.append(colored_token)
105+
106+
# Print all tokens directly
107+
rank0_print("".join(decoded_tokens))
108+
109+
# Display original conversation for reference
110+
rank0_print(f"\n{BOLD}Original conversation:{RESET}")
111+
rank0_print("-" * 80)
112+
rank0_print(conversation)
113+
rank0_print("=" * 80 + "\n")
114+
51115
def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset:
52116
try:
53117
# Load and shuffle dataset
@@ -67,8 +131,10 @@ def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset:
67131
desc="Processing conversations",
68132
)
69133

70-
# Filter out None results
71-
processed_ds = processed_ds.filter(lambda x: x["input_ids"] is not None)
134+
# Filter out None results with multiprocessing support
135+
processed_ds = processed_ds.filter(
136+
lambda x: x["input_ids"] is not None, num_proc=num_proc
137+
)
72138
processed_ds.set_format(type="torch")
73139

74140
return processed_ds
@@ -134,6 +200,11 @@ def _process_single_conversation(
134200
input_ids = torch.tensor(input_ids)
135201
attention_mask = torch.ones_like(input_ids)
136202

203+
# Visualize loss mask if display mode is enabled
204+
if self.display and self.display_count == 0:
205+
self._visualize_loss_mask(input_ids, loss_mask, conversation)
206+
self.display_count += 1
207+
137208
return {
138209
"input_ids": input_ids[None, :],
139210
"attention_mask": attention_mask[None, :],
@@ -262,6 +333,7 @@ def __init__(
262333
tokenizer: AutoTokenizer,
263334
model_max_length: int = 2048,
264335
chat_template_type: Optional[Union[str, ChatTemplateType]] = None,
336+
display: bool = False,
265337
):
266338
"""
267339
Initialize DatasetManager with DataArguments.
@@ -274,10 +346,12 @@ def __init__(
274346
- ChatTemplateType enum value (e.g., ChatTemplateType.QWEN3)
275347
- String (e.g., "llama", "qwen")
276348
- None (will default to LLAMA)
349+
display: Whether to display loss mask visualization for the first sample
277350
"""
278351
self.data_args = data_args
279352
self.tokenizer = tokenizer
280353
self.model_max_length = model_max_length
354+
self.display = display
281355

282356
# Convert chat_template_type to ChatTemplateType enum
283357
if chat_template_type is None:
@@ -293,6 +367,7 @@ def __init__(
293367
max_length=model_max_length,
294368
shuffle_seed=data_args.shuffle_seed,
295369
chat_template_type=chat_template_type,
370+
display=display,
296371
)
297372

298373
def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]:
@@ -305,8 +380,8 @@ def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]:
305380
"""
306381
# Determine number of processes
307382
num_proc = self.data_args.num_proc
308-
if self.data_args.preprocessing_num_workers is not None:
309-
num_proc = self.data_args.preprocessing_num_workers
383+
if self.display:
384+
num_proc = None
310385

311386
# Create train dataset
312387
train_dataset = self.dataset_builder.build_dataset(
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .draft import DraftModelConfig, create_draft_model
2+
from .target import create_target_model
3+
4+
__all__ = ["create_draft_model", "DraftModelConfig", "create_target_model"]

angelslim/compressor/speculative/train/models/target/target_model_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def load_model(self):
6565
param.requires_grad = False
6666

6767
self.model.eval()
68-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
68+
self.tokenizer = AutoTokenizer.from_pretrained(
69+
self.model_path, trust_remote_code=True
70+
)
6971

7072
def get_hidden_states_and_logits(
7173
self,
@@ -122,8 +124,8 @@ def __init__(self, backend: str, model_path: str, **kwargs):
122124
Initialize TargetModel with specified backend
123125
124126
Args:
125-
backend: One of ["hf", "vllm_local", "vllm_serving"]
126-
model_path: Path to model or serving endpoint
127+
backend: One of ["hf"]
128+
model_path: Path to model
127129
**kwargs: Additional arguments for backend initialization
128130
"""
129131
if backend not in self.BACKENDS:
@@ -148,8 +150,6 @@ def get_hidden_states_and_logits(
148150
Args:
149151
input_ids: Input token ids, shape [batch_size, seq_len]
150152
attention_mask: Attention mask, shape [batch_size, seq_len]
151-
position_ids: Position ids, shape [batch_size, seq_len]
152-
past_key_values: Past key values for generation
153153
154154
Returns:
155155
Tuple of (hidden_states, logits)

0 commit comments

Comments
 (0)