diff --git a/.gitignore b/.gitignore index 7935db5e..d68ab077 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,8 @@ dist/ eval/ *_ckpt*/ output/ +outputs/ outs/ wandb/ tools/results/ -__pycache__/ \ No newline at end of file +__pycache__/outputs/ diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index ace2b186..7725af78 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -184,11 +184,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: "target_hiddens": None, } - # Check if both hidden_states and target_hiddens exist in all features - if all("hidden_states" in item and "target_hiddens" in item for item in features): + # Handle hidden_states and target_hiddens independently + if all("hidden_states" in item for item in features): batch["hidden_states"] = torch.cat( [paddingtensor(item["hidden_states"], max_length) for item in features] ) + + if all("target_hiddens" in item for item in features): batch["target_hiddens"] = torch.cat( [paddingtensor(item["target_hiddens"], max_length) for item in features] ) diff --git a/angelslim/compressor/speculative/train/data/dataset.py b/angelslim/compressor/speculative/train/data/dataset.py index 722f0d15..9735b327 100644 --- a/angelslim/compressor/speculative/train/data/dataset.py +++ b/angelslim/compressor/speculative/train/data/dataset.py @@ -174,41 +174,56 @@ def _create_online_datasets( if self.display: num_proc = None + # Determine min_loss_tokens for DFlash filtering + min_loss_tokens = None + if self.data_args.modal_type == "DFlash": + block_size = getattr(self.data_args, "block_size", 16) + min_loss_tokens = 2 * block_size + # Create training dataset train_dataset = None - if self.data_args.train_data_path is not None: + train_path = getattr(self.data_args, "train_data_path", None) + if train_path is not None: train_dataset = self.online_dataset_builder.build_dataset( - self.data_args.train_data_path, + train_path, num_proc=num_proc, shuffle=True, - sample_num=self.data_args.sample_num, + sample_num=getattr(self.data_args, "sample_num", None), + min_loss_tokens=min_loss_tokens, ) # Create evaluation dataset eval_dataset = None - if self.data_args.eval_data_path is not None: + eval_path = getattr(self.data_args, "eval_data_path", None) + if eval_path is not None: eval_dataset = self.online_dataset_builder.build_dataset( - self.data_args.eval_data_path, + eval_path, num_proc=num_proc, shuffle=False, - sample_num=self.data_args.sample_num, + sample_num=getattr(self.data_args, "sample_num", None), + min_loss_tokens=min_loss_tokens, ) data_collator = self.online_dataset_builder.get_data_collator() return train_dataset, eval_dataset, data_collator - def _create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: + def _create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset], Any]: """ Create offline datasets from pre-computed .ckpt files. Returns: - Tuple of (train_dataset, eval_dataset) + Tuple of (train_dataset, eval_dataset, data_collator) """ + if self.offline_dataset_builder is None: + return None, None, None + # Create train dataset - train_dataset = self.offline_dataset_builder.build_dataset( - self.data_args.train_hidden_path - ) + train_dataset = None + if self.data_args.train_hidden_path is not None: + train_dataset = self.offline_dataset_builder.build_dataset( + self.data_args.train_hidden_path + ) # Create eval dataset if path is provided eval_dataset = None diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/base_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/base_dataset_builder.py index 7259149c..d7fd2f5e 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/base_dataset_builder.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/base_dataset_builder.py @@ -29,7 +29,12 @@ class DatasetBuilder(metaclass=ABCMeta): @abstractmethod def build_dataset( - self, datapath: str, num_proc: int = 8, shuffle: bool = True, **kwargs + self, + datapath: str, + num_proc: int = 8, + shuffle: bool = True, + min_loss_tokens: Optional[int] = None, + **kwargs, ) -> Dataset: pass @@ -127,6 +132,7 @@ def build_dataset( num_proc: int = 8, shuffle: bool = True, sample_num: Optional[int] = None, + min_loss_tokens: Optional[int] = None, ) -> Dataset: try: # Load dataset @@ -161,6 +167,18 @@ def build_dataset( num_proc=num_proc, desc="Filtering empty input_ids", ) + + if min_loss_tokens is not None: + processed_ds = processed_ds.filter( + lambda batch: [ + sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens + for m in batch["loss_mask"] + ], + batched=True, + num_proc=num_proc, + desc=f"Filtering sequences with loss tokens < {min_loss_tokens}", + ) + processed_ds.set_format(type="torch") return processed_ds diff --git a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py index 9a662eff..e09561e8 100644 --- a/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py +++ b/angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py @@ -94,6 +94,7 @@ def build_dataset( num_proc: int = 8, shuffle: bool = True, sample_num: Optional[int] = None, + min_loss_tokens: Optional[int] = None, ) -> Dataset: try: # Load dataset @@ -146,11 +147,19 @@ def build_dataset( num_proc=num_proc, desc="Filtering empty input_ids", ) + if min_loss_tokens is not None: + processed_ds = processed_ds.filter( + lambda batch: [ + sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens + for m in batch["loss_mask"] + ], + batched=True, + num_proc=num_proc, + desc=f"Filtering sequences with loss tokens < {min_loss_tokens}", + ) + torch_columns = [c for c in processed_ds.column_names if c != "image_paths"] processed_ds.set_format(type="torch", columns=torch_columns, output_all_columns=True) - rank0_print( - f"processed_ds size:{len(processed_ds)}, columns: {processed_ds.column_names}" - ) return processed_ds @@ -324,6 +333,7 @@ def build_dataset( num_proc: int = 8, shuffle: bool = True, sample_num: Optional[int] = None, + min_loss_tokens: Optional[int] = None, ) -> Dataset: try: # Load dataset @@ -374,6 +384,16 @@ def build_dataset( num_proc=num_proc, desc="Filtering empty input_ids", ) + if min_loss_tokens is not None: + processed_ds = processed_ds.filter( + lambda batch: [ + sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens + for m in batch["loss_mask"] + ], + batched=True, + num_proc=num_proc, + desc=f"Filtering sequences with loss tokens < {min_loss_tokens}", + ) torch_columns = [c for c in processed_ds.column_names if c != "image_paths"] processed_ds.set_format(type="torch", columns=torch_columns, output_all_columns=True) @@ -572,6 +592,7 @@ def build_dataset( num_proc: int = 8, shuffle: bool = True, sample_num: Optional[int] = None, + min_loss_tokens: Optional[int] = None, ) -> Dataset: try: # Load dataset @@ -623,6 +644,18 @@ def build_dataset( num_proc=num_proc, desc="Filtering empty input_ids", ) + + if min_loss_tokens is not None: + processed_ds = processed_ds.filter( + lambda batch: [ + sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens + for m in batch["loss_mask"] + ], + batched=True, + num_proc=num_proc, + desc=f"Filtering sequences with loss tokens < {min_loss_tokens}", + ) + processed_ds.set_format(type="torch") return processed_ds @@ -886,6 +919,7 @@ def build_dataset( num_proc: int = 8, shuffle: bool = True, sample_num: Optional[int] = None, + min_loss_tokens: Optional[int] = None, ) -> Dataset: try: if not isinstance(datapath, list): diff --git a/angelslim/compressor/speculative/train/models/draft/__init__.py b/angelslim/compressor/speculative/train/models/draft/__init__.py index 1b1eb4b9..c056ce23 100644 --- a/angelslim/compressor/speculative/train/models/draft/__init__.py +++ b/angelslim/compressor/speculative/train/models/draft/__init__.py @@ -14,10 +14,12 @@ from .draft_model_factory import DraftModelConfig, create_draft_model from .llama_eagle3 import CosyVoice3Eagle3LlamaForCausalLM, Eagle3LlamaForCausalLM +from .qwen_dflash import QwenDFlashDraftModel __all__ = [ "create_draft_model", "DraftModelConfig", "Eagle3LlamaForCausalLM", "CosyVoice3Eagle3LlamaForCausalLM", + "QwenDFlashDraftModel", ] diff --git a/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py b/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py new file mode 100755 index 00000000..faf00343 --- /dev/null +++ b/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py @@ -0,0 +1,389 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Draft Model for Qwen3 architecture. + +Migrated from SpecForge's specforge/modeling/draft/dflash.py. +Uses cross-attention between draft blocks and context hidden states, +fundamentally different from Eagle3's concat + self-attention approach. +""" + +from typing import Callable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import DynamicCache +from transformers.cache_utils import Cache +from transformers.models.qwen3.modeling_qwen3 import ( + ALL_ATTENTION_FUNCTIONS, + FlashAttentionKwargs, + GradientCheckpointingLayer, + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward, + rotate_half, +) +from typing_extensions import Unpack + +from .draft_model_factory import DraftModelFactory + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Compute target layer IDs to capture from the target model.""" + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) + ] + return target_layer_ids + + +def extract_context_feature( + hidden_states: list, + layer_ids: Optional[List[int]], +) -> torch.Tensor: + """Extract and concatenate hidden states from specified layers.""" + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +class Qwen3DFlashAttention(nn.Module): + """Multi-headed cross-attention for DFlash. + + Q comes from draft hidden states, KV comes from concatenation of + context (target) hidden states and draft hidden states. + """ + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = ( + config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + attn_fn: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer): + """DFlash decoder layer with cross-attention to context.""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@DraftModelFactory.register +class QwenDFlashDraftModel(Qwen3PreTrainedModel): + """DFlash Draft Model for Qwen3 architecture. + + Uses block-parallel cross-attention between noise-masked draft blocks + and context hidden states from the target model. + """ + + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + Qwen3DFlashDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + dflash_config = getattr(config, "dflash_config", {}) or {} + self.target_layer_ids = dflash_config.get( + "target_layer_ids", + build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear( + len(self.target_layer_ids) * config.hidden_size, + config.hidden_size, + bias=False, + ) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = dflash_config.get("mask_token_id", None) + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> torch.Tensor: + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + @torch.inference_mode() + def spec_generate( + self, + target: nn.Module, + input_ids: torch.LongTensor, + max_new_tokens: int, + stop_token_ids: List[int], + temperature: float, + ): + """Speculative generation with DFlash draft model.""" + self.eval() + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + block_size = self.block_size + output_ids = torch.full( + (1, max_length + block_size), + self.mask_token_id, + dtype=torch.long, + device=target.device, + ) + position_ids = torch.arange(output_ids.shape[1], device=target.device).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = sample(output.logits, temperature) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids) + + # Decode stage + acceptance_lengths = [] + start = input_ids.shape[1] + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head( + self( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[ + :, past_key_values_draft.get_seq_length() : start + block_size + ], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size + 1 :, :] + ) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = ( + (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() + ) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ + :, : acceptance_length + 1 + ] + output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids)[ + :, : acceptance_length + 1, : + ] + acceptance_lengths.append(acceptance_length + 1) + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] + for stop_token_id in stop_token_ids + ): + break + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != self.mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin( + output_ids[0][num_input_tokens:], stop_token_ids + ).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1] + + return output_ids diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 11d17151..e803ee27 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -193,7 +193,7 @@ def _prepare_model_kwargs(self, device: str) -> dict: Dictionary of model loading arguments """ default_kwargs = { - "dtype": torch.bfloat16, + "torch_dtype": torch.bfloat16, "device_map": device, "trust_remote_code": True, } @@ -228,6 +228,7 @@ def get_hidden_states_and_logits( attention_mask=attention_mask, output_hidden_states=True, output_logits=True, + use_cache=False, # match SpecForge: no KV-cache during training ) # Extract auxiliary hidden states @@ -914,7 +915,7 @@ def create_target_model( # Add backend-specific configuration if backend == "hf": - kwargs["dtype"] = torch_dtype + kwargs["torch_dtype"] = torch_dtype else: raise ValueError( f"Unsupported backend: '{backend}'. " diff --git a/angelslim/compressor/speculative/train/trainer/__init__.py b/angelslim/compressor/speculative/train/trainer/__init__.py index ad1c6729..cd4db79c 100644 --- a/angelslim/compressor/speculative/train/trainer/__init__.py +++ b/angelslim/compressor/speculative/train/trainer/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .offline_dflash_trainer import OfflineDFlashTrainer from .offline_eagle3_trainer import OfflineEagle3Trainer, OfflineVLMEagle3Trainer +from .online_dflash_trainer import OnlineDFlashTrainer from .online_eagle3_trainer import ( OnlineEagle3Trainer, OnlineTTSEagle3Trainer, @@ -25,6 +27,8 @@ "OnlineEagle3Trainer", "OnlineVLMEagle3Trainer", "OnlineTTSEagle3Trainer", + "OnlineDFlashTrainer", + "OfflineDFlashTrainer", "OfflineEagle3Trainer", "OfflineVLMEagle3Trainer", ] diff --git a/angelslim/compressor/speculative/train/trainer/offline_dflash_trainer.py b/angelslim/compressor/speculative/train/trainer/offline_dflash_trainer.py new file mode 100644 index 00000000..b58b74cf --- /dev/null +++ b/angelslim/compressor/speculative/train/trainer/offline_dflash_trainer.py @@ -0,0 +1,44 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .online_dflash_trainer import OnlineDFlashTrainer +from .trainer_factory import Eagle3TrainerFactory + + +@Eagle3TrainerFactory.register("offline", "DFlash") +class OfflineDFlashTrainer(OnlineDFlashTrainer): + """ + DFlash trainer for offline (pre-computed hidden states) training. + + The main difference vs online: hidden_states are loaded directly from the + pre-computed .ckpt files, so prepare_data_for_draft_model() just unpacks + the batch instead of running a target-model forward pass. + """ + + def prepare_data_for_draft_model(self, inputs): + """ + Unpack pre-computed hidden states from the offline batch. + + Expected batch keys (from OfflineDFlashDataset): + input_ids [B, S] + hidden_states [B, S, D*L] + loss_mask [B, S] + attention_mask [B, S] + """ + return { + "input_ids": inputs["input_ids"], + "hidden_states": inputs["hidden_states"], + "loss_mask": inputs["loss_mask"], + "attention_mask": inputs["attention_mask"], + } diff --git a/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py b/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py new file mode 100755 index 00000000..e2ff3741 --- /dev/null +++ b/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py @@ -0,0 +1,547 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Online DFlash Trainer for speculative decoding training. + +DFlash uses block-parallel cross-attention rather than Eagle3's +iterative autoregressive approach, so it overrides compute_loss +with its own block-wise CE loss logic. +""" + +import gc +import glob +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from safetensors import safe_open +from torch import nn +from transformers import AutoConfig + +from .eagle3_trainer import Eagle3Trainer +from .trainer_factory import Eagle3TrainerFactory + +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + create_block_mask = None + + +def create_dflash_block_mask( + anchor_positions: torch.Tensor, + block_keep_mask: torch.Tensor, + S: int, + block_size: int, + device: torch.device, +): + """Construct Flex Attention BlockMask for DFlash training. + + KV: [Context (S tokens) | Block_0 | Block_1 | ... | Block_{n-1}] + Q: [Block_0 | Block_1 | ... | Block_{n-1}] + + Rules: + 1. Each block sees context strictly before its anchor (kv_idx < anchor_pos). + 2. Intra-block attention is bidirectional. + 3. Different blocks are invisible to each other. + 4. Invalid blocks (block_keep_mask=False) see nothing. + """ + + def dflash_mask_mod(b, h, q_idx, kv_idx): + q_block_id = q_idx // block_size + anchor_pos = anchor_positions[b, q_block_id] + + is_context = kv_idx < S + # Strictly less than: matches inference where target_hidden[anchor_pos] + # is not available as context. + mask_context = is_context & (kv_idx < anchor_pos) + + is_draft = kv_idx >= S + kv_block_id = (kv_idx - S) // block_size + mask_draft = is_draft & (q_block_id == kv_block_id) + + is_valid_block = block_keep_mask[b, q_block_id] + return (mask_context | mask_draft) & is_valid_block + + B, N = anchor_positions.shape + Q_LEN = N * block_size + KV_LEN = S + N * block_size + + return create_block_mask( + dflash_mask_mod, B=B, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device + ) + + +class TargetEmbeddingsAndHead(nn.Module): + """Efficiently loads only the embedding layer and lm_head from a pretrained model. + + Handles safetensors slicing and Weight Tying correctly. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=getattr(config, "pad_token_id", None), + ) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + @classmethod + def from_pretrained( + cls, + model_path: str, + embed_key: Optional[str] = None, + lm_head_key: Optional[str] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + trust_remote_code: bool = False, + ) -> "TargetEmbeddingsAndHead": + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + instance = cls(config) + + if embed_key is None: + embed_key = "model.embed_tokens.weight" + if lm_head_key is None: + lm_head_key = "lm_head.weight" + + tie_weights = getattr(config, "tie_word_embeddings", False) + instance._load_weights(model_path, embed_key, lm_head_key, tie_weights) + + instance.to(device=device, dtype=dtype) + instance.eval() + instance.requires_grad_(False) + + return instance + + def _load_weights(self, model_path: str, embed_key: str, lm_head_key: str, tie_weights: bool): + index_files = glob.glob(os.path.join(model_path, "*.index.json")) + files_to_load = {} + + if index_files: + with open(index_files[0], "r") as f: + index = json.load(f) + weight_map = index.get("weight_map", {}) + + if embed_key in weight_map: + files_to_load[embed_key] = weight_map[embed_key] + else: + raise ValueError(f"Embedding key '{embed_key}' not found in weight map.") + + if not tie_weights: + if lm_head_key in weight_map: + files_to_load[lm_head_key] = weight_map[lm_head_key] + else: + safetensors = glob.glob(os.path.join(model_path, "*.safetensors")) + bins = glob.glob(os.path.join(model_path, "*.bin")) + target_file = safetensors[0] if safetensors else (bins[0] if bins else None) + + if not target_file: + raise FileNotFoundError("No checkpoint found.") + + files_to_load[embed_key] = os.path.basename(target_file) + if not tie_weights: + files_to_load[lm_head_key] = os.path.basename(target_file) + + file_to_keys_map = {} + for key, filename in files_to_load.items(): + full_path = os.path.join(model_path, filename) + if full_path not in file_to_keys_map: + file_to_keys_map[full_path] = [] + file_to_keys_map[full_path].append(key) + + for file_path, keys in file_to_keys_map.items(): + self._load_file_content(file_path, keys, embed_key, lm_head_key) + + if tie_weights: + self.lm_head.weight = self.embed_tokens.weight + + def _load_file_content( + self, + file_path: str, + keys_to_extract: list, + target_embed_key: str, + target_head_key: str, + ): + state_dict_part = {} + + if file_path.endswith(".safetensors"): + with safe_open(file_path, framework="pt") as f: + for k in keys_to_extract: + if k in f.keys(): + state_dict_part[k] = f.get_tensor(k) + else: + full_state = torch.load(file_path, map_location="cpu") + for k in keys_to_extract: + if k in full_state: + state_dict_part[k] = full_state[k] + del full_state + gc.collect() + + for k, tensor in state_dict_part.items(): + if k == target_embed_key: + self.embed_tokens.weight.data.copy_(tensor) + elif k == target_head_key: + if tensor.shape == self.lm_head.weight.data.shape: + self.lm_head.weight.data.copy_(tensor) + + +@Eagle3TrainerFactory.register("online", "DFlash") +class OnlineDFlashTrainer(Eagle3Trainer): + """Online DFlash Trainer for speculative decoding training. + + Uses block-parallel cross-attention and anchor-based CE loss + rather than Eagle3's iterative autoregressive training loop. + """ + + def __init__( + self, + draft_model: nn.Module, + target_model: nn.Module, + length: int, + draft_model_config: Dict[str, Any], + **kwargs, + ): + """ + Initialize the OnlineDFlashTrainer. + + Args: + draft_model: DFlash draft model + target_model: Target model for generating hidden states + length: Not used for DFlash (kept for interface compatibility) + draft_model_config: Configuration dictionary for draft model, + must contain dflash-specific fields + **kwargs: Additional arguments passed to parent Trainer + """ + super().__init__(draft_model=draft_model, length=length, **kwargs) + self.target_model = target_model + self._aux_hidden_states_layer_ids = getattr( + draft_model_config, "aux_hidden_states_layer_ids", None + ) + + # Extract DFlash-specific config + dflash_config = getattr(draft_model_config, "dflash_config", {}) or {} + self.block_size = getattr(draft_model_config, "block_size", 16) + self.num_anchors = getattr(draft_model_config, "num_anchors", 512) + self.loss_decay_gamma = getattr(draft_model_config, "loss_decay_gamma", None) + self.attention_backend = getattr(draft_model_config, "attention_backend", "flex_attention") + self.mask_token_id = dflash_config.get( + "mask_token_id", + getattr(draft_model_config, "mask_token_id", None), + ) + + # Load target model's lm_head and embed_tokens + # In offline mode target_model may be None; fall back to config path. + target_model_path = None + if target_model is not None: + target_model_path = getattr(target_model, "model_path", None) + if target_model_path is None: + target_model_path = getattr(draft_model_config, "target_model_name_or_path", None) + embed_weight_key = getattr( + draft_model_config, "embed_weight_key", "model.embed_tokens.weight" + ) + lm_head_key = getattr(draft_model_config, "lm_head_key", "lm_head.weight") + trust_remote_code = getattr(draft_model_config, "trust_remote_code", True) + + if target_model_path is not None: + target_components = TargetEmbeddingsAndHead.from_pretrained( + target_model_path, + embed_key=embed_weight_key, + lm_head_key=lm_head_key, + device="cuda", + trust_remote_code=trust_remote_code, + ) + self.target_lm_head = target_components.lm_head + self.target_embed_tokens = target_components.embed_tokens + else: + raise ValueError( + "target_model_name_or_path must be set in draft_model_config " + "or target_model.model_path for DFlash training." + ) + + def prepare_data_for_draft_model(self, inputs): + """Prepare data for DFlash training. + + Extracts hidden states from the target model. DFlash needs + multi-layer hidden states concatenated as context features. + """ + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + loss_mask = inputs["loss_mask"] + + # Get hidden states from target model + hidden_states, _ = self.target_model.get_hidden_states_and_logits( + input_ids=input_ids, + attention_mask=attention_mask, + aux_hidden_states_layer_ids=self._aux_hidden_states_layer_ids, + ) + + return { + "input_ids": input_ids, + "hidden_states": hidden_states, + "loss_mask": loss_mask, + "attention_mask": attention_mask, + } + + def _sample_anchor_positions( + self, seq_len: int, loss_mask: torch.Tensor, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Randomly sample anchor positions per sample; returns (anchors, keep_mask). + + Returns (None, None) when the batch has no valid anchors (too-short or + loss_mask-empty sequences), which is handled gracefully in forward(). + """ + bs = self.block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_valid = int(valid_counts.max().item()) + + # Need at least 2 valid positions (anchor + at least one prediction target) + if max_valid <= 1: + return None, None + + max_n = min(self.num_anchors, max_valid - 1) + + indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) + masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + + random_vals = torch.rand(bsz, max_anchor + 1, device=device) + random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + + _, sorted_idx = random_vals.sort(dim=1) + gathered = torch.gather(masked_indices, 1, sorted_idx) + anchors = gathered[:, :max_n].sort(dim=1).values + + keep_mask = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze( + 1 + ).clamp(max=max_n) + anchors = torch.where(keep_mask, anchors, torch.tensor(0, dtype=torch.long, device=device)) + + return anchors, keep_mask + + def _create_position_ids(self, anchor_positions: torch.Tensor) -> torch.Tensor: + """Create absolute position IDs for parallel draft blocks.""" + bsz, n_blocks = anchor_positions.shape + device = anchor_positions.device + offsets = torch.arange(self.block_size, device=device).view(1, 1, -1) + pos_ids = anchor_positions.unsqueeze(-1) + offsets + return pos_ids.view(bsz, -1) + + def _create_noise_embed(self, input_ids, anchor_positions, block_keep_mask): + bsz, seq_len = input_ids.shape + n = anchor_positions.shape[1] + bs = self.block_size + device = input_ids.device + + noise_ids = torch.full((bsz, n * bs), self.mask_token_id, dtype=torch.long, device=device) + + block_starts = torch.arange(n, device=device) * bs + block_starts = block_starts.unsqueeze(0).expand(bsz, -1) + + valid_anchor_positions = anchor_positions.clamp(0, seq_len - 1) + anchor_tokens = torch.gather(input_ids, 1, valid_anchor_positions) + + flat_batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n) + noise_ids[flat_batch_idx, block_starts] = torch.where( + block_keep_mask, + anchor_tokens, + torch.tensor(self.mask_token_id, dtype=torch.long, device=device), + ) + + return self.target_embed_tokens(noise_ids) + + def _compute_dflash_loss_and_accuracy( + self, + model: nn.Module, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + loss_mask: torch.Tensor, + ): + """Core DFlash block-parallel loss logic (shared by train + eval). + + Steps: + 1. Sample anchor positions from valid loss_mask positions. + 2. Build noise embedding (anchor token is real, rest are MASK). + 3. Build DFlash BlockMask (context-causal + intra-block bidirectional). + 4. Run draft model forward → logits. + 5. Compute weighted CE loss with optional exponential decay. + 6. Compute accuracy (no-decay mask). + + Returns: + (loss, accuracy) — both scalar tensors. + """ + bsz, seq_len = input_ids.shape + device = input_ids.device + + # ── 1. Anchor sampling ──────────────────────────────────────────────── + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + + # No valid anchors → return zero loss connected to model params (DDP-safe) + if anchor_positions is None: + zero_loss = sum(p.sum() * 0.0 for p in model.parameters() if p.requires_grad) + return zero_loss, torch.tensor(0.0, device=device) + + # ── 2. Noise embedding ──────────────────────────────────────────────── + noise_embedding = self._create_noise_embed(input_ids, anchor_positions, block_keep_mask) + + # ── 3. Position IDs [B, S + N*block_size] ─────────────────────────── + context_position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + draft_position_ids = self._create_position_ids(anchor_positions) + full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) + + # ── 4. Attention mask (DFlash BlockMask) ───────────────────────────── + dflash_attn_mask = create_dflash_block_mask( + anchor_positions=anchor_positions, + block_keep_mask=block_keep_mask, + S=seq_len, + block_size=self.block_size, + device=device, + ) + + # ── 5. Draft model forward → logits [B, N*bs, vocab] ──────────────── + model_dtype = next(model.parameters()).dtype + noise_embedding = noise_embedding.to(model_dtype) + hidden_states = hidden_states.to(model_dtype) + + output_hidden = model( + noise_embedding=noise_embedding, + target_hidden=hidden_states, + attention_mask=dflash_attn_mask, + position_ids=full_position_ids, + ) + + output_hidden = output_hidden.to(self.target_lm_head.weight.dtype) + logits = self.target_lm_head(output_hidden) + + # ── 6. Labels: position k in block predicts token at (anchor + k) ──── + bs = self.block_size + label_offsets = torch.arange(0, bs, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label_mask = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, anchor_positions.size(1), -1), + dim=2, + index=safe_label_indices, + ) # [B, N, bs] + + # ── 7. Weight mask: valid block × in-bounds × skip anchor × loss_mask ─ + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, bs).float() + weight_mask = weight_mask * valid_label_mask.float() + + pos_in_block = torch.arange(bs, device=device).view(1, 1, -1) + weight_mask = weight_mask * (pos_in_block > 0).float() # skip pos 0 (anchor) + + gathered_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, anchor_positions.size(1), -1), + dim=2, + index=safe_label_indices, + ) + weight_mask = weight_mask * gathered_loss_mask + + binary_eval_mask = weight_mask.view(-1) # no decay, used for accuracy + + # ── 8. Exponential decay: exp(-(k-1)/γ), k=1 gets weight 1.0 ───────── + if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0: + k = torch.arange(bs, device=device).view(1, 1, -1) + decay = torch.exp(-(k - 1).clamp(min=0).float() / self.loss_decay_gamma) + weight_mask = weight_mask * decay + + # ── 9. Cross-entropy loss ───────────────────────────────────────────── + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + flat_weights = weight_mask.view(-1) + + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / (flat_weights.sum() + 1e-6) + + # ── 10. Accuracy (no gradient) ──────────────────────────────────────── + with torch.no_grad(): + pred_ids = torch.argmax(flat_logits, dim=-1) + correct = (pred_ids == flat_targets) & (binary_eval_mask > 0.5) + accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) + + return loss, accuracy + + def compute_loss( + self, + model: nn.Module, + inputs: Dict[str, torch.Tensor], + num_items_in_batch: Optional[int] = None, + return_outputs: bool = False, + ) -> torch.Tensor: + """Compute the DFlash training loss. + + Unlike Eagle3's iterative multi-step loss, DFlash computes a single + block-parallel cross-entropy loss over all sampled anchor positions. + """ + data = self.prepare_data_for_draft_model(inputs) + + loss, accuracy = self._compute_dflash_loss_and_accuracy( + model=model, + input_ids=data["input_ids"], + hidden_states=data["hidden_states"], + loss_mask=data["loss_mask"], + ) + + self.log( + { + "train/loss": round(float(loss.item()), 4), + "train/accuracy": round(float(accuracy.item()), 4), + } + ) + + return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, torch.Tensor], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Perform an evaluation step.""" + data = self.prepare_data_for_draft_model(inputs) + + with torch.no_grad(): + loss, accuracy = self._compute_dflash_loss_and_accuracy( + model=model, + input_ids=data["input_ids"], + hidden_states=data["hidden_states"], + loss_mask=data["loss_mask"], + ) + + self.log( + { + "eval/loss": round(float(loss.item()), 4), + "eval/accuracy": round(float(accuracy.item()), 4), + } + ) + + return loss, None, None diff --git a/angelslim/engine.py b/angelslim/engine.py index 12b97ea6..b4002a8f 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -532,7 +532,7 @@ def run( print_info("=" * 80) for i, output in enumerate(outputs[:5]): generated_text = output.outputs[0].text - print_info(f"[{i+1}] Output: {generated_text!r}") + print_info(f"[{i + 1}] Output: {generated_text!r}") print_info(f"\nTotal outputs generated: {len(outputs)}") # Collect and save statistics diff --git a/configs/qwen3_dflash.json b/configs/qwen3_dflash.json new file mode 100755 index 00000000..b58bb0c7 --- /dev/null +++ b/configs/qwen3_dflash.json @@ -0,0 +1,54 @@ +{ + "architectures": [ + "QwenDFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [ + 1, + 9, + 17, + 25, + 33 + ] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + + "num_anchors": 512, + "loss_decay_gamma": 7.0, + "attention_backend": "flex_attention" +} diff --git a/scripts/speculative/generate_dflash_data.sh b/scripts/speculative/generate_dflash_data.sh new file mode 100644 index 00000000..ce7a8706 --- /dev/null +++ b/scripts/speculative/generate_dflash_data.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# ============================================================================= +# Step 1: Pre-generate DFlash training data (hidden states) from target model. +# +# Usage: +# bash scripts/speculative/generate_qwen3_dflash_data.sh [NUM_GPUS] +# +# Output: +# One .ckpt file per training sample, saved to OUTPUT_DIR. +# Each file contains: input_ids, hidden_states (5-layer concat), loss_mask, +# attention_mask. +# ============================================================================= + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +# Use local source code instead of installed site-packages +export PYTHONPATH=$ROOT_DIR:$PYTHONPATH + +NUM_GPUS=${1:-8} + +# ---- Paths -- modify these to match your environment ---- +TARGET_MODEL_PATH="" +TRAIN_DATA_PATH="" +OUTPUT_DIR="${ROOT_DIR}/outputs/" # directory for .ckpt files + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/tools/generate_dflash_data.py \ + --target_model_name_or_path $TARGET_MODEL_PATH \ + --draft_model_config_path $ROOT_DIR/configs/qwen3_dflash.json \ + --train_data_path $TRAIN_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --model_max_length 3072 \ + --chat_template_type qwen3 \ + --batch_size 1 \ + --num_proc 16 \ + --sample_num 128 \ + --shard_size 10000 diff --git a/scripts/speculative/run_dflash_offline.sh b/scripts/speculative/run_dflash_offline.sh new file mode 100644 index 00000000..05524719 --- /dev/null +++ b/scripts/speculative/run_dflash_offline.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# ============================================================================= +# Step 2: Train DFlash draft model in OFFLINE mode. +# +# Prerequisites: +# Run generate_qwen3_dflash_data.sh first to produce the .ckpt files. +# +# Usage: +# bash scripts/speculative/run_qwen3_dflash_offline.sh [NUM_GPUS] +# ============================================================================= + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +# Use local source code instead of installed site-packages +export PYTHONPATH=$ROOT_DIR:$PYTHONPATH + +NUM_GPUS=${1:-8} + +# ---- Paths -- modify these to match your environment ---- +TARGET_MODEL_PATH="" +TRAIN_HIDDEN_PATH="" +OUTPUT_DIR="${ROOT_DIR}/outputs/" + +# WandB configuration +export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflash"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"qwen3-4b-dflash-offline"} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/tools/train_dflash_offline.py \ + --target_model_name_or_path $TARGET_MODEL_PATH \ + --draft_model_config_path $ROOT_DIR/configs/qwen3_dflash.json \ + --train_hidden_path $TRAIN_HIDDEN_PATH \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 12 \ + --per_device_train_batch_size 2 \ + --learning_rate 6e-4 \ + --warmup_ratio 0.04 \ + --max_grad_norm 1.0 \ + --model_max_length 3072 \ + --chat_template_type qwen3 \ + --attention_backend flex_attention \ + --block_size 16 \ + --num_anchors 512 \ + --loss_decay_gamma 7.0 \ + --logging_steps 50 \ + --save_strategy steps \ + --save_steps 2500 \ + --bf16 \ + --lr_scheduler_type cosine \ + --report_to wandb \ + --run_name $WANDB_RUN_NAME diff --git a/scripts/speculative/run_dflash_online.sh b/scripts/speculative/run_dflash_online.sh new file mode 100644 index 00000000..1d4a1397 --- /dev/null +++ b/scripts/speculative/run_dflash_online.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# DFlash Online Training Script for Qwen3 +# Usage: bash scripts/speculative/run_qwen3_dflash_online.sh [NUM_GPUS] [ATTENTION_BACKEND] + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +# Use local source code instead of installed site-packages +export PYTHONPATH=$ROOT_DIR:$PYTHONPATH + +NUM_GPUS=${1:-8} +ATTENTION_BACKEND=${2:-flex_attention} + +# Set paths - modify these to match your environment +TARGET_MODEL_PATH="" +TRAIN_DATA_PATH="" +OUTPUT_DIR="${ROOT_DIR}/outputs/" + +export CONFIG_DIR=${ROOT_DIR}/angelslim/compressor/speculative/train/configs + +# WandB configuration (mirrors SpecForge's --wandb-project / --wandb-name) +export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflash"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"qwen3-4b-dflash"} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/tools/train_dflash_online.py \ + --target_model_name_or_path $TARGET_MODEL_PATH \ + --draft_model_config_path $ROOT_DIR/configs/qwen3_dflash.json \ + --train_data_path $TRAIN_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --modal_type DFlash \ + --training_mode online \ + --num_train_epochs 6 \ + --per_device_train_batch_size 2 \ + --learning_rate 6e-4 \ + --warmup_ratio 0.04 \ + --max_grad_norm 1.0 \ + --model_max_length 3072 \ + --chat_template_type qwen3 \ + --attention_backend $ATTENTION_BACKEND \ + --block_size 16 \ + --num_anchors 512 \ + --loss_decay_gamma 7.0 \ + --logging_steps 50 \ + --save_strategy steps \ + --save_steps 2500 \ + --bf16 \ + --lr_scheduler_type cosine \ + --report_to wandb \ + --run_name $WANDB_RUN_NAME + diff --git a/tools/generate_dflash_data.py b/tools/generate_dflash_data.py new file mode 100755 index 00000000..b141f024 --- /dev/null +++ b/tools/generate_dflash_data.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# DFlash offline data pre-generation script. +# +# Usage: +# torchrun --nproc_per_node=8 tools/generate_dflash_data.py \ +# --target_model_name_or_path /path/to/Qwen3-4B \ +# --draft_model_config_path configs/qwen3_dflash.json \ +# --train_data_path /path/to/data.jsonl \ +# --output_dir /path/to/output/ckpts \ +# --model_max_length 3072 \ +# --chat_template_type qwen3 +# +# Each output .ckpt file contains: +# - input_ids: LongTensor [1, S] +# - hidden_states: BFloat16Tensor [1, S, D*num_target_layers] (multi-layer concat) +# - loss_mask: LongTensor [1, S] +# - attention_mask: LongTensor [1, S] + +import argparse +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + +from angelslim.compressor.speculative import ( + DatasetManager, + DraftModelConfig, + create_target_model, + get_supported_chat_template_type_strings, +) +from angelslim.utils import rank0_print + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Pre-generate DFlash training data (hidden states) from target model" + ) + + # Model + parser.add_argument("--target_model_name_or_path", type=str, required=True) + parser.add_argument("--draft_model_config_path", type=str, required=True) + parser.add_argument( + "--target_backend", + type=str, + default="hf", + choices=["hf"], + help="Target model backend", + ) + parser.add_argument( + "--torch_dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--trust_remote_code", action="store_true", default=True) + + # Data + parser.add_argument( + "--train_data_path", type=str, nargs="+", required=True, help="Input JSONL file(s)" + ) + parser.add_argument( + "--output_dir", type=str, required=True, help="Directory to save .ckpt files" + ) + parser.add_argument( + "--chat_template_type", + type=str, + default="qwen3", + help=f"Supported: {', '.join(get_supported_chat_template_type_strings())}", + ) + parser.add_argument("--model_max_length", type=int, default=3072) + parser.add_argument( + "--block_size", type=int, default=16, help="Block size for DFlash parallel prediction" + ) + parser.add_argument( + "--num_proc", type=int, default=16, help="Workers for tokenization (dataset.map)" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Samples per forward pass (keep at 1 for variable-length seqs)", + ) + parser.add_argument("--shuffle_seed", type=int, default=42) + parser.add_argument( + "--sample_num", type=int, default=None, help="Limit number of samples (for debugging)" + ) + parser.add_argument( + "--shard_size", + type=int, + default=0, + help="Save a new sub-directory every N files (0 = no sharding)", + ) + + return parser.parse_args() + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0)) + + +def get_global_rank(): + return int(os.environ.get("RANK", 0)) + + +def get_world_size(): + return int(os.environ.get("WORLD_SIZE", 1)) + + +def init_distributed(): + if get_world_size() > 1 and not dist.is_initialized(): + dist.init_process_group(backend="nccl") + local_rank = get_local_rank() + torch.cuda.set_device(local_rank) + + +def main(): + args = parse_args() + init_distributed() + + rank = get_global_rank() + world_size = get_world_size() + local_rank = get_local_rank() + + # -------------------------------------------------------------------------- + # 1. Load draft-model config (to get target_layer_ids) + # -------------------------------------------------------------------------- + rank0_print("Loading draft model config...") + draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) + dflash_config = getattr(draft_model_config, "dflash_config", {}) or {} + target_layer_ids = dflash_config.get("target_layer_ids", None) + if target_layer_ids is None: + raise ValueError( + "dflash_config.target_layer_ids not found in draft_model_config. " + f"Please set it in {args.draft_model_config_path}" + ) + rank0_print(f"DFlash target layer IDs: {target_layer_ids}") + + # -------------------------------------------------------------------------- + # 2. Load target model (on this rank's GPU) + # -------------------------------------------------------------------------- + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} + torch_dtype = dtype_map.get(args.torch_dtype, torch.bfloat16) + + rank0_print("Loading target model...") + target_model = create_target_model( + backend=args.target_backend, + model_path=args.target_model_name_or_path, + modal_type="LLM", + torch_dtype=torch_dtype, + trust_remote_code=args.trust_remote_code, + ) + rank0_print("Target model loaded successfully") + + # -------------------------------------------------------------------------- + # 3. Tokenize dataset (using DatasetManager, same as online training) + # -------------------------------------------------------------------------- + rank0_print("Building dataset...") + # Temporarily patch args so DatasetManager picks the correct builder + args.modal_type = "LLM" # DFlash uses the LLM tokenisation path + args.training_mode = "online" # We want the text→token builder, not offline .ckpt loader + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_name_or_path, trust_remote_code=True + ) + + dataset_manager = DatasetManager( + data_args=args, + tokenizer=tokenizer, + model_max_length=args.model_max_length, + chat_template_type=args.chat_template_type, + ) + + # Restore modal_type to DFlash so DFlash-specific filtering (min_loss_tokens) applies + args.modal_type = "DFlash" + + ( + _, # offline_train_dataset (unused here) + _, # offline_eval_dataset + online_train_dataset, + _, # online_eval_dataset + _, # data_collator + ) = dataset_manager.create_all_datasets() + + if online_train_dataset is None: + raise RuntimeError("No training dataset was created. Check --train_data_path.") + + rank0_print(f"Dataset size: {len(online_train_dataset)}") + + # -------------------------------------------------------------------------- + # 4. Distributed sampler: each rank processes its own shard + # -------------------------------------------------------------------------- + sampler = DistributedSampler( + online_train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + + def collate_fn(batch): + """Simple collate: each element is already a dict of 1-D tensors.""" + result = {} + for key in batch[0]: + tensors = [item[key] for item in batch] + # Tensors may be 1-D or 2-D ([1, S]) — keep original shape + try: + result[key] = torch.stack(tensors) + except Exception: + result[key] = tensors + return result + + dataloader = DataLoader( + online_train_dataset, + batch_size=args.batch_size, + sampler=sampler, + num_workers=4, + pin_memory=True, + collate_fn=collate_fn, + ) + + # -------------------------------------------------------------------------- + # 5. Output directory setup + # -------------------------------------------------------------------------- + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + rank0_print(f"Saving .ckpt files to: {output_dir}") + rank0_print(f"World size={world_size}, this rank={rank}") + + # -------------------------------------------------------------------------- + # 6. Main loop: forward target model, save hidden states + # -------------------------------------------------------------------------- + global_idx = 0 # index within this rank's portion + total = len(dataloader) + t0 = time.time() + + for batch_idx, batch in enumerate(dataloader): + input_ids = batch["input_ids"] + # Shape may be [B, 1, S] or [B, S] depending on how dataset stores it + if input_ids.dim() == 3: + input_ids = input_ids.squeeze(1) # → [B, S] + + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + if attention_mask.dim() == 3: + attention_mask = attention_mask.squeeze(1) + + loss_mask = batch["loss_mask"] + if loss_mask.dim() == 3: + loss_mask = loss_mask.squeeze(1) + + input_ids = input_ids.to(f"cuda:{local_rank}") + attention_mask = attention_mask.to(f"cuda:{local_rank}") + + # Run target model + hidden_states, _ = target_model.get_hidden_states_and_logits( + input_ids=input_ids, + attention_mask=attention_mask, + aux_hidden_states_layer_ids=target_layer_ids, + ) + # hidden_states: [B, S, D*len(target_layer_ids)] + + # Save one .ckpt per sample in the batch + for i in range(input_ids.size(0)): + # Use a globally unique name: rank × position within rank + sample_idx = rank * len(dataloader) + global_idx + global_idx += 1 + + if args.shard_size > 0: + shard_id = sample_idx // args.shard_size + save_dir = output_dir / f"shard_{shard_id:05d}" + save_dir.mkdir(parents=True, exist_ok=True) + else: + save_dir = output_dir + + ckpt_path = save_dir / f"sample_{sample_idx:08d}_rank{rank}.ckpt" + + ckpt = { + "input_ids": input_ids[i : i + 1].cpu(), # [1, S] + "hidden_states": hidden_states[i : i + 1].cpu().to(torch.bfloat16), # [1, S, D*L] + "loss_mask": loss_mask[i : i + 1].cpu(), # [1, S] + "attention_mask": attention_mask[i : i + 1].cpu(), # [1, S] + } + torch.save(ckpt, ckpt_path) + + # Progress log + if batch_idx % 100 == 0: + elapsed = time.time() - t0 + samples_done = (batch_idx + 1) * args.batch_size + speed = samples_done / elapsed if elapsed > 0 else 0 + rank0_print( + f"[rank {rank}] {batch_idx + 1}/{total} batches | " + f"{speed:.1f} samples/s | elapsed {elapsed:.0f}s" + ) + + if world_size > 1: + dist.barrier() + + rank0_print(f"Data generation complete. " f"Saved files to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/tools/train_dflash_offline.py b/tools/train_dflash_offline.py new file mode 100755 index 00000000..4f8ad528 --- /dev/null +++ b/tools/train_dflash_offline.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# DFlash offline training script. +# Trains a DFlash draft model using pre-computed hidden states (.ckpt files). +# +# Workflow: +# Step 1 (data generation): +# bash scripts/speculative/generate_qwen3_dflash_data.sh +# Step 2 (offline training, this script): +# bash scripts/speculative/run_qwen3_dflash_offline.sh + +import argparse +import os +from pathlib import Path + +import torch +import transformers + +from angelslim.compressor.speculative import ( + DraftModelConfig, + Eagle3TrainerFactory, + create_draft_model, + get_supported_chat_template_type_strings, +) +from angelslim.compressor.speculative.train.data.data_utils import ( + DataCollatorWithPadding, +) +from angelslim.compressor.speculative.train.data.dataset_builder.offline_dataset_builder import ( + OfflineEagle3Dataset, +) +from angelslim.utils import rank0_print + +# --------------------------------------------------------------------------- +# Offline DFlash Dataset +# --------------------------------------------------------------------------- + + +class OfflineDFlashDataset(OfflineEagle3Dataset): + """ + DFlash variant of the offline dataset. + + Each .ckpt file must contain: + - input_ids: LongTensor [1, S] + - hidden_states: BFloat16Tensor [1, S, D*num_target_layers] ← multi-layer hidden states + - loss_mask: LongTensor [1, S] + - attention_mask: LongTensor [1, S] (auto-generated if missing) + + Note: DFlash does NOT need target_hiddens (only Eagle3 offline uses the + single final-layer hidden state). The multi-layer hidden_states is the + context feature passed directly to the DFlash cross-attention. + """ + + REQUIRED_KEYS = ["input_ids", "hidden_states", "loss_mask"] + + def _load_ckpt(self, idx: int): + import warnings + + ckpt_path = self.ckpt_files[idx] + try: + data = torch.load(ckpt_path, map_location="cpu", weights_only=False) + except Exception as e: + warnings.warn( + f"Failed to load {ckpt_path}: {e}. Skipping.", RuntimeWarning, stacklevel=2 + ) + return None + + missing = [k for k in self.REQUIRED_KEYS if k not in data] + if missing: + warnings.warn( + f"{ckpt_path} missing keys {missing}. Skipping.", RuntimeWarning, stacklevel=2 + ) + return None + + # Auto-generate attention_mask if absent + if "attention_mask" not in data: + data["attention_mask"] = torch.ones_like(data["input_ids"]) + + return data + + +# --------------------------------------------------------------------------- +# Argument parser +# --------------------------------------------------------------------------- + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train DFlash draft model (offline mode)") + + # Model + m = parser.add_argument_group("Model Arguments") + m.add_argument("--target_model_name_or_path", type=str, required=True) + m.add_argument("--draft_model_config_path", type=str, required=True) + m.add_argument( + "--torch_dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"] + ) + m.add_argument("--trust_remote_code", action="store_true", default=True) + m.add_argument("--embed_weight_key", type=str, default="model.embed_tokens.weight") + m.add_argument("--lm_head_key", type=str, default="lm_head.weight") + + # DFlash-specific (override values in config JSON) + d = parser.add_argument_group("DFlash Arguments") + d.add_argument("--block_size", type=int, default=None) + d.add_argument("--num_anchors", type=int, default=None) + d.add_argument("--loss_decay_gamma", type=float, default=None) + d.add_argument("--mask_token_id", type=int, default=None) + d.add_argument( + "--attention_backend", type=str, default=None, choices=["flex_attention", "sdpa", "eager"] + ) + + # Data + da = parser.add_argument_group("Data Arguments") + da.add_argument( + "--train_hidden_path", + type=str, + required=True, + help="Directory of pre-computed training .ckpt files", + ) + da.add_argument( + "--eval_hidden_path", + type=str, + default=None, + help="Directory of pre-computed eval .ckpt files (optional)", + ) + da.add_argument( + "--chat_template_type", + type=str, + default="qwen3", + help=f"Supported: {', '.join(get_supported_chat_template_type_strings())}", + ) + da.add_argument("--model_max_length", type=int, default=3072) + da.add_argument("--num_proc", type=int, default=16) + da.add_argument( + "--cache_in_memory", + action="store_true", + default=False, + help="Cache all .ckpt files in RAM (fast but memory-intensive)", + ) + + # Training + t = parser.add_argument_group("Training Arguments") + t.add_argument("--output_dir", type=str, required=True) + t.add_argument("--optim", type=str, default="adamw_torch") + t.add_argument("--num_train_epochs", type=int, default=6) + t.add_argument("--per_device_train_batch_size", type=int, default=2) + t.add_argument("--per_device_eval_batch_size", type=int, default=2) + t.add_argument("--gradient_accumulation_steps", type=int, default=1) + t.add_argument("--learning_rate", type=float, default=6e-4) + t.add_argument("--weight_decay", type=float, default=0.0) + t.add_argument("--warmup_steps", type=int, default=0) + t.add_argument("--warmup_ratio", type=float, default=0.04) + t.add_argument("--max_grad_norm", type=float, default=1.0) + t.add_argument("--logging_steps", type=int, default=50) + t.add_argument("--save_steps", type=float, default=2500) + t.add_argument("--save_total_limit", type=int, default=None) + t.add_argument("--eval_steps", type=int, default=500) + t.add_argument("--save_strategy", type=str, default="steps") + t.add_argument("--eval_strategy", type=str, default="no") + t.add_argument("--lr_scheduler_type", type=str, default="cosine") + t.add_argument("--fp16", action="store_true") + t.add_argument("--bf16", action="store_true") + t.add_argument("--deepspeed", type=str, default=None) + t.add_argument("--report_to", type=str, default="none") + t.add_argument("--run_name", type=str, default=None) + t.add_argument("--training_time_test_length", type=int, default=7) + + # WandB + w = parser.add_argument_group("WandB Arguments") + w.add_argument("--wandb_project", type=str, default=None) + w.add_argument("--wandb_run_name", type=str, default=None) + + return parser.parse_args() + + +def _setup_wandb(args): + if args.report_to not in ("wandb", "all"): + return + if args.wandb_project: + os.environ["WANDB_PROJECT"] = args.wandb_project + run_name = args.wandb_run_name or args.run_name or os.environ.get("WANDB_RUN_NAME") + if run_name: + os.environ["WANDB_RUN_NAME"] = run_name + args.run_name = run_name + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_rank == 0: + try: + import wandb + + wandb.init( + project=os.environ.get("WANDB_PROJECT", "angelslim-dflash"), + name=run_name, + resume="allow", + ) + except ImportError: + print("[WARNING] wandb not installed. Install via: pip install wandb") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def train(): + args = parse_args() + _setup_wandb(args) + + # dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} + # torch_dtype = dtype_map.get(args.torch_dtype, torch.bfloat16) + + # ------------------------------------------------------------------ + # 1. Draft model config + # ------------------------------------------------------------------ + rank0_print("Loading draft model config...") + draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) + draft_model_config.target_model_name_or_path = args.target_model_name_or_path + draft_model_config.embed_weight_key = args.embed_weight_key + draft_model_config.trust_remote_code = args.trust_remote_code + + # Override DFlash params from CLI if specified + if args.block_size is not None: + draft_model_config.block_size = args.block_size + if args.num_anchors is not None: + draft_model_config.num_anchors = args.num_anchors + if args.loss_decay_gamma is not None: + draft_model_config.loss_decay_gamma = args.loss_decay_gamma + if args.attention_backend is not None: + draft_model_config.attention_backend = args.attention_backend + draft_model_config._attn_implementation = args.attention_backend + if args.mask_token_id is not None: + dfc = getattr(draft_model_config, "dflash_config", None) or {} + dfc["mask_token_id"] = args.mask_token_id + draft_model_config.dflash_config = dfc + + # ------------------------------------------------------------------ + # 2. Draft model + # ------------------------------------------------------------------ + rank0_print("Loading draft model...") + draft_model = create_draft_model(draft_model_config) + rank0_print(f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}") + + # ------------------------------------------------------------------ + # 3. Offline datasets + # ------------------------------------------------------------------ + rank0_print(f"Loading offline training data from: {args.train_hidden_path}") + train_dataset = OfflineDFlashDataset( + data_dir=args.train_hidden_path, + file_pattern="*.ckpt", + cache_in_memory=args.cache_in_memory, + ) + rank0_print(f"Training samples: {len(train_dataset)}") + + eval_dataset = None + if args.eval_hidden_path: + rank0_print(f"Loading offline eval data from: {args.eval_hidden_path}") + eval_dataset = OfflineDFlashDataset( + data_dir=args.eval_hidden_path, + file_pattern="*.ckpt", + cache_in_memory=args.cache_in_memory, + ) + rank0_print(f"Eval samples: {len(eval_dataset)}") + + data_collator = DataCollatorWithPadding() + + # ------------------------------------------------------------------ + # 4. TrainingArguments + # ------------------------------------------------------------------ + training_args = transformers.TrainingArguments( + output_dir=args.output_dir, + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps, + warmup_ratio=args.warmup_ratio, + max_grad_norm=args.max_grad_norm, + optim=args.optim, + lr_scheduler_type=args.lr_scheduler_type, + fp16=args.fp16, + bf16=args.bf16, + eval_strategy=args.eval_strategy, + save_strategy=args.save_strategy, + save_steps=args.save_steps, + save_total_limit=args.save_total_limit, + eval_steps=args.eval_steps, + logging_steps=args.logging_steps, + report_to=args.report_to, + run_name=args.run_name, + deepspeed=args.deepspeed, + remove_unused_columns=False, + ) + + # ------------------------------------------------------------------ + # 5. Trainer -- use Eagle3TrainerFactory + # ------------------------------------------------------------------ + rank0_print("Initializing trainer...") + trainer = Eagle3TrainerFactory.create( + training_mode="offline", + modal_type="DFlash", + draft_model=draft_model, + target_model=None, # Not needed — hidden states are pre-computed + length=args.training_time_test_length, + draft_model_config=draft_model_config, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=data_collator, + ) + + # ------------------------------------------------------------------ + # 6. Train + # ------------------------------------------------------------------ + output_dir = Path(training_args.output_dir) + if list(output_dir.glob("checkpoint-*")): + rank0_print("Resuming training from checkpoint...") + trainer.train(resume_from_checkpoint=True) + else: + rank0_print("Starting fresh training run...") + trainer.train() + + rank0_print("Training completed!") + + +if __name__ == "__main__": + train() diff --git a/tools/train_dflash_online.py b/tools/train_dflash_online.py new file mode 100755 index 00000000..5a33a16d --- /dev/null +++ b/tools/train_dflash_online.py @@ -0,0 +1,516 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Online Training Script. + +Based on train_eagle3_online.py but adapted for DFlash's block-parallel +cross-attention training approach. +""" + +import argparse +import os +from pathlib import Path + +import torch +import transformers + +from angelslim.compressor.speculative import ( + DatasetManager, + DraftModelConfig, + Eagle3TrainerFactory, + create_draft_model, + create_target_model, + get_supported_chat_template_type_strings, +) +from angelslim.utils import rank0_print + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Train DFlash online model") + + # Model arguments + model_group = parser.add_argument_group("Model Arguments") + model_group.add_argument( + "--modal_type", + type=str, + default="DFlash", + help="Modal type, should be DFlash for DFlash training", + ) + model_group.add_argument( + "--training_mode", + type=str, + default="online", + choices=["online"], + help="Training mode (only online is supported for DFlash)", + ) + model_group.add_argument( + "--target_model_name_or_path", + type=str, + default=None, + help="Path to target model", + ) + model_group.add_argument( + "--draft_model_config_path", + type=str, + default=None, + help="Path to draft model config", + ) + model_group.add_argument( + "--target_backend", + type=str, + default="hf", + choices=["hf"], + help="Target model backend: hf (HuggingFace Transformers)", + ) + model_group.add_argument( + "--torch_dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Data type for model weights", + ) + model_group.add_argument( + "--trust_remote_code", + action="store_true", + default=True, + help="Whether to trust remote code when loading models", + ) + model_group.add_argument( + "--embed_weight_key", + type=str, + default="model.embed_tokens.weight", + help="Key for embedding weights in model config", + ) + + # DFlash-specific arguments + dflash_group = parser.add_argument_group("DFlash Arguments") + dflash_group.add_argument( + "--block_size", + type=int, + default=16, + help="Block size for DFlash parallel prediction", + ) + dflash_group.add_argument( + "--num_anchors", + type=int, + default=512, + help="Number of anchor positions per sequence", + ) + dflash_group.add_argument( + "--loss_decay_gamma", + type=float, + default=None, + help=( + "Gamma for exponential loss decay weighting. " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. " + "None disables decay." + ), + ) + dflash_group.add_argument( + "--attention_backend", + type=str, + default="flex_attention", + choices=["eager", "sdpa", "flex_attention"], + help="Attention backend for draft model", + ) + dflash_group.add_argument( + "--mask_token_id", + type=int, + default=None, + help="MASK token ID. If not provided, uses config or auto-detect.", + ) + + # Data arguments + data_group = parser.add_argument_group("Data Arguments") + data_group.add_argument( + "--train_data_path", + type=str, + nargs="+", + required=True, + help="Path to training data file(s) (JSON format). Can specify multiple files.", + ) + data_group.add_argument( + "--eval_data_path", + type=str, + default=None, + help="Path to evaluation data file (JSON format)", + ) + data_group.add_argument( + "--chat_template_type", + type=str, + default="qwen3", + help=( + f"Chat template type for conversation formatting. " + f"Supported types: {', '.join(get_supported_chat_template_type_strings())}" + ), + ) + data_group.add_argument( + "--num_proc", + type=int, + default=16, + help="Number of processes for data preprocessing", + ) + data_group.add_argument( + "--sample_num", + type=int, + default=None, + help="Number of max samples for data preprocessing", + ) + data_group.add_argument( + "--shuffle_seed", type=int, default=42, help="Random seed for shuffling dataset" + ) + data_group.add_argument( + "--display", + action="store_true", + default=False, + help="Display data samples during preprocessing", + ) + + # Training arguments + training_group = parser.add_argument_group("Training Arguments") + training_group.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for model checkpoints", + ) + training_group.add_argument( + "--optim", type=str, default="adamw_torch", help="Optimizer to use" + ) + training_group.add_argument( + "--training_time_test_length", + type=int, + default=1, + help="Not used for DFlash (kept for compatibility)", + ) + training_group.add_argument( + "--model_max_length", + type=int, + default=3072, + help="Maximum sequence length", + ) + training_group.add_argument( + "--per_device_train_batch_size", + type=int, + default=2, + help="Batch size per device during training", + ) + training_group.add_argument( + "--per_device_eval_batch_size", + type=int, + default=2, + help="Batch size per device during evaluation", + ) + training_group.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass", + ) + training_group.add_argument( + "--num_train_epochs", + type=int, + default=6, + help="Total number of training epochs to perform", + ) + training_group.add_argument( + "--learning_rate", type=float, default=6e-4, help="Initial learning rate" + ) + training_group.add_argument( + "--weight_decay", type=float, default=0.0, help="Weight decay to apply" + ) + training_group.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Maximum gradient norm for clipping", + ) + training_group.add_argument( + "--warmup_steps", type=int, default=0, help="Number of steps for warmup" + ) + training_group.add_argument( + "--warmup_ratio", type=float, default=0.04, help="Ratio of warmup steps" + ) + training_group.add_argument( + "--logging_steps", type=int, default=50, help="Log every X updates steps" + ) + training_group.add_argument( + "--save_steps", + type=float, + default=5000, + help="Save checkpoint every X updates steps", + ) + training_group.add_argument( + "--eval_steps", type=int, default=1000, help="Run evaluation every X steps" + ) + training_group.add_argument( + "--save_total_limit", + type=int, + default=None, + help="Limit the total amount of checkpoints", + ) + training_group.add_argument( + "--deepspeed", type=str, default=None, help="DeepSpeed config file" + ) + training_group.add_argument("--fp16", action="store_true", help="Whether to use fp16 training") + training_group.add_argument("--bf16", action="store_true", help="Whether to use bf16 training") + training_group.add_argument( + "--save_strategy", type=str, default="no", help="Save strategy for checkpoints" + ) + training_group.add_argument( + "--eval_strategy", type=str, default="no", help="Evaluation strategy" + ) + training_group.add_argument( + "--lr_scheduler_type", + type=str, + default="cosine", + help="Learning rate scheduler type", + ) + training_group.add_argument("--run_name", type=str, default=None, help="Run name for tracking") + training_group.add_argument( + "--report_to", + type=str, + default="none", + help="The list of integrations to report the results and logs to (e.g. 'wandb')", + ) + + # WandB arguments (mirrors SpecForge's --wandb-project / --wandb-name) + wandb_group = parser.add_argument_group("WandB Arguments") + wandb_group.add_argument( + "--wandb_project", + type=str, + default=None, + help="WandB project name. Overrides WANDB_PROJECT env var if set.", + ) + wandb_group.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="WandB run name. Overrides --run_name if both are set.", + ) + + return parser.parse_args() + + +def _setup_wandb(args) -> None: + """Set up WandB environment variables and initialize wandb run on rank 0. + + Mirrors the --wandb-project / --wandb-name pattern from SpecForge. + Priority: CLI args > env vars > defaults. + """ + if args.report_to not in ("wandb", "all"): + return + + # CLI args take priority over env vars + if args.wandb_project: + os.environ["WANDB_PROJECT"] = args.wandb_project + + # Resolve run name: --wandb_run_name > --run_name > env WANDB_RUN_NAME + run_name = args.wandb_run_name or args.run_name or os.environ.get("WANDB_RUN_NAME") + if run_name: + os.environ["WANDB_RUN_NAME"] = run_name + # Propagate back so TrainingArguments picks it up + args.run_name = run_name + + # Explicit wandb.init() on rank 0 so project/name are registered immediately + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_rank == 0: + try: + import wandb + + wandb.init( + project=os.environ.get("WANDB_PROJECT", "angelslim-dflash"), + name=run_name, + resume="allow", + ) + except ImportError: + print("[WARNING] wandb not installed. " "Install via: pip install wandb") + + +def train(): + args = parse_args() + _setup_wandb(args) + + # Parse torch dtype + dtype_mapping = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + torch_dtype = dtype_mapping.get(args.torch_dtype, torch.bfloat16) + + rank0_print("Loading draft model config...") + draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) + target_model_type = getattr(draft_model_config, "target_model_type", None) + + # Inject DFlash-specific config into the draft model config + # so the trainer can access them + draft_model_config.target_model_name_or_path = args.target_model_name_or_path + draft_model_config.embed_weight_key = args.embed_weight_key + draft_model_config.trust_remote_code = args.trust_remote_code + + # Override DFlash params from CLI if specified + if args.block_size is not None: + draft_model_config.block_size = args.block_size + if args.num_anchors is not None: + draft_model_config.num_anchors = args.num_anchors + if args.loss_decay_gamma is not None: + draft_model_config.loss_decay_gamma = args.loss_decay_gamma + if args.attention_backend is not None: + draft_model_config.attention_backend = args.attention_backend + if args.mask_token_id is not None: + if ( + not hasattr(draft_model_config, "dflash_config") + or draft_model_config.dflash_config is None + ): + draft_model_config.dflash_config = {} + draft_model_config.dflash_config["mask_token_id"] = args.mask_token_id + + # Set attention implementation + draft_model_config._attn_implementation = args.attention_backend + + # Create target model with specified backend + rank0_print(f"Loading target model with {args.target_backend} backend...") + target_model = create_target_model( + backend=args.target_backend, + model_path=args.target_model_name_or_path, + modal_type="LLM", # DFlash uses standard LLM target model + torch_dtype=torch_dtype, + trust_remote_code=args.trust_remote_code, + target_model_type=target_model_type, + ) + rank0_print("Target model loaded successfully") + + # Configure target model to capture the right layers for DFlash + dflash_config = getattr(draft_model_config, "dflash_config", {}) or {} + target_layer_ids = dflash_config.get("target_layer_ids", None) + if target_layer_ids is not None: + # Set aux_hidden_states_layer_ids to match DFlash's target_layer_ids + draft_model_config.aux_hidden_states_layer_ids = target_layer_ids + rank0_print(f"DFlash target layer IDs: {target_layer_ids}") + + # Create draft model + rank0_print("Loading draft model...") + rank0_print(f"draft_model_config: {draft_model_config}") + draft_model = create_draft_model(draft_model_config) + rank0_print("Draft model loaded successfully") + rank0_print(f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}") + + # Create datasets using DatasetManager + rank0_print( + "Creating training and evaluation datasets " + f"with chat template type: {args.chat_template_type}..." + ) + # DatasetBuilderFactory doesn't know "DFlash"; DFlash uses the same data + # format as "LLM", so temporarily override modal_type for dataset creation. + args.modal_type = "LLM" + dataset_manager = DatasetManager( + data_args=args, + tokenizer=target_model.tokenizer, + model_max_length=args.model_max_length, + chat_template_type=args.chat_template_type, + display=args.display, + target_model_type=target_model_type, + ) + args.modal_type = "DFlash" # restore for trainer factory + train_dataset, eval_dataset, data_collator = dataset_manager.create_online_datasets() + rank0_print( + f"Train dataset size: {len(train_dataset)}, " + f"Eval dataset size: {len(eval_dataset) if eval_dataset else 0}" + ) + + # Create TrainingArguments + basic_args = { + "output_dir": args.output_dir, + "num_train_epochs": args.num_train_epochs, + } + + batch_args = { + "per_device_train_batch_size": args.per_device_train_batch_size, + "per_device_eval_batch_size": args.per_device_eval_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + "remove_unused_columns": False, + } + + optimizer_args = { + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "warmup_steps": args.warmup_steps, + "warmup_ratio": args.warmup_ratio, + "optim": args.optim, + "lr_scheduler_type": args.lr_scheduler_type, + "max_grad_norm": args.max_grad_norm, + } + + precision_args = { + "fp16": args.fp16, + "bf16": args.bf16, + } + + checkpoint_args = { + "eval_strategy": args.eval_strategy, + "save_strategy": args.save_strategy, + "save_steps": args.save_steps, + "save_total_limit": args.save_total_limit, + } + + logging_args = { + "logging_steps": args.logging_steps, + "eval_steps": args.eval_steps, + "report_to": args.report_to, + "run_name": args.run_name, + } + + distributed_args = { + "deepspeed": args.deepspeed, + } + + training_args = transformers.TrainingArguments( + **basic_args, + **batch_args, + **optimizer_args, + **precision_args, + **checkpoint_args, + **logging_args, + **distributed_args, + ) + + # Initialize trainer + rank0_print("Initializing DFlash trainer...") + trainer = Eagle3TrainerFactory.create( + training_mode=args.training_mode, + modal_type=args.modal_type, + draft_model=draft_model, + target_model=target_model, + length=args.training_time_test_length, + draft_model_config=draft_model_config, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=data_collator, + ) + + # Start training + if list(Path(training_args.output_dir).glob("checkpoint-*")): + rank0_print("Resuming training from checkpoint...") + trainer.train(resume_from_checkpoint=True) + else: + rank0_print("Starting DFlash training...") + trainer.train() + rank0_print("DFlash training completed!") + + +if __name__ == "__main__": + train()