diff --git a/angelslim/compressor/speculative/__init__.py b/angelslim/compressor/speculative/__init__.py index b7e38f46..391fe606 100644 --- a/angelslim/compressor/speculative/__init__.py +++ b/angelslim/compressor/speculative/__init__.py @@ -17,7 +17,9 @@ DataCollatorWithPadding, DatasetManager, DraftModelConfig, + OfflineEagle3Trainer, OnlineEagle3Trainer, + TargetHead, convert_sharegpt_data, convert_ultrachat_data, create_draft_model, @@ -34,10 +36,12 @@ "DraftModelConfig", "create_target_model", "OnlineEagle3Trainer", + "OfflineEagle3Trainer", "data_generation_work_flow", "DataCollatorWithPadding", "convert_sharegpt_data", "convert_ultrachat_data", "DatasetManager", "get_supported_chat_template_type_strings", + "TargetHead", ] diff --git a/angelslim/compressor/speculative/train/__init__.py b/angelslim/compressor/speculative/train/__init__.py index fb97a38d..8857f11a 100644 --- a/angelslim/compressor/speculative/train/__init__.py +++ b/angelslim/compressor/speculative/train/__init__.py @@ -6,18 +6,25 @@ data_generation_work_flow, get_supported_chat_template_type_strings, ) -from .models import DraftModelConfig, create_draft_model, create_target_model -from .trainer import OnlineEagle3Trainer +from .models import ( + DraftModelConfig, + TargetHead, + create_draft_model, + create_target_model, +) +from .trainer import OfflineEagle3Trainer, OnlineEagle3Trainer __all__ = [ "create_draft_model", "DraftModelConfig", "create_target_model", "OnlineEagle3Trainer", + "OfflineEagle3Trainer", "data_generation_work_flow", "DataCollatorWithPadding", "convert_sharegpt_data", "convert_ultrachat_data", "DatasetManager", "get_supported_chat_template_type_strings", + "TargetHead", ] diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index fe97d9bc..b9c08b73 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -133,5 +133,24 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: "input_ids": batch_input_ids, "attention_mask": batch_attention_mask, "loss_mask": batch_loss_mask, + "hidden_states": None, + "target_hiddens": None, } + + # Check if both hidden_states and target_hiddens exist in all features + if all( + "hidden_states" in item and "target_hiddens" in item for item in features + ): + batch["hidden_states"] = torch.cat( + [ + self.paddingtensor(item["hidden_states"], max_length) + for item in features + ] + ) + batch["target_hiddens"] = torch.cat( + [ + self.paddingtensor(item["target_hiddens"], max_length) + for item in features + ] + ) return batch diff --git a/angelslim/compressor/speculative/train/data/dataset.py b/angelslim/compressor/speculative/train/data/dataset.py index 3394e944..0a95a1fd 100644 --- a/angelslim/compressor/speculative/train/data/dataset.py +++ b/angelslim/compressor/speculative/train/data/dataset.py @@ -13,7 +13,9 @@ # limitations under the License. import re -from typing import Any, Dict, List, Optional, Tuple, Union +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union import torch from datasets import load_dataset @@ -112,11 +114,18 @@ def _visualize_loss_mask( rank0_print(conversation) rank0_print("=" * 80 + "\n") - def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset: + def build_dataset( + self, datapath: str, num_proc: int = 8, shuffle: bool = True + ) -> Dataset: try: - # Load and shuffle dataset + # Load dataset ds = load_dataset("json", data_files=datapath) - ds = ds["train"].shuffle(seed=self.shuffle_seed) + + # Conditionally shuffle dataset + if shuffle: + ds = ds["train"].shuffle(seed=self.shuffle_seed) + else: + ds = ds["train"] # Store original columns for removal original_columns = ds.column_names @@ -282,97 +291,122 @@ def _build_messages(self, source: List[Dict]) -> List[Dict]: return messages if len(messages) > 1 else [] -class DataCollatorWithPadding: - def paddingtensor(self, intensors, N): - B, n, S = intensors.shape - # padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype) - padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor), dim=1) - return outtensors - - def paddingtensor2D(self, intensors, N): - B, n = intensors.shape - padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor), dim=1) - return outtensors - - def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - max_length = max(item["input_ids"].shape[1] for item in features) - batch_input_ids = torch.cat( - [self.paddingtensor2D(item["input_ids"], max_length) for item in features] - ) - batch_attention_mask = torch.cat( - [ - self.paddingtensor2D(item["attention_mask"], max_length) - for item in features - ] - ) - batch_loss_mask = torch.cat( - [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] - ) - - batch = { - "input_ids": batch_input_ids, - "attention_mask": batch_attention_mask, - "loss_mask": batch_loss_mask, - } - return batch - - class DatasetManager: """ - Simplified DatasetManager for train_eagle3_online.py. + Unified DatasetManager for EAGLE3 training. + + This manager supports creating datasets for: + - Offline mode: Loads pre-computed hidden states from .ckpt files for training + - Online mode: Processes raw conversation data on-the-fly - This manager is designed to work with DataArguments from train_eagle3_online.py - and provides a simple interface to create train and eval datasets. + Can create both types of datasets simultaneously when needed. """ def __init__( self, data_args, - tokenizer: AutoTokenizer, + tokenizer: Optional[AutoTokenizer] = None, model_max_length: int = 2048, chat_template_type: Optional[Union[str, ChatTemplateType]] = None, display: bool = False, + cache_in_memory: bool = False, ): """ Initialize DatasetManager with DataArguments. Args: - data_args: DataArguments object from train_eagle3_online.py - tokenizer: Tokenizer for the model + data_args: DataArguments object containing data paths and configurations + tokenizer: Tokenizer for the model (required for online dataset processing) model_max_length: Maximum sequence length - chat_template_type: Chat template type. Can be: + chat_template_type: Chat template type for conversation formatting. Can be: - ChatTemplateType enum value (e.g., ChatTemplateType.QWEN3) - String (e.g., "llama", "qwen") - - None (will default to LLAMA) + - None (will default to QWEN3) display: Whether to display loss mask visualization for the first sample + cache_in_memory: Whether to cache all data in memory for offline datasets """ self.data_args = data_args self.tokenizer = tokenizer self.model_max_length = model_max_length self.display = display + self.cache_in_memory = cache_in_memory # Convert chat_template_type to ChatTemplateType enum if chat_template_type is None: - # Default to QWEN3 chat_template_type = ChatTemplateType.QWEN3 elif isinstance(chat_template_type, str): - # Convert string to enum chat_template_type = string_to_chat_template_type(chat_template_type) - # Create dataset builder - self.dataset_builder = DatasetBuilder( - tokenizer=tokenizer, - max_length=model_max_length, - shuffle_seed=data_args.shuffle_seed, - chat_template_type=chat_template_type, - display=display, + self.chat_template_type = chat_template_type + + # Create dataset builder for online processing + self.dataset_builder = None + if tokenizer is not None: + self.dataset_builder = DatasetBuilder( + tokenizer=tokenizer, + max_length=model_max_length, + shuffle_seed=data_args.shuffle_seed, + chat_template_type=chat_template_type, + display=display, + ) + + def create_all_datasets( + self, + ) -> Tuple[Dataset, Optional[Dataset], Dataset, Optional[Dataset]]: + """ + Create all required datasets: offline and online datasets. + + Returns: + Tuple of (offline_train_dataset, offline_eval_dataset, + online_train_dataset, online_eval_dataset) + - offline_train_dataset: Offline training dataset from .ckpt files + - offline_eval_dataset: Offline evaluation dataset (None if not provided) + - online_train_dataset: Online training dataset from raw conversation data + - online_eval_dataset: Online evaluation dataset (None if not provided) + + Raises: + ValueError: If required paths are not provided + """ + # Create offline datasets (from .ckpt files) + offline_train_dataset, offline_eval_dataset = self._create_offline_datasets() + + # Create online datasets (from raw JSON data) if tokenizer is provided + online_train_dataset, online_eval_dataset = None, None + if self.tokenizer is not None and self.dataset_builder is not None: + online_train_dataset, online_eval_dataset = self._create_online_datasets() + + return ( + offline_train_dataset, + offline_eval_dataset, + online_train_dataset, + online_eval_dataset, ) - def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: + def create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: + """ + Create offline datasets only. + + Returns: + Tuple of (train_dataset, eval_dataset) + eval_dataset will be None if eval_hidden_path is not provided + """ + return self._create_offline_datasets() + + def create_online_datasets(self) -> Tuple[Optional[Dataset], Optional[Dataset]]: """ - Create train and eval datasets based on DataArguments. + Create online datasets only. + + Returns: + Tuple of (train_dataset, eval_dataset) + Both will be None if tokenizer not provided + """ + if self.tokenizer is None or self.dataset_builder is None: + return None, None + return self._create_online_datasets() + + def _create_online_datasets(self) -> Tuple[Optional[Dataset], Optional[Dataset]]: + """ + Create online datasets from raw conversation data. Returns: Tuple of (train_dataset, eval_dataset) @@ -383,16 +417,217 @@ def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: if self.display: num_proc = None + # Create training dataset + train_dataset = None + if self.data_args.train_data_path is not None: + train_dataset = self.dataset_builder.build_dataset( + self.data_args.train_data_path, num_proc=num_proc, shuffle=True + ) + + # Create evaluation dataset + eval_dataset = None + if self.data_args.eval_data_path is not None: + eval_dataset = self.dataset_builder.build_dataset( + self.data_args.eval_data_path, num_proc=num_proc, shuffle=False + ) + + return train_dataset, eval_dataset + + def _create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: + """ + Create offline datasets from pre-computed .ckpt files. + + Returns: + Tuple of (train_dataset, eval_dataset) + """ # Create train dataset - train_dataset = self.dataset_builder.build_dataset( - self.data_args.train_data_path, num_proc=num_proc + train_dataset = OfflineEagle3Dataset( + data_dir=self.data_args.train_hidden_path, + file_pattern="*.ckpt", + cache_in_memory=self.cache_in_memory, ) # Create eval dataset if path is provided eval_dataset = None - if self.data_args.eval_data_path is not None: - eval_dataset = self.dataset_builder.build_dataset( - self.data_args.eval_data_path, num_proc=num_proc + if self.data_args.eval_hidden_path is not None: + eval_dataset = OfflineEagle3Dataset( + data_dir=self.data_args.eval_hidden_path, + file_pattern="*.ckpt", + cache_in_memory=self.cache_in_memory, ) return train_dataset, eval_dataset + + +class OfflineEagle3Dataset(Dataset): + """ + Offline Dataset for EAGLE3 training. + + Loads pre-computed hidden states, logits, and other data from .ckpt files. + Each .ckpt file contains a dictionary with keys: input_ids, target_logits, + hidden_states, and loss_mask. + """ + + def __init__( + self, data_dir: str, file_pattern: str = "*.ckpt", cache_in_memory: bool = False + ): + """ + Initialize the OfflineEagle3Dataset. + + Args: + data_dir: Directory containing .ckpt files + (will search recursively in subdirectories) + file_pattern: Pattern to match checkpoint files (default: "*.ckpt") + cache_in_memory: Whether to cache all data in memory (default: False) + """ + self.data_dir = Path(data_dir) + self.cache_in_memory = cache_in_memory + + if not self.data_dir.exists(): + raise ValueError(f"Data directory does not exist: {data_dir}") + + # Recursively find all checkpoint files in subdirectories + self.ckpt_files = sorted(list(self.data_dir.rglob(file_pattern))) + + if len(self.ckpt_files) == 0: + raise ValueError( + f"No checkpoint files found in {data_dir} " + f"(including subdirectories) with pattern {file_pattern}" + ) + + rank0_print( + f"Found {len(self.ckpt_files)} checkpoint files " + f"in {data_dir} (including subdirectories)" + ) + + # Track valid indices (files that can be loaded successfully) + self.valid_indices = list(range(len(self.ckpt_files))) + + # Cache data in memory if requested + self.cached_data: Optional[List[Dict[str, torch.Tensor]]] = None + if self.cache_in_memory: + rank0_print("Caching all data in memory...") + self.cached_data = [] + failed_count = 0 + for i in range(len(self.ckpt_files)): + data = self._load_ckpt(i) + if data is not None: + self.cached_data.append(data) + else: + failed_count += 1 + + # Update valid indices based on successful loads + self.valid_indices = list(range(len(self.cached_data))) + + if failed_count > 0: + rank0_print( + f"Data caching completed. " + f"Successfully loaded {len(self.cached_data)} files, " + f"failed to load {failed_count} files" + ) + else: + rank0_print("Data caching completed") + + def _load_ckpt(self, idx: int) -> Optional[Dict[str, torch.Tensor]]: + """ + Load a checkpoint file. + + Args: + idx: Index of the checkpoint file + + Returns: + Dictionary containing input_ids, target_hiddens, + hidden_states, and loss_mask, or None if loading fails + """ + ckpt_path = self.ckpt_files[idx] + + try: + data = torch.load(ckpt_path, map_location="cpu") + except Exception as e: + warnings.warn( + f"Failed to load checkpoint {ckpt_path}: {e}. Skipping this file.", + RuntimeWarning, + stacklevel=2, + ) + return None + + # Validate required keys + required_keys = [ + "input_ids", # B, N + "target_hiddens", # B, N, D + "hidden_states", # B, N, 3*D + "loss_mask", # B, N + ] + missing_keys = [key for key in required_keys if key not in data] + + if missing_keys: + warnings.warn( + f"Checkpoint {ckpt_path} is missing required keys: {missing_keys}. " + f"Skipping this file.", + RuntimeWarning, + stacklevel=2, + ) + return None + + # Validate tensor types + for key in required_keys: + if not isinstance(data[key], torch.Tensor): + warnings.warn( + f"Value for key '{key}' in {ckpt_path} is not a torch.Tensor. " + f"Skipping this file.", + RuntimeWarning, + stacklevel=2, + ) + return None + + attention_mask = torch.ones_like(data["input_ids"]) + data["attention_mask"] = attention_mask # B, N + return data + + def __len__(self) -> int: + """Return the number of valid samples in the dataset.""" + if self.cached_data is not None: + return len(self.cached_data) + return len(self.valid_indices) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Get a sample from the dataset. + + Args: + idx: Index of the sample + + Returns: + Dictionary containing: + - input_ids: Token IDs (torch.Tensor) + - target_logits: Pre-computed logits from target + model (torch.Tensor) + - hidden_states: Pre-computed hidden states from + target model (torch.Tensor) + - loss_mask: Mask for loss computation (torch.Tensor) + """ + if self.cached_data is not None: + return self.cached_data[idx] + else: + # Try to load the checkpoint, retry with next valid index if fails + max_retries = len(self.valid_indices) + for _attempt in range(max_retries): + actual_idx = self.valid_indices[idx % len(self.valid_indices)] + data = self._load_ckpt(actual_idx) + if data is not None: + return data + else: + # Remove failed index from valid_indices + self.valid_indices.remove(actual_idx) + if len(self.valid_indices) == 0: + raise RuntimeError( + "All checkpoint files failed to load. " + "Cannot continue training." + ) + # Try next index + idx += 1 + + # If all retries failed, raise error + raise RuntimeError( + f"Failed to load any valid checkpoint after {max_retries} attempts" + ) diff --git a/angelslim/compressor/speculative/train/models/__init__.py b/angelslim/compressor/speculative/train/models/__init__.py index 3a2a5ea3..e04709d8 100644 --- a/angelslim/compressor/speculative/train/models/__init__.py +++ b/angelslim/compressor/speculative/train/models/__init__.py @@ -1,4 +1,9 @@ from .draft import DraftModelConfig, create_draft_model -from .target import create_target_model +from .target import TargetHead, create_target_model -__all__ = ["create_draft_model", "DraftModelConfig", "create_target_model"] +__all__ = [ + "create_draft_model", + "DraftModelConfig", + "create_target_model", + "TargetHead", +] diff --git a/angelslim/compressor/speculative/train/models/target/__init__.py b/angelslim/compressor/speculative/train/models/target/__init__.py index 1cbbebed..98d7bb40 100644 --- a/angelslim/compressor/speculative/train/models/target/__init__.py +++ b/angelslim/compressor/speculative/train/models/target/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .target_head import TargetHead from .target_model_wrapper import create_target_model -__all__ = ["create_target_model"] +__all__ = ["create_target_model", "TargetHead"] diff --git a/angelslim/compressor/speculative/train/models/target/target_head.py b/angelslim/compressor/speculative/train/models/target/target_head.py new file mode 100644 index 00000000..d4bd3758 --- /dev/null +++ b/angelslim/compressor/speculative/train/models/target/target_head.py @@ -0,0 +1,133 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import torch +from safetensors import safe_open +from torch import nn +from transformers import AutoConfig + +from angelslim.utils import decide_device_for_distributed + + +class TargetHead(nn.Module): + """ + Target Head for computing logits from hidden states in offline EAGLE3 training. + + This module takes the last hidden states from the target model and projects them + to vocabulary logits, which are used as training targets for the draft model. + """ + + def __init__(self, lm_head: nn.Module): + """ + Initialize the TargetHead. + + Args: + lm_head: Language model head (typically nn.Linear) that projects + hidden states to vocabulary logits. This should be loaded + from the target model. + """ + super().__init__() + self.lm_head = lm_head + + # Freeze the lm_head parameters since we only use it for inference + for param in self.lm_head.parameters(): + param.requires_grad = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Compute logits from hidden states. + + Args: + hidden_states: Hidden states from target model with shape + (batch_size, seq_length, hidden_size) + + Returns: + Logits with shape (batch_size, seq_length, vocab_size) + """ + # Project hidden states to vocabulary logits + logits = self.lm_head(hidden_states) + return logits + + @classmethod + def from_pretrained( + cls, model_name_or_path: str, lm_head_key: str = "lm_head.weight" + ): + """ + Load TargetHead from a pretrained model efficiently. + + This method only loads the lm_head weights using safetensors index, + which is more memory-efficient than loading the entire model. + + Args: + model_name_or_path: Path to pretrained model or model identifier + **kwargs: Additional arguments for model loading (e.g., torch_dtype, + trust_remote_code, device_map) + + Returns: + TargetHead instance with loaded lm_head + """ + # Load model config to get architecture info + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + + # Get model dimensions + hidden_size = config.hidden_size + vocab_size = config.vocab_size + + # Initialize lm_head + lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + + # Load lm_head weights from safetensors + try: + # Read safetensors index to locate lm_head weights + index_path = os.path.join( + model_name_or_path, "model.safetensors.index.json" + ) + + if not os.path.exists(index_path): + raise FileNotFoundError( + f"model.safetensors.index.json not found in {model_name_or_path}. " + "Please ensure the model is saved in safetensors " + "format with sharding." + ) + + # Model is sharded, use index to find lm_head + with open(index_path, "r") as f: + index_json = json.loads(f.read()) + head_path = index_json["weight_map"][lm_head_key] + + # Load lm_head weights using safetensors + safetensors_file = os.path.join(model_name_or_path, head_path) + with safe_open(safetensors_file, framework="pt", device="cpu") as f: + tensor_slice = f.get_slice(lm_head_key) + _, hidden_dim = tensor_slice.get_shape() + tensor = tensor_slice[:, :hidden_dim] + lm_head.weight.data = tensor + + except Exception as e: + raise RuntimeError( + f"Failed to load lm_head weights from {model_name_or_path}. " + f"Error: {str(e)}" + ) + + # Create TargetHead instance + target_head = cls(lm_head) + + device = decide_device_for_distributed() + target_head.to(device) + target_head.eval() + + return target_head diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 3352ca87..72f4ee7a 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -21,17 +21,34 @@ class BaseBackend(ABC): - """Base class for model backends""" + """ + Base class for model backends. + + This abstract class defines the interface that all backend implementations + must follow to ensure consistent behavior across different model serving frameworks. + """ def __init__(self, model_path: str, **kwargs): + """ + Initialize the backend. + + Args: + model_path: Path to the model checkpoint or serving endpoint + **kwargs: Additional backend-specific configuration parameters + """ self.model_path = model_path self.kwargs = kwargs self.model = None self.tokenizer = None @abstractmethod - def load_model(self): - """Load the backend model""" + def load_model(self) -> None: + """ + Load the backend model and tokenizer. + + This method should initialize self.model and self.tokenizer. + Implementations should handle device placement and model configuration. + """ pass @abstractmethod @@ -41,88 +58,235 @@ def get_hidden_states_and_logits( attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Get hidden states and logits from model""" + """ + Extract hidden states and logits from the model. + + Args: + input_ids: Input token IDs, shape [batch_size, seq_len] + attention_mask: Attention mask, shape [batch_size, seq_len] + **kwargs: Additional model-specific arguments + + Returns: + Tuple of (hidden_states, logits): + - hidden_states: Concatenated auxiliary hidden states, + shape [batch_size, seq_len, hidden_size * num_layers] + - logits: Model output logits, shape [batch_size, seq_len, vocab_size] + """ + pass + + @abstractmethod + def get_aux_and_target_hiddens( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract auxiliary and target hidden states from the model. + + Args: + input_ids: Input token IDs, shape [batch_size, seq_len] + attention_mask: Attention mask, shape [batch_size, seq_len] + **kwargs: Additional model-specific arguments + + Returns: + Tuple of (aux_hidden_states, target_hidden_states): + - aux_hidden_states: Concatenated auxiliary hidden states + from multiple layers + - target_hidden_states: Final layer hidden states + """ pass + def _get_default_aux_layer_ids(self, total_layers: int) -> List[int]: + """ + Calculate default auxiliary hidden state layer indices. + + Selects three representative layers: early, middle, and late in the model. + + Args: + total_layers: Total number of hidden state layers (including embedding) + + Returns: + List of three layer indices [low, mid, high] + """ + return [ + 1, # Early layer + total_layers // 2 - 1, # Middle layer + total_layers - 4, # Late layer (before final layers) + ] + + def _extract_auxiliary_hidden_states( + self, + hidden_states: Tuple[torch.Tensor, ...], + aux_layer_ids: Optional[List[int]] = None, + ) -> torch.Tensor: + """ + Extract and concatenate auxiliary hidden states from specified layers. + + Args: + hidden_states: Tuple of hidden states from all layers + aux_layer_ids: List of layer indices to extract. + If None, uses default layers. + + Returns: + Concatenated hidden states, shape [batch_size, seq_len, hidden_size * 3] + """ + if aux_layer_ids is None: + aux_layer_ids = self._get_default_aux_layer_ids(len(hidden_states)) + + # Offset by 1 to skip embedding layer + embed_offset = 1 + + selected_hiddens = [ + hidden_states[layer_id + embed_offset] for layer_id in aux_layer_ids + ] + + return torch.cat(selected_hiddens, dim=-1) + class TransformersBackend(BaseBackend): - """HuggingFace Transformers backend""" + """ + HuggingFace Transformers backend implementation. - def load_model(self): + This backend uses the transformers library's AutoModelForCausalLM + for model loading and inference. + """ + + def load_model(self) -> None: + """Load model and tokenizer using HuggingFace Transformers.""" from transformers import AutoModelForCausalLM, AutoTokenizer - # Get device based on environment + # Determine device based on distributed environment device = decide_device_for_distributed() - - # Print device information with rank details print_with_rank(f"Loading model to device: {device}") - # Update kwargs with default values + # Prepare model loading configuration + model_kwargs = self._prepare_model_kwargs(device) + + # Load and configure model + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, **model_kwargs + ) + self._freeze_model_parameters() + self.model.eval() + + # Load tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True + ) + + def _prepare_model_kwargs(self, device: str) -> dict: + """ + Prepare keyword arguments for model loading. + + Args: + device: Target device for model placement + + Returns: + Dictionary of model loading arguments + """ default_kwargs = { "dtype": torch.bfloat16, "device_map": device, "trust_remote_code": True, } default_kwargs.update(self.kwargs) + return default_kwargs - # Load model to specific device based on rank - self.model = AutoModelForCausalLM.from_pretrained( - self.model_path, **default_kwargs - ) - - # Freeze the base model + def _freeze_model_parameters(self) -> None: + """Freeze all model parameters to prevent training.""" for param in self.model.parameters(): param.requires_grad = False - self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, trust_remote_code=True - ) - def get_hidden_states_and_logits( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract hidden states and logits using Transformers backend. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers + + Returns: + Tuple of (concatenated_hidden_states, logits) + """ with torch.no_grad(): outputs = self.model( - input_ids, attention_mask, output_hidden_states=True, output_logits=True + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + output_logits=True, ) - aux_hidden_states_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) - if aux_hidden_states_layer_ids is None: - out_hidden_nums = len(outputs.hidden_states) - aux_hidden_states_layer_ids = [ - 1, - out_hidden_nums // 2 - 1, - out_hidden_nums - 4, - ] + # Extract auxiliary hidden states + aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) + hidden_states = self._extract_auxiliary_hidden_states( + outputs.hidden_states, aux_layer_ids + ) - embed_offset = 1 - low_hidden_states = outputs.hidden_states[ - aux_hidden_states_layer_ids[0] + embed_offset - ] - mid_hidden_states = outputs.hidden_states[ - aux_hidden_states_layer_ids[1] + embed_offset - ] - high_hidden_states = outputs.hidden_states[ - aux_hidden_states_layer_ids[2] + embed_offset - ] - hidden_states = torch.cat( - [low_hidden_states, mid_hidden_states, high_hidden_states], dim=-1 + # Return hidden states and logits on the same device as input + return hidden_states, outputs.logits.to(input_ids.device) + + def get_aux_and_target_hiddens( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract auxiliary and final layer hidden states. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers + + Returns: + Tuple of (auxiliary_hidden_states, final_hidden_states) + """ + with torch.no_grad(): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + output_logits=True, + ) + + # Extract auxiliary hidden states + aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) + aux_hidden_states = self._extract_auxiliary_hidden_states( + outputs.hidden_states, aux_layer_ids ) - target = outputs.logits - return hidden_states, target.to(input_ids.device) + # Get final layer hidden states + target_hidden_states = outputs.hidden_states[-1] + + return aux_hidden_states, target_hidden_states class TargetModelWrapper: """ - Target model wrapper for Eagle3 training. - - Supports three backends: - - hf: HuggingFace Transformers AutoModelForCausalLM + Unified wrapper for target models in Eagle3 training. + + This wrapper provides a consistent interface across + different backend implementations, allowing seamless switching + between model serving frameworks. + + Supported backends: + - hf: HuggingFace Transformers (AutoModelForCausalLM) + + Example: + >>> wrapper = TargetModelWrapper( + ... backend="hf", + ... model_path="/path/to/model", + ... dtype=torch.bfloat16 + ... ) + >>> hidden_states, logits = wrapper.get_hidden_states_and_logits(input_ids) """ BACKENDS = { @@ -131,23 +295,39 @@ class TargetModelWrapper: def __init__(self, backend: str, model_path: str, **kwargs): """ - Initialize TargetModel with specified backend + Initialize TargetModelWrapper with specified backend. Args: - backend: One of ["hf"] - model_path: Path to model - **kwargs: Additional arguments for backend initialization + backend: Backend identifier, one of ["hf"] + model_path: Path to model checkpoint or serving endpoint + **kwargs: Backend-specific configuration parameters + + Raises: + ValueError: If backend is not supported """ - if backend not in self.BACKENDS: - raise ValueError( - f"Unsupported backend: {backend}. " - f"Available backends: {list(self.BACKENDS.keys())}" - ) + self._validate_backend(backend) self.backend_name = backend self.backend = self.BACKENDS[backend](model_path, **kwargs) self.backend.load_model() + def _validate_backend(self, backend: str) -> None: + """ + Validate that the requested backend is supported. + + Args: + backend: Backend identifier to validate + + Raises: + ValueError: If backend is not in BACKENDS + """ + if backend not in self.BACKENDS: + available = ", ".join(self.BACKENDS.keys()) + raise ValueError( + f"Unsupported backend: '{backend}'. " + f"Available backends: [{available}]" + ) + def get_hidden_states_and_logits( self, input_ids: torch.Tensor, @@ -155,16 +335,18 @@ def get_hidden_states_and_logits( **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Get hidden states and logits from target model + Get hidden states and logits from target model. Args: - input_ids: Input token ids, shape [batch_size, seq_len] + input_ids: Input token IDs, shape [batch_size, seq_len] attention_mask: Attention mask, shape [batch_size, seq_len] + **kwargs: Additional backend-specific arguments Returns: - Tuple of (hidden_states, logits) - - hidden_states: shape [batch_size, seq_len, hidden_size] - - logits: shape [batch_size, seq_len, vocab_size] + Tuple of (hidden_states, logits): + - hidden_states: Concatenated auxiliary hidden states, + shape [batch_size, seq_len, hidden_size * num_aux_layers] + - logits: Model output logits, shape [batch_size, seq_len, vocab_size] """ return self.backend.get_hidden_states_and_logits( input_ids=input_ids, @@ -172,20 +354,59 @@ def get_hidden_states_and_logits( **kwargs, ) + def get_aux_and_target_hiddens( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get auxiliary and target hidden states from model. + + Args: + input_ids: Input token IDs, shape [batch_size, seq_len] + attention_mask: Attention mask, shape [batch_size, seq_len] + **kwargs: Additional backend-specific arguments + + Returns: + Tuple of (aux_hidden_states, target_hidden_states) + """ + return self.backend.get_aux_and_target_hiddens( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + @property def model(self): - """Access underlying model""" + """ + Access the underlying model instance. + + Returns: + The backend's model object + """ return self.backend.model @property def tokenizer(self): - """Access underlying tokenizer""" + """ + Access the underlying tokenizer instance. + + Returns: + The backend's tokenizer object + + Raises: + AttributeError: If backend doesn't support tokenizers + ValueError: If tokenizer is not initialized + """ if not hasattr(self.backend, "tokenizer"): raise AttributeError( - f"Backend '{self.backend_name}' does not have a tokenizer attribute" + f"Backend '{self.backend_name}' does not support tokenizers" ) if self.backend.tokenizer is None: - raise ValueError(f"Backend '{self.backend_name}' does not have a tokenizer") + raise ValueError( + f"Tokenizer not initialized for backend '{self.backend_name}'" + ) return self.backend.tokenizer @@ -199,28 +420,42 @@ def create_target_model( """ Factory function to create target model with appropriate backend configuration. + This function provides a convenient way to instantiate a TargetModelWrapper + with commonly used default settings. + Args: backend: Backend type, one of ["hf"] - model_path: Path to model or serving endpoint URL + model_path: Path to model checkpoint or serving endpoint URL torch_dtype: Data type for model weights (for HF backend) - trust_remote_code: Whether to trust remote code - tokenizer_path: Path to tokenizer + trust_remote_code: Whether to trust and execute remote code **extra_kwargs: Additional backend-specific arguments Returns: - TargetModelWrapper instance + Configured TargetModelWrapper instance + + Raises: + ValueError: If backend is not supported + + Example: + >>> model = create_target_model( + ... backend="hf", + ... model_path="/path/to/llama-7b", + ... torch_dtype=torch.float16 + ... ) """ - # Prepare common kwargs - kwargs = {"trust_remote_code": trust_remote_code, **extra_kwargs} + # Prepare common configuration + kwargs = { + "trust_remote_code": trust_remote_code, + **extra_kwargs, + } - # Add backend-specific kwargs + # Add backend-specific configuration if backend == "hf": - kwargs.update( - { - "dtype": torch_dtype, - } - ) + kwargs["dtype"] = torch_dtype else: - raise ValueError(f"Unsupported backend: {backend}") + raise ValueError( + f"Unsupported backend: '{backend}'. " + f"Use one of: {list(TargetModelWrapper.BACKENDS.keys())}" + ) return TargetModelWrapper(backend=backend, model_path=model_path, **kwargs) diff --git a/angelslim/compressor/speculative/train/trainer/__init__.py b/angelslim/compressor/speculative/train/trainer/__init__.py index 092a7d36..2417866a 100644 --- a/angelslim/compressor/speculative/train/trainer/__init__.py +++ b/angelslim/compressor/speculative/train/trainer/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .eagle3_trainer import OnlineEagle3Trainer +from .eagle3_trainer import OfflineEagle3Trainer, OnlineEagle3Trainer -__all__ = ["OnlineEagle3Trainer"] +__all__ = ["OnlineEagle3Trainer", "OfflineEagle3Trainer"] diff --git a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py index 43890b92..f3cc00c8 100644 --- a/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/eagle3_trainer.py @@ -72,12 +72,12 @@ def compute_loss( """ data_for_draft_model = self.prepare_data_for_draft_model(inputs) - attention_mask = data_for_draft_model["attention_mask"] - position_ids = data_for_draft_model["position_ids"] - input_ids = data_for_draft_model["input_ids"] - target_logits = data_for_draft_model["target_logits"] - loss_mask = data_for_draft_model["loss_mask"] - hidden_states = data_for_draft_model["hidden_states"] + attention_mask = data_for_draft_model["attention_mask"] # Batch x Seq + position_ids = data_for_draft_model["position_ids"] # Batch x Seq + input_ids = data_for_draft_model["input_ids"] # Batch x Seq + target_logits = data_for_draft_model["target_logits"] # Batch x Seq x Vocab + loss_mask = data_for_draft_model["loss_mask"] # Batch x Seq x 1 + hidden_states = data_for_draft_model["hidden_states"] # Batch x Seq x Hidden hidden_states = self.down_project_hidden_states(hidden_states) attention_mask, position_ids = self.prepare_attention_mask_and_position_ids( @@ -353,3 +353,80 @@ def prepare_data_for_draft_model(self, inputs): "position_ids": position_ids, "attention_mask": attention_mask, } + + +class OfflineEagle3Trainer(Eagle3Trainer): + """ + Offline EAGLE3 Trainer for speculative decoding training. + + Uses pre-computed hidden states and logits from offline processing, + avoiding the need for online target model inference. + """ + + def __init__( + self, draft_model: nn.Module, target_head: nn.Module, length: int, **kwargs + ): + """ + Initialize the OnlineEagle3Trainer. + + Args: + draft_model: Draft model for token prediction + length: Number of speculative decoding steps + **kwargs: Additional arguments passed to parent Trainer + """ + super().__init__(draft_model=draft_model, length=length, **kwargs) + self.target_head = target_head + + def prepare_data_for_draft_model( + self, inputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Prepare data for draft model training from offline-generated inputs. + + Args: + inputs: Dictionary containing: + - input_ids: Token IDs + - target_hiddens: Pre-computed last hidden states from target model + - hidden_states: Pre-computed aux hidden states from target model + - attention_mask: Attention mask + - loss_mask: Mask for loss computation + - position_ids (optional): Position IDs + + Returns: + Dictionary with prepared data for draft model training + """ + # + inputs_fields = [ + "input_ids", + "target_hiddens", + "hidden_states", + "attention_mask", + "loss_mask", + ] + output_fields = [ + "input_ids", + "target_logits", + "hidden_states", + "attention_mask", + "loss_mask", + "position_ids", + ] + + # + target_logits = self.target_head(inputs["target_hiddens"]) + position_ids = inputs.get("position_ids", None) + loss_mask = inputs["loss_mask"] + input_ids = inputs["input_ids"] + + # Apply right padding and move tensors to correct device + target_logits = padding(target_logits, left=False).to(input_ids.device) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None].to(input_ids.device) + + outputs = {k: inputs[k] for k in inputs_fields if k in output_fields} + outputs["target_logits"] = target_logits + outputs["position_ids"] = position_ids + outputs["loss_mask"] = loss_mask + outputs["input_ids"] = input_ids + + return outputs diff --git a/docs/source/features/speculative_decoding/eagle.md b/docs/source/features/speculative_decoding/eagle.md index 272165c3..9cc2d3c6 100644 --- a/docs/source/features/speculative_decoding/eagle.md +++ b/docs/source/features/speculative_decoding/eagle.md @@ -1,48 +1,212 @@ # EAGLE -本项目将集成包括但不限于[Eagle](https://arxiv.org/pdf/2503.01840)系列的投机采样算法。 -我们计划将对投机采样算法的代码以及部分开源大模型的Eagle3权重开源。 -作为第一批开源内容,我们将提供Qwen3系列的[Eagle3权重](https://huggingface.co/collections/AngelSlim/qwen3-eagle-686787e3258f84fb09019f32)。 -后续,更多的代码和其他大模型的权重也将陆续开源,敬请关注。 +[Eagle3](https://arxiv.org/pdf/2503.01840)是目前最常用、加速效果最好的投机采样算法。 +本项目包括Eagle3的训练以及benchmark测试,并开源了Qwen3和Hunyuan系列的[Eagle3权重](https://huggingface.co/collections/AngelSlim/eagle3)。 + 我们训练的Qwen3系列Eagle3模型的表现可以参见基准测试[benchmarks](../../performance/speculative_decoding/benchmarks.md), 其中全部数据都是在单张H20上使用pytorch推理获得。 -## 快速测试 - -### SGLang -目前sglang已经支持Qwen3-8B/14B/30B-A3B模型的eagle3部署,你可以选择使用sglang作为推理后端快速验证Eagle3模型的加速效果。 -在已经安装sglang的环境中使用以下命令可以快速启动一个兼容Openai的服务,然后即可以通过本地端口进行请求了。 -- 启动兼容OpenAI格式的API服务 - - 以下指令将启动兼容OpenAI API格式的服务,默认在 http://0.0.0.0:8080 地址进行访问: - - ```shell - python3 -m sglang.launch_server \ - --model ${TARGET_MODEL_PATH_OR_NAME} \ - --speculative-algorithm EAGLE3 \ - --speculative-draft-model-path ${EAGLE3_MODEL_PATH} \ - --speculative-num-steps 6 \ - --speculative-eagle-topk 10 \ - --speculative-num-draft-tokens 32 \ - --mem-fraction 0.9 \ - --dtype bfloat16 - ``` - 其中: - - `TARGET_MODEL_PATH_OR_NAME`为本地路径或模型在huggingface上的名字; - - `EAGLE3_MODEL_PATH`为Eagle3模型路径或在huggingface上的名字; - - -### vLLM -目前vllm已经支持Hunyuan-1.8B-Instruct/4B-Instruct/7B-Instruct模型的eagle3部署,你可以选择使用vllm作为推理后端快速验证Eagle3模型的加速效果。 -在已经安装正确的[vllm commit](https://github.com/vllm-project/vllm/pull/22080) 的环境中使用以下命令可以快速启动一个兼容Openai的服务,然后即可以通过本地端口进行请求了。 -- 启动兼容OpenAI格式的API服务 - - ```shell - python3 -m vllm.entrypoints.openai.api_server --tensor-parallel-size 1 \ - --port 8000 \ - --speculative_config '{"model": "AngelSlim/Hunyuan-1.8B-Instruct_eagle3", "method" : "eagle3", "draft_tensor_parallel_size" : 1, "num_speculative_tokens": 2}' --trust-remote-code \ - --model tencent/Hunyuan-1.8B-Instruct - ``` -但是由于vllm最新版本Eagle3并不支持tree attention, 因此推理验证时为chain-base推理模式。 - -## 训练及创新 -Comming soon. \ No newline at end of file +## 1. 数据生成 + +数据生成包括:1)为目标模型生成采样数据,2)为Eagle3模型离线生成目标模型的hidden states。 + +### 1.1 为目标模型生成采样数据 + +生成采样数据为可选项,当有足够数量以及足够质量的目标模型SFT数据时,此步可略过。当训练数据和目标模型不配套时,则需要为目标模型重新采样生成数据。 + +**步骤1:启动vLLM server** + +首先需要启动vLLM server来提供模型推理服务: + +```shell +bash scripts/speculative/run_vllm_server.sh +``` + +**server配置说明:** +- 该脚本会启动目标基础模型的vLLM推理服务 +- 确保服务器成功启动后再进行下一步数据生成 +- 可以通过修改脚本中的参数来调整vLLM server配置(如vLLM启动参数、GPU数量等),来适应不同的目标模型 + +**步骤2:生成采样数据** + +vLLM server启动后,使用 `scripts/speculative/generate_data_for_target_model.sh` 脚本生成训练数据: + +```shell +bash scripts/speculative/generate_data_for_target_model.sh +``` + +**脚本功能说明:** +- 通过vLLM server调用目标基础模型对输入数据进行采样 +- 生成 `.jsonl` 格式的训练数据集 +- 数据将用于后续Eagle模型的在线训练 + +**脚本参数说明:** + +在使用前,需要在脚本中配置以下参数: + +- `DATA_NAME_OR_PATH`: 输入数据集的HF名称或本地路径 +- `OUTPUT_DIR`: 生成的数据集输出路径 +- `DATA_FORMAT`: 输入数据集的格式(sharegpt|ultrachat) +- `DATA_SHARD_SIZE`: 生成数据集的切分子集大小 +- `BASE_PORT`: vLLM server的端口号 + +**注意事项:** +- 确保vLLM服务器已成功启动并正常运行 +- 数据生成过程可能需要较长时间,取决于样本数量和模型规模 + + +### 1.2 为Eagle3模型生成hidden states + +目前仅支持以HF为后端生成hidden states,调用脚本如下: +```shell +bash scripts/speculative/generate_hidden_for_draft_model.sh +``` + +**脚本参数说明:** + +在使用前,需要在脚本中配置以下参数: + +- `DATASET_PATH`: 输入数据集的HF名称或本地路径 +- `MODEL_NAME`: 目标模型的HF名称或本地路径 +- `TARGET_BACKEND`: 目标模型后端,目前仅支持HF +- `MODEL_MAX_LENGTH`: 生成数据的上下文长度 +- `CHAT_TEMPLATE_TYPE`: 目标模型的目标类型,目前支持qwen3/hunyuan +- `OUTPUT_DIR`: 生成的数据集输出路径 + + +## 2. 训练Eagle3模型 + +目前支持在线训练和离线训练两种模式:在线训练适合显存足够、目标模型不大、训练上下文长度不要求极长的场景, +离线训练适合大尺寸目标模型、磁盘空间足够、长上下文训练场景。 + +### 2.1 在线训练 + +使用 `scripts/speculative/train_eagle3_online.sh` 脚本进行Eagle3模型的在线训练: + +```shell +bash scripts/speculative/train_eagle3_online.sh +``` + +**脚本参数说明:** + +在使用前,需要在脚本中配置以下参数: + +- `TARGET_MODEL_NAME_OR_PATH`: 目标模型的HF名称或本地名称 +- `DRAFT_MODEL_CONFIG_PATH`: 草稿模型的config路径 +- `TRAIN_DATA_PATH`: 训练数据路径 +- `EVAL_DATA_PATH`: 验证数据路径 +- `OUTPUT_DIR`: Eagle3模型输出路径 +- `MODEL_MAX_LENGTH`: 训练数据的最大长度 +- `CHAT_TEMPLATE_TYPE`: 目标模型的数据模板类型 + +### 2.2 离线训练 + +在离线训练前,必须要完成`1.2` 为Eagle3模型生成hidden states。 +使用 `scripts/speculative/train_eagle3_offline.sh` 脚本进行Eagle3模型的离线训练: + +```shell +bash scripts/speculative/train_eagle3_offline.sh +``` + +**脚本参数说明:** + +在使用前,需要在脚本中配置以下参数: + +- `TARGET_MODEL_NAME_OR_PATH`: 目标模型的HF名称或本地名称 +- `DRAFT_MODEL_CONFIG_PATH`: 草稿模型的config路径 +- `TRAIN_DATA_PATH`: 训练数据路径,.jsonl格式 +- `TRAIN_HIDDEN_PATH`: 训练hidden states数据路径 +- `EVAL_HIDDEN_PATH`: 验证hidden states数据路径 +- `OUTPUT_DIR`: Eagle3模型输出路径 +- `MODEL_MAX_LENGTH`: 训练数据的最大长度 +- `CHAT_TEMPLATE_TYPE`: 目标模型的数据模板类型 +- `LM_HEAD_KEY`: 目标模型lm head的weight key名称,可以在model.safetensors.index.json中查看,默认为lm_head.weight时可不指定这个参数。当为model.embed_tokens.weight时,需要指定。 +- `RUN_NAME`: 当`report_to`设为wand时,可以指定该参数设置wand中的run name。 + + +## 3. 基准测试 + +AngelSlim提供了完整的Eagle3基准测试工具,用于评估投机采样的性能提升。 + +### 3.1 基本用法 + +使用 `tools/spec_benchmark.py` 脚本进行投机采样基准测试: + +```shell +python3 tools/spec_benchmark.py \ + --base-model-path ${BASE_MODEL_PATH} \ + --eagle-model-path ${EAGLE_MODEL_PATH} \ + --model-id ${MODEL_ID} \ + --mode both +``` + +### 3.2 参数说明 + +**模型配置参数:** +- `--base-model-path`: 基础模型路径(必需) +- `--eagle-model-path`: Eagle辅助模型路径(必需) +- `--model-id`: 模型标识符(必需) + +**基准测试配置:** +- `--bench-name`: 基准数据集名称,默认为 `mt_bench`, 可选【`alpaca`,`gsm8k`,`humaneval`,`mt_bench`】 +- `--mode`: 执行模式,可选 `eagle`(仅投机采样)、`baseline`(仅基线)、`both`(两者都执行),默认为 `both` +- `--output-dir`: 结果输出目录 + +**生成参数:** +- `--temperature`: 采样温度,默认为 1.0 +- `--max-new-token`: 最大生成token数,默认为 1024 +- `--total-token`: 草稿树中的总节点数,默认为 60 +- `--depth`: 树深度,默认为 5 +- `--top-k`: Top-k采样,默认为 10 + +**硬件配置:** +- `--num-gpus-per-model`: 每个模型使用的GPU数量,默认为 1 +- `--num-gpus-total`: 总GPU数量,默认为 1 +- `--max-gpu-memory`: 每个GPU的最大内存限制 + +**其他设置:** +- `--seed`: 随机种子,默认为 42 +- `--question-begin`: 问题起始索引(用于调试) +- `--question-end`: 问题结束索引(用于调试) +- `--no-metrics`: 跳过自动指标计算 + +### 3.3 使用示例 + +**完整基准测试(推荐):** +```shell +python3 tools/spec_benchmark.py \ + --base-model-path /path/to/base/model \ + --eagle-model-path /path/to/eagle/model \ + --model-id qwen3-8b \ + --mode both \ + --output-dir ./results \ + --max-new-token 512 \ + --temperature 0.0 +``` + +**仅运行投机采样:** +```shell +python3 tools/spec_benchmark.py \ + --base-model-path /path/to/base/model \ + --eagle-model-path /path/to/eagle/model \ + --model-id qwen3-8b \ + --mode eagle +``` + +**多GPU配置:** +```shell +python3 tools/spec_benchmark.py \ + --base-model-path /path/to/base/model \ + --eagle-model-path /path/to/eagle/model \ + --model-id qwen3-8b \ + --num-gpus-per-model 1 \ + --num-gpus-total 8 +``` + +### 3.4 性能报告 + +运行完成后,工具会自动生成性能报告,包括: +- 投机采样与基线模型的性能对比 +- 加速比统计 +- 生成质量指标(如果启用) + +结果将保存在指定的输出目录中,便于后续分析和比较。 \ No newline at end of file diff --git a/docs/source/getting_started/quickstrat.md b/docs/source/getting_started/quickstrat.md index 532ef46d..0fa74bf9 100644 --- a/docs/source/getting_started/quickstrat.md +++ b/docs/source/getting_started/quickstrat.md @@ -47,163 +47,10 @@ python3 tools/fp8_quant_blockwise.py \ ### 投机采样 -投机采样(Speculative Decoding)是一种加速大语言模型推理的技术,通过使用较小的辅助模型来预测后续token,然后由主模型进行验证,从而提高生成效率。AngelSlim提供了完整的Eagle3训练和基准测试工具。 +投机采样(Speculative Decoding)是一种加速大语言模型推理的技术,通过草稿模型预测多个token,然后由主模型一次性进行验证,从而提高生成效率。 +Eagle3是目前使用最多、加速效果最好的投机采样方法,AngelSlim提供了完整的Eagle3训练和基准测试工具。 -#### 1. 训练Eagle3模型 - -Eagle3模型的训练分为两个步骤:数据准备和在线训练。 - -##### 1.1 准备训练数据 - -训练数据的准备分为两个步骤:启动vLLM server和生成采样数据。 - -**步骤1:启动vLLM server** - -首先需要启动vLLM server来提供模型推理服务: - -```shell -bash scripts/speculative/run_vllm_server.sh -``` - -**server配置说明:** -- 该脚本会启动目标基础模型的vLLM推理服务 -- 确保服务器成功启动后再进行下一步数据生成 -- 可以通过修改脚本中的参数来调整vLLM server配置(如vLLM启动参数、GPU数量等),来适应不同的目标模型 - -**步骤2:生成采样数据** - -vLLM server启动后,使用 `scripts/speculative/generate_data_for_target_model.sh` 脚本生成训练数据: - -```shell -bash scripts/speculative/generate_data_for_target_model.sh -``` - -**脚本功能说明:** -- 通过vLLM server调用目标基础模型对输入数据进行采样 -- 生成 `.jsonl` 格式的训练数据集 -- 数据将用于后续Eagle模型的在线训练 - -**脚本参数说明:** - -在使用前,需要在脚本中配置以下参数: - -- `DATA_NAME_OR_PATH`: 输入数据集的HF名称或本地路径 -- `OUTPUT_DIR`: 生成的数据集输出路径 -- `DATA_FORMAT`: 输入数据集的格式(sharegpt|ultrachat) -- `DATA_SHARD_SIZE`: 生成数据集的切分子集大小 -- `BASE_PORT`: vLLM server的端口号 - -**注意事项:** -- 确保vLLM服务器已成功启动并正常运行 -- 数据生成过程可能需要较长时间,取决于样本数量和模型规模 - -##### 1.2 在线训练 - -使用 `scripts/speculative/train_eagle3_online.sh` 脚本进行Eagle3模型的在线训练: - -```shell -bash scripts/speculative/train_eagle3_online.sh -``` - -**脚本参数说明:** - -在使用前,需要在脚本中配置以下参数: - -- `TARGET_MODEL_NAME_OR_PATH`: 目标模型的HF名称或本地名称 -- `DRAFT_MODEL_CONFIG_PATH`: 草稿模型的config路径 -- `TRAIN_DATA_PATH`: 训练数据路径 -- `EVAL_DATA_PATH`: 验证数据路径 -- `OUTPUT_DIR`: Eagle3模型输出路径 -- `MODEL_MAX_LENGTH`: 训练数据的最大长度 -- `CHAT_TEMPLATE_TYPE`: 目标模型的数据模板类型 - -#### 2. 基准测试 - -AngelSlim提供了完整的Eagle3基准测试工具,用于评估投机采样的性能提升。 - -##### 2.1 基本用法 - -使用 `tools/spec_benchmark.py` 脚本进行投机采样基准测试: - -```shell -python3 tools/spec_benchmark.py \ - --base-model-path ${BASE_MODEL_PATH} \ - --eagle-model-path ${EAGLE_MODEL_PATH} \ - --model-id ${MODEL_ID} \ - --mode both -``` - -##### 2.2 参数说明 - -**模型配置参数:** -- `--base-model-path`: 基础模型路径(必需) -- `--eagle-model-path`: Eagle辅助模型路径(必需) -- `--model-id`: 模型标识符(必需) - -**基准测试配置:** -- `--bench-name`: 基准数据集名称,默认为 `mt_bench`, 可选【`alpaca`,`gsm8k`,`humaneval`,`mt_bench`】 -- `--mode`: 执行模式,可选 `eagle`(仅投机采样)、`baseline`(仅基线)、`both`(两者都执行),默认为 `both` -- `--output-dir`: 结果输出目录 - -**生成参数:** -- `--temperature`: 采样温度,默认为 1.0 -- `--max-new-token`: 最大生成token数,默认为 1024 -- `--total-token`: 草稿树中的总节点数,默认为 60 -- `--depth`: 树深度,默认为 5 -- `--top-k`: Top-k采样,默认为 10 - -**硬件配置:** -- `--num-gpus-per-model`: 每个模型使用的GPU数量,默认为 1 -- `--num-gpus-total`: 总GPU数量,默认为 1 -- `--max-gpu-memory`: 每个GPU的最大内存限制 - -**其他设置:** -- `--seed`: 随机种子,默认为 42 -- `--question-begin`: 问题起始索引(用于调试) -- `--question-end`: 问题结束索引(用于调试) -- `--no-metrics`: 跳过自动指标计算 - -##### 2.3 使用示例 - -**完整基准测试(推荐):** -```shell -python3 tools/spec_benchmark.py \ - --base-model-path /path/to/base/model \ - --eagle-model-path /path/to/eagle/model \ - --model-id qwen3-8b \ - --mode both \ - --output-dir ./results \ - --max-new-token 512 \ - --temperature 0.0 -``` - -**仅运行投机采样:** -```shell -python3 tools/spec_benchmark.py \ - --base-model-path /path/to/base/model \ - --eagle-model-path /path/to/eagle/model \ - --model-id qwen3-8b \ - --mode eagle -``` - -**多GPU配置:** -```shell -python3 tools/spec_benchmark.py \ - --base-model-path /path/to/base/model \ - --eagle-model-path /path/to/eagle/model \ - --model-id qwen3-8b \ - --num-gpus-per-model 1 \ - --num-gpus-total 8 -``` - -##### 2.4 性能报告 - -运行完成后,工具会自动生成性能报告,包括: -- 投机采样与基线模型的性能对比 -- 加速比统计 -- 生成质量指标(如果启用) - -结果将保存在指定的输出目录中,便于后续分析和比较。 +具体使用请参考[使用文档](../features/speculative_decoding/eagle.md)。 ## 部署 diff --git a/scripts/speculative/generate_hidden_for_draft_model.sh b/scripts/speculative/generate_hidden_for_draft_model.sh new file mode 100644 index 00000000..3a12fa73 --- /dev/null +++ b/scripts/speculative/generate_hidden_for_draft_model.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +DATASET_PATH= +MODEL_NAME= +TARGET_BACKEND=hf +MODEL_MAX_LENGTH=2048 +CHAT_TEMPLATE_TYPE=qwen3 +OUTPUT_DIR= + +torchrun --nproc_per_node=8 \ + tools/generate_hidden_for_draft_model.py \ + --dataset_path $DATASET_PATH \ + --model_name $MODEL_NAME \ + --target_backend $TARGET_BACKEND \ + --torch_dtype bfloat16 \ + --model_max_length $MODEL_MAX_LENGTH \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --outdir $OUTPUT_DIR diff --git a/scripts/speculative/train_eagle3_offline.sh b/scripts/speculative/train_eagle3_offline.sh new file mode 100644 index 00000000..305c7332 --- /dev/null +++ b/scripts/speculative/train_eagle3_offline.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +export CONFIG_DIR=angelslim/compressor/speculative/train/configs +export TARGET_MODEL_NAME_OR_PATH= +export DRAFT_MODEL_CONFIG_PATH= +export TRAIN_DATA_PATH= +export TRAIN_HIDDEN_PATH= +export EVAL_HIDDEN_PATH= +export OUTPUT_DIR= +export RUN_NAME= +export MODEL_MAX_LENGTH=4096 +export LM_HEAD_KEY= +export CHAT_TEMPLATE_TYPE=qwen3 + +torchrun --nproc_per_node=8 tools/train_eagle3_offline.py \ + --target_model_name_or_path $TARGET_MODEL_NAME_OR_PATH \ + --draft_model_config_path $DRAFT_MODEL_CONFIG_PATH \ + --train_data_path $TRAIN_DATA_PATH \ + --train_hidden_path $TRAIN_HIDDEN_PATH \ + --eval_hidden_path $EVAL_HIDDEN_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 10 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --save_strategy "steps" \ + --save_steps 0.01 \ + --learning_rate 1e-4 \ + --weight_decay 0.0 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "constant" \ + --logging_steps 100 \ + --model_max_length $MODEL_MAX_LENGTH \ + --chat_template_type $CHAT_TEMPLATE_TYPE \ + --lm_head_key $LM_HEAD_KEY \ + --deepspeed $CONFIG_DIR/deepspeed_zero3.json \ + --report_to wandb \ + --run_name $RUN_NAME \ + --num_proc 48 \ + --bf16 \ No newline at end of file diff --git a/tools/generate_data_for_target_model.py b/tools/generate_data_for_target_model.py index 8c1a1e2a..9ad7af32 100644 --- a/tools/generate_data_for_target_model.py +++ b/tools/generate_data_for_target_model.py @@ -1,6 +1,20 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse -from angelslim.compressor.speculative.train.data import data_generation_work_flow +from angelslim.compressor.speculative import data_generation_work_flow def parse_arguments() -> argparse.Namespace: diff --git a/tools/generate_hidden_for_draft_model.py b/tools/generate_hidden_for_draft_model.py new file mode 100644 index 00000000..9daa0470 --- /dev/null +++ b/tools/generate_hidden_for_draft_model.py @@ -0,0 +1,442 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path +from typing import Any, Dict, Tuple + +import torch +import torch.distributed as dist +from tqdm import tqdm + +from angelslim.compressor.speculative import DatasetManager, create_target_model +from angelslim.utils import decide_device_for_distributed + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - [Rank %(rank)s] - %(message)s", +) +logger = logging.getLogger(__name__) + + +def setup_distributed(): + """ + Setup distributed training environment. + + Returns: + Tuple of (rank, world_size, local_rank) or (0, 1, 0) if not distributed + """ + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + # Initialize process group + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + + return rank, world_size, local_rank + else: + # Single process mode + return 0, 1, 0 + + +def cleanup_distributed(): + """Cleanup distributed training environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +class HiddenStateGenerator: + """Generator for creating hidden states from target model.""" + + def __init__( + self, target_model, output_dir: str, group_size: int = 5000, rank: int = 0 + ): + """ + Initialize the hidden state generator. + + Args: + target_model: The target model for generating hidden states + output_dir: Directory to save generated hidden states + group_size: Number of samples per subdirectory group + rank: Process rank for distributed training + """ + self.target_model = target_model + self.output_dir = Path(output_dir) + self.group_size = group_size + self.rank = rank + self.output_dir.mkdir(parents=True, exist_ok=True) + + def _get_output_path(self, idx: int) -> Path: + """ + Get the output file path for a given sample index. + + Args: + idx: Sample index + + Returns: + Path object for the output file + """ + start = (idx // self.group_size) * self.group_size + end = start + self.group_size + grouped_subdir = f"rows_{start}-{end}" + grouped_path = self.output_dir / grouped_subdir + grouped_path.mkdir(parents=True, exist_ok=True) + + return grouped_path / f"data_{idx}.ckpt" + + def _process_single_sample(self, idx: int, row: Dict[str, Any]) -> bool: + """ + Process a single sample and save its hidden states. + + Args: + idx: Sample index + row: Sample data containing input_ids and loss_mask + + Returns: + True if processing succeeded, False otherwise + """ + output_file = self._get_output_path(idx) + + # Skip if file already exists + if output_file.exists(): + logger.debug( + f"Skipping existing file: {output_file}", extra={"rank": self.rank} + ) + return True + + try: + # Generate aux and target hiddens + device = decide_device_for_distributed() + aux_hiddens, target_hiddens = self.target_model.get_aux_and_target_hiddens( + input_ids=row["input_ids"].to(device), + ) + + # Prepare data point + data_point = { + "input_ids": row["input_ids"].cpu(), # B, N + "loss_mask": row["loss_mask"].cpu(), # B, N + "hidden_states": aux_hiddens.cpu(), # B, N, 3*D + "target_hiddens": target_hiddens.cpu(), # B, N, D + } + + # Save to disk + torch.save(data_point, output_file) + return True + + except Exception as e: + logger.error( + f"Error processing sample {idx}: {str(e)}", extra={"rank": self.rank} + ) + return False + + def generate(self, dataset) -> Tuple[int, int]: + """ + Generate hidden states for all samples in the dataset. + + Args: + dataset: Dataset to process + + Returns: + Tuple of (successful_count, failed_count) + """ + successful = 0 + failed = 0 + + # Only show progress bar on rank 0 + iterator = ( + tqdm( + enumerate(dataset), + total=len(dataset), + desc=f"Rank {self.rank} processing", + ) + if self.rank == 0 + else enumerate(dataset) + ) + + for idx, row in iterator: + if self._process_single_sample(idx, row): + successful += 1 + else: + failed += 1 + + logger.info( + f"Processing complete. Success: {successful}, Failed: {failed}", + extra={"rank": self.rank}, + ) + logger.info(f"Results saved to {self.output_dir}", extra={"rank": self.rank}) + + return successful, failed + + +def parse_arguments() -> argparse.Namespace: + """ + Parse command line arguments. + + Returns: + Parsed arguments namespace + """ + parser = argparse.ArgumentParser( + description="Generate hidden states for draft model training", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Dataset range arguments + parser.add_argument( + "--start", + type=int, + default=0, + help="Global start index of dataset (applies before distribution to GPUs)", + ) + parser.add_argument( + "--end", + type=int, + default=None, + help="Global end index of dataset (None means use full dataset). " + "The range [start, end) will be automatically distributed across all GPUs.", + ) + + # Output configuration + parser.add_argument( + "--outdir", + type=str, + default="outdir0", + help="Output directory for generated hidden states", + ) + + # Model configuration + parser.add_argument( + "--model_name", type=str, default="Qwen/Qwen3-4B", help="Model name or path" + ) + parser.add_argument( + "--target_model_name_or_path", + type=str, + help="Target model name or path (if different from model_name)", + ) + parser.add_argument( + "--target_backend", + type=str, + default="hf", + choices=["hf"], + help="Backend for target model", + ) + parser.add_argument( + "--torch_dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Torch dtype for model", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Trust remote code when loading model", + ) + + # Dataset configuration + parser.add_argument( + "--dataset_path", type=str, nargs="+", required=True, help="Dataset to use" + ) + parser.add_argument( + "--model_max_length", type=int, default=2048, help="Maximum token length" + ) + parser.add_argument( + "--chat_template_type", type=str, default="default", help="Chat template type" + ) + parser.add_argument( + "--display", + action="store_true", + help="Display dataset samples (only on rank 0)", + ) + parser.add_argument( + "--num_proc", + type=int, + default=16, + help="Number of processes for data preprocessing", + ) + parser.add_argument( + "--shuffle_seed", type=int, default=42, help="Random seed for shuffling dataset" + ) + + return parser.parse_args() + + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """ + Convert string dtype to torch dtype. + + Args: + dtype_str: String representation of dtype + + Returns: + Corresponding torch dtype + """ + dtype_mapping = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + return dtype_mapping.get(dtype_str, torch.bfloat16) + + +def load_dataset(args: argparse.Namespace, tokenizer, rank: int): + """ + Load and prepare dataset. + + Args: + args: Parsed command line arguments + tokenizer: Tokenizer from target model + rank: Process rank + + Returns: + Prepared dataset + """ + logger.info(f"Loading dataset: {args.dataset_path}", extra={"rank": rank}) + + # Only display on rank 0 + display = args.display and rank == 0 + + args.train_data_path = None + args.eval_data_path = args.dataset_path + dataset_manager = DatasetManager( + data_args=args, + tokenizer=tokenizer, + model_max_length=args.model_max_length, + chat_template_type=args.chat_template_type, + display=display, + ) + + _, dataset = dataset_manager.create_online_datasets() + logger.info(f"Dataset loaded: {len(dataset)} samples", extra={"rank": rank}) + + return dataset + + +def split_dataset_for_rank( + dataset, rank: int, world_size: int, start: int = 0, end: int = None +): + """ + Split dataset for distributed processing. + + The dataset is first sliced to [start:end] range (global range), + then evenly distributed across all ranks. + + Args: + dataset: Full dataset + rank: Current process rank (0 to world_size-1) + world_size: Total number of processes + start: Global start index (default: 0) + end: Global end index (default: None, means len(dataset)) + + Returns: + Dataset slice for current rank + + Example: + Dataset has 10000 samples, world_size=4, start=1000, end=5000 + - Global range: [1000, 5000) = 4000 samples + - Rank 0: [1000, 2000) = 1000 samples + - Rank 1: [2000, 3000) = 1000 samples + - Rank 2: [3000, 4000) = 1000 samples + - Rank 3: [4000, 5000) = 1000 samples + """ + # Determine the global range to process + if end is None: + end = len(dataset) + + # Validate range + if start < 0 or end > len(dataset) or start >= end: + raise ValueError( + f"Invalid range: start={start}, end={end}, dataset_size={len(dataset)}" + ) + + total_samples = end - start + samples_per_rank = total_samples // world_size + remainder = total_samples % world_size + + # Calculate start and end for this rank + rank_start = start + rank * samples_per_rank + min(rank, remainder) + rank_end = rank_start + samples_per_rank + (1 if rank < remainder else 0) + + logger.info( + f"Rank {rank}/{world_size}: Processing global range [{start}, {end}) -> " + f"assigned range [{rank_start}, {rank_end}) ({rank_end - rank_start} samples)", + extra={"rank": rank}, + ) + + return dataset.select(range(rank_start, rank_end)) + + +def main(): + """Main execution function.""" + # Setup distributed environment + rank, world_size, local_rank = setup_distributed() + + # Parse arguments + args = parse_arguments() + args.train_data_path = None + args.eval_data_path = args.dataset_path + + try: + # Load target model + torch_dtype = get_torch_dtype(args.torch_dtype) + target_model = create_target_model( + backend=args.target_backend, + model_path=args.target_model_name_or_path or args.model_name, + torch_dtype=torch_dtype, + trust_remote_code=args.trust_remote_code, + ) + + # Load dataset + dataset = load_dataset(args, target_model.tokenizer, rank) + + # Split dataset for this rank + dataset_slice = split_dataset_for_rank( + dataset, rank, world_size, args.start, args.end + ) + + # Generate hidden states + output_dir = f"{args.outdir}/rank_{rank}" + generator = HiddenStateGenerator(target_model, output_dir, rank=rank) + successful, failed = generator.generate(dataset_slice) + + # Synchronize all processes + if world_size > 1: + dist.barrier() + + # Log final statistics (only on rank 0) + if rank == 0: + logger.info("=" * 50, extra={"rank": rank}) + logger.info("Generation Complete!", extra={"rank": rank}) + logger.info( + f"Total samples processed across all ranks: {len(dataset)}", + extra={"rank": rank}, + ) + logger.info("=" * 50, extra={"rank": rank}) + + logger.info( + f"Rank {rank} - Successful: {successful}, Failed: {failed}", + extra={"rank": rank}, + ) + + finally: + # Cleanup distributed environment + cleanup_distributed() + + +if __name__ == "__main__": + main() diff --git a/tools/train_eagle3_offline.py b/tools/train_eagle3_offline.py new file mode 100644 index 00000000..b9dac759 --- /dev/null +++ b/tools/train_eagle3_offline.py @@ -0,0 +1,392 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path + +import transformers +from transformers import AutoTokenizer + +from angelslim.compressor.speculative import ( + DataCollatorWithPadding, + DatasetManager, + DraftModelConfig, + OfflineEagle3Trainer, + TargetHead, + create_draft_model, + get_supported_chat_template_type_strings, +) +from angelslim.utils import rank0_print + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Train EAGLE3 offline model") + + # Model arguments + model_group = parser.add_argument_group("Model Arguments") + model_group.add_argument( + "--target_model_name_or_path", + type=str, + default=None, + help="Path to target model, defaults to model_name_or_path if not specified", + ) + model_group.add_argument( + "--draft_model_config_path", + type=str, + default=None, + help="Path to draft model config", + ) + model_group.add_argument( + "--target_backend", + type=str, + default="hf", + choices=["hf"], + help=("Target model backend: hf (HuggingFace Transformers)"), + ) + model_group.add_argument( + "--torch_dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Data type for model weights: float16, bfloat16, float32", + ) + model_group.add_argument( + "--trust_remote_code", + action="store_true", + default=True, + help="Whether to trust remote code when loading models", + ) + model_group.add_argument( + "--lm_head_key", + type=str, + default="lm_head.weight", + help="Key for lm head in model config", + ) + + # Data arguments + data_group = parser.add_argument_group("Data Arguments") + data_group.add_argument( + "--train_data_path", + type=str, + nargs="+", + required=True, + help="Path to training data file(s) (JSON format). Can specify multiple files.", + ) + data_group.add_argument( + "--eval_data_path", + type=str, + default=None, + help="Path to evaluation data file", + ) + data_group.add_argument( + "--train_hidden_path", + type=str, + required=True, + help="Path to training hidden file", + ) + data_group.add_argument( + "--eval_hidden_path", + type=str, + default=None, + help="Path to evaluation hidden file", + ) + data_group.add_argument( + "--chat_template_type", + type=str, + default="qwen3", + help=( + f"Chat template type for conversation formatting. " + f"Supported types: {', '.join(get_supported_chat_template_type_strings())}" + ), + ) + data_group.add_argument( + "--num_proc", + type=int, + default=16, + help="Number of processes for data preprocessing", + ) + data_group.add_argument( + "--shuffle_seed", type=int, default=42, help="Random seed for shuffling dataset" + ) + data_group.add_argument( + "--display", + action="store_true", + default=False, + help="Display data samples during preprocessing (default: False)", + ) + + # Training arguments + training_group = parser.add_argument_group("Training Arguments") + training_group.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for model checkpoints", + ) + training_group.add_argument( + "--optim", type=str, default="adamw_torch", help="Optimizer to use" + ) + training_group.add_argument( + "--training_time_test_length", + type=int, + default=7, + help="Length of test data for training time", + ) + training_group.add_argument( + "--model_max_length", + type=int, + default=2048, + help=( + "Maximum sequence length. " + "Sequences will be right padded (and possibly truncated)." + ), + ) + training_group.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size per device during training", + ) + training_group.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size per device during evaluation", + ) + training_group.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help=( + "Number of updates steps to accumulate before " + "performing a backward/update pass" + ), + ) + training_group.add_argument( + "--num_train_epochs", + type=int, + default=3, + help="Total number of training epochs to perform", + ) + training_group.add_argument( + "--learning_rate", type=float, default=5e-5, help="Initial learning rate" + ) + training_group.add_argument( + "--weight_decay", type=float, default=0.0, help="Weight decay to apply" + ) + training_group.add_argument( + "--warmup_steps", type=int, default=0, help="Number of steps for warmup" + ) + training_group.add_argument( + "--warmup_ratio", type=float, default=0.0, help="Ratio of warmup steps" + ) + training_group.add_argument( + "--logging_steps", type=int, default=10, help="Log every X updates steps" + ) + training_group.add_argument( + "--save_steps", + type=float, + default=500, + help="Save checkpoint every X updates steps", + ) + training_group.add_argument( + "--eval_steps", type=int, default=500, help="Run evaluation every X steps" + ) + training_group.add_argument( + "--save_total_limit", + type=int, + default=None, + help="Limit the total amount of checkpoints", + ) + training_group.add_argument( + "--deepspeed", type=str, default=None, help="DeepSpeed config file" + ) + training_group.add_argument( + "--fp16", action="store_true", help="Whether to use fp16 training" + ) + training_group.add_argument( + "--bf16", action="store_true", help="Whether to use bf16 training" + ) + training_group.add_argument( + "--save_strategy", type=str, default="no", help="Save strategy for checkpoints" + ) + training_group.add_argument( + "--lr_scheduler_type", + type=str, + default="constant", + help=( + "Learning rate scheduler type. " + "Common options: 'linear', 'cosine', 'cosine_with_restarts', " + "'polynomial', 'constant', 'constant_with_warmup'" + ), + ) + training_group.add_argument( + "--run_name", type=str, default=None, help="Run name for tracking" + ) + training_group.add_argument( + "--report_to", + type=str, + default="none", + help=( + "The list of integrations to report the results and logs to. " + "Supported platforms: 'tensorboard', 'wandb', 'mlflow', 'all', 'none'" + ), + ) + + return parser.parse_args() + + +def train(): + args = parse_args() + + # Create draft model + rank0_print("Loading draft model...") + draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) + draft_model = create_draft_model(draft_model_config) + draft_model.load_embed_weights(args.target_model_name_or_path) + draft_model.freeze_embed_weights() + rank0_print("Draft model loaded successfully") + + # Load target head for computing logits from hidden states + rank0_print("Loading target head...") + target_head = TargetHead.from_pretrained( + args.target_model_name_or_path, + lm_head_key=args.lm_head_key, + ) + rank0_print("Target head loaded successfully") + + # Load tokenizer + rank0_print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(args.target_model_name_or_path) + + # Create all datasets using unified DatasetManager + rank0_print("Creating datasets...") + rank0_print("- Offline mode: Loading pre-computed hidden states from .ckpt files") + rank0_print( + "- Online mode: Processing raw conversation data " + f"(chat template: {args.chat_template_type})" + ) + + dataset_manager = DatasetManager( + data_args=args, + tokenizer=tokenizer, + model_max_length=args.model_max_length, + chat_template_type=args.chat_template_type, + ) + + (offline_train_dataset, offline_eval_dataset, online_train_dataset, _) = ( + dataset_manager.create_all_datasets() + ) + + rank0_print( + f"Offline train dataset size: {len(offline_train_dataset)}, " + "Offline eval dataset size: " + f"{len(offline_eval_dataset) if offline_eval_dataset else 0}" + ) + + # Build vocabulary mapping for draft model using online training dataset + rank0_print("Building vocabulary mapping for draft model...") + if online_train_dataset is not None: + cache_path = os.path.join(args.output_dir, "vocab_mapping_cache.pt") + draft_model.build_vocab_mapping( + dataset=online_train_dataset, + cache_path=cache_path, + ) + rank0_print("Vocabulary mapping built successfully") + else: + rank0_print( + "Warning: No online training dataset available, " + "skipping vocab mapping build" + ) + + # Create a TrainingArguments object for the trainer + # Organize training arguments by category + basic_args = { + "output_dir": args.output_dir, + "num_train_epochs": args.num_train_epochs, + } + + batch_args = { + "per_device_train_batch_size": args.per_device_train_batch_size, + "per_device_eval_batch_size": args.per_device_eval_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + } + + optimizer_args = { + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "warmup_steps": args.warmup_steps, + "optim": args.optim, + "lr_scheduler_type": args.lr_scheduler_type, + } + + precision_args = { + "fp16": args.fp16, + "bf16": args.bf16, + } + + checkpoint_args = { + "save_strategy": args.save_strategy, + "save_steps": args.save_steps, + "save_total_limit": args.save_total_limit, + } + + logging_args = { + "logging_steps": args.logging_steps, + "eval_steps": args.eval_steps, + "report_to": args.report_to, + "run_name": args.run_name, + } + + distributed_args = { + "deepspeed": args.deepspeed, + } + + training_args = transformers.TrainingArguments( + **basic_args, + **batch_args, + **optimizer_args, + **precision_args, + **checkpoint_args, + **logging_args, + **distributed_args, + remove_unused_columns=False, + ) + + # Initialize trainer with offline datasets + rank0_print("Initializing trainer...") + trainer = OfflineEagle3Trainer( + draft_model=draft_model, + target_head=target_head, + length=args.training_time_test_length, + args=training_args, + train_dataset=offline_train_dataset, + eval_dataset=offline_eval_dataset, + data_collator=DataCollatorWithPadding(), + ) + + # Start training + if list(Path(training_args.output_dir).glob("checkpoint-*")): + rank0_print("Resuming training from checkpoint...") + trainer.train(resume_from_checkpoint=True) + else: + rank0_print("Starting training...") + trainer.train() + rank0_print("Training completed!") + + +if __name__ == "__main__": + train() diff --git a/tools/train_eagle3_online.py b/tools/train_eagle3_online.py index 278d481b..15f5d711 100644 --- a/tools/train_eagle3_online.py +++ b/tools/train_eagle3_online.py @@ -1,3 +1,17 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import os from pathlib import Path @@ -258,7 +272,7 @@ def train(): chat_template_type=args.chat_template_type, display=args.display, ) - train_dataset, eval_dataset = dataset_manager.create_datasets() + train_dataset, eval_dataset = dataset_manager.create_online_datasets() rank0_print( f"Train dataset size: {len(train_dataset)}, " f"Eval dataset size: {len(eval_dataset) if eval_dataset else 0}"