diff --git a/README.md b/README.md index de3c3f1e..7d952b65 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress

## 📣Latest News +- [26/06/01] We have released **DFlare**, a block-diffusion speculative decoding framework with layer-wise fusion that achieves up to **5.52× end-to-end speedup**. [[Docs]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/dflare.html) - [26/05/27] We have released **D-Cut**, an adaptive verification depth pruning technique for speculative decoding. [[Docs]](https://angelslim.readthedocs.io/zh-cn/latest/dcut.html) - [26/05/20] We support Distillation for full-precision HuggingFace models and **quantized QAT-style** models, as detailed in the [distillation documentation](https://angelslim.readthedocs.io/zh-cn/latest/features/distill/index.html). - [26/05/08] We have released STQ1_0 kernel for 1.25-bit model and given a PR to llama.cpp [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836) ! If you have any questions or suggestions for STQ_0, welcome to comment under the PR !🔥🔥🔥 @@ -92,6 +93,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress diff --git a/README_cn.md b/README_cn.md index 079ab305..ec72132c 100644 --- a/README_cn.md +++ b/README_cn.md @@ -22,6 +22,7 @@

## 📣最新进展 +- [26/06/01] 我们发布了 **DFlare**,一种基于 layer-wise fusion 的块扩散投机解码框架,端到端加速比可达 **5.52×**。[[文档]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/dflare.html) - [26/05/27] 我们发布了 **D-Cut**,一种用于投机解码的自适应验证深度裁剪技术。[[文档]](https://angelslim.readthedocs.io/zh-cn/latest/dcut.html) - [26/05/20] 我们支持了模型蒸馏功能,适用于huggingface 全精度或者**QAT量化**模型,详细步骤可以参考[文档](https://angelslim.readthedocs.io/zh-cn/latest/features/distill/index.html).🔥🔥🔥 - [26/05/08] 我们发布了用于 1.25-bit 模型的 STQ1_0 内核,并向 llama.cpp 提交了 [PR #22836](https://github.com/ggml-org/llama.cpp/pull/22836)!如果您对 STQ_0 有任何疑问或建议,欢迎在该 PR 下留言!🔥🔥🔥 @@ -93,6 +94,7 @@ diff --git a/angelslim/compressor/speculative/train/data/data_utils.py b/angelslim/compressor/speculative/train/data/data_utils.py index a452d5b3..583588cd 100644 --- a/angelslim/compressor/speculative/train/data/data_utils.py +++ b/angelslim/compressor/speculative/train/data/data_utils.py @@ -91,7 +91,6 @@ def convert_ultrachat_data(row, dataset_column="messages"): return {"conversations": converted_messages, "id": row["prompt_id"]} -# Copied from https://github.com/sgl-project/SpecForge/blob/main/specforge/data/preprocessing.py # noqa: E501 def process_token_dict_to_mappings( token_dict, draft_vocab_size: int, diff --git a/angelslim/compressor/speculative/train/models/draft/__init__.py b/angelslim/compressor/speculative/train/models/draft/__init__.py index c056ce23..58940d9f 100644 --- a/angelslim/compressor/speculative/train/models/draft/__init__.py +++ b/angelslim/compressor/speculative/train/models/draft/__init__.py @@ -14,6 +14,7 @@ from .draft_model_factory import DraftModelConfig, create_draft_model from .llama_eagle3 import CosyVoice3Eagle3LlamaForCausalLM, Eagle3LlamaForCausalLM +from .qwen_dflare import QwenDFlareDraftModel from .qwen_dflash import QwenDFlashDraftModel __all__ = [ @@ -22,4 +23,5 @@ "Eagle3LlamaForCausalLM", "CosyVoice3Eagle3LlamaForCausalLM", "QwenDFlashDraftModel", + "QwenDFlareDraftModel", ] diff --git a/angelslim/compressor/speculative/train/models/draft/qwen_dflare.py b/angelslim/compressor/speculative/train/models/draft/qwen_dflare.py new file mode 100644 index 00000000..51b32912 --- /dev/null +++ b/angelslim/compressor/speculative/train/models/draft/qwen_dflare.py @@ -0,0 +1,436 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlare Draft Model for Qwen3 architecture. + +AngelSlim DFlare is an enhanced DFlash variant with two structural changes +compared to ``QwenDFlashDraftModel``: + +1. Cross-attention uses **separate** k/v projections for context (target hidden + states) vs. noise (draft tokens): ``k_proj_target/v_proj_target`` for + context, ``k_proj/v_proj`` for noise. +2. Multi-layer target hidden states are fused via a learnable + ``layer_fusion_weights[D, T]`` matrix (softmax-normalised, einsum'd) instead + of a single ``Linear(T*H -> H)`` projection. Each draft layer learns its own + weighted combination of the T captured target layers. + +Training-side logic (anchor sampling, BlockMask, weighted CE loss, accuracy) +is **identical** to DFlash, so this model is consumed by the same +``OnlineDFlashTrainer``. +""" + +from typing import Callable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import DynamicCache +from transformers.cache_utils import Cache +from transformers.models.qwen3.modeling_qwen3 import ( + ALL_ATTENTION_FUNCTIONS, + FlashAttentionKwargs, + GradientCheckpointingLayer, + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward, + rotate_half, +) +from typing_extensions import Unpack + +from .draft_model_factory import DraftModelFactory + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Compute target layer IDs to capture from the target model.""" + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) + ] + return target_layer_ids + + +def extract_context_feature( + hidden_states: list, + layer_ids: Optional[List[int]], +) -> torch.Tensor: + """Extract and concatenate hidden states from specified layers.""" + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +class Qwen3DFlareAttention(nn.Module): + """Multi-headed cross-attention for DFlare. + + Q comes from draft hidden states. KV is the concatenation of context + (target hidden) and noise (draft) projections, but unlike DFlash the + context uses dedicated ``k_proj_target / v_proj_target`` parameters. + """ + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj_target = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj_target = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = ( + config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + k_ctx = self.k_proj_target(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj_target(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + attn_fn: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlareDecoderLayer(GradientCheckpointingLayer): + """DFlare decoder layer with cross-attention to context.""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlareAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@DraftModelFactory.register +class QwenDFlareDraftModel(Qwen3PreTrainedModel): + """DFlare Draft Model for Qwen3 architecture. + + Same input/output contract as ``QwenDFlashDraftModel`` (consumed by the + same ``OnlineDFlashTrainer``), with two structural improvements: + * separate context/noise k,v projections inside cross-attention; + * learnable per-draft-layer fusion weights over target layers. + """ + + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlareDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + Qwen3DFlareDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.num_draft_layers = config.num_hidden_layers + # We intentionally read from ``dflash_config`` to remain compatible with + # the existing trainer, which extracts ``mask_token_id`` etc. from the + # same key. + dflash_config = getattr(config, "dflash_config", {}) or {} + self.target_layer_ids = dflash_config.get( + "target_layer_ids", + build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), + ) + self.num_target_layers = len(self.target_layer_ids) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.layer_fusion_weights = nn.Parameter( + torch.empty(self.num_draft_layers, self.num_target_layers) + ) + self._init_fusion_weights() + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = dflash_config.get("mask_token_id", None) + self.post_init() + + def _init_fusion_weights(self) -> None: + nn.init.constant_(self.layer_fusion_weights, 0.0) + for d_idx in range(self.num_draft_layers): + t_idx = min( + self.num_target_layers - 1, + int((d_idx / self.num_draft_layers) * self.num_target_layers), + ) + self.layer_fusion_weights.data[d_idx, t_idx] = 2.0 + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> torch.Tensor: + hidden_states = noise_embedding + bsz, seq_len, _ = target_hidden.shape + # target_hidden arrives as concatenation along feature dim of T target + # layers' hidden states: [B, S, T*H]. Reshape to per-layer tensor. + target_hidden_reshaped = target_hidden.view( + bsz, seq_len, self.num_target_layers, self.config.hidden_size + ) + fusion_probs = torch.softmax(self.layer_fusion_weights, dim=1) + # bsth (target) x dt (per-draft-layer fusion) -> bsdh (per-draft-layer) + fused_hidden = torch.einsum("bsth,dt->bsdh", target_hidden_reshaped, fusion_probs) + fused_hidden = self.hidden_norm(fused_hidden) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for i, layer in enumerate(self.layers): + layer_target_hidden = fused_hidden[:, :, i, :] + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=layer_target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + @torch.inference_mode() + def spec_generate( + self, + target: nn.Module, + input_ids: torch.LongTensor, + max_new_tokens: int, + stop_token_ids: List[int], + temperature: float, + ): + """Speculative generation with DFlare draft model.""" + self.eval() + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + block_size = self.block_size + output_ids = torch.full( + (1, max_length + block_size), + self.mask_token_id, + dtype=torch.long, + device=target.device, + ) + position_ids = torch.arange(output_ids.shape[1], device=target.device).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = sample(output.logits, temperature) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids) + + # Decode stage + acceptance_lengths = [] + start = input_ids.shape[1] + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head( + self( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[ + :, past_key_values_draft.get_seq_length() : start + block_size + ], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size + 1 :, :] + ) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = ( + (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() + ) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ + :, : acceptance_length + 1 + ] + output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = extract_context_feature(output.hidden_states, self.target_layer_ids)[ + :, : acceptance_length + 1, : + ] + acceptance_lengths.append(acceptance_length + 1) + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] + for stop_token_id in stop_token_ids + ): + break + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != self.mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin( + output_ids[0][num_input_tokens:], stop_token_ids + ).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1] + + return output_ids diff --git a/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py b/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py index faf00343..ef824c0b 100755 --- a/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py +++ b/angelslim/compressor/speculative/train/models/draft/qwen_dflash.py @@ -14,9 +14,9 @@ """DFlash Draft Model for Qwen3 architecture. -Migrated from SpecForge's specforge/modeling/draft/dflash.py. -Uses cross-attention between draft blocks and context hidden states, -fundamentally different from Eagle3's concat + self-attention approach. +AngelSlim DFlash draft model using cross-attention between draft blocks and +context hidden states from the target model — fundamentally different from +Eagle3's concat + self-attention approach. """ from typing import Callable, List, Optional, Tuple diff --git a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py index 0b896050..58b9bfc5 100644 --- a/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py +++ b/angelslim/compressor/speculative/train/models/target/target_model_wrapper.py @@ -647,12 +647,20 @@ def load_model(self) -> None: # Prepare model loading configuration model_kwargs = self._prepare_model_kwargs(device) + print_with_rank( + f"Target model attn_implementation: " + f"{model_kwargs.get('attn_implementation', 'NOT SET (will use HF default)')}" + ) # Load and configure model self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs) self._freeze_model_parameters() self.model.eval() + # Verify attention implementation actually used + _actual_attn = getattr(self.model.config, "_attn_implementation", "unknown") + print_with_rank(f"Target model loaded. Actual attn_implementation: {_actual_attn}") + # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) @@ -670,8 +678,15 @@ def _prepare_model_kwargs(self, device: str) -> dict: "torch_dtype": torch.bfloat16, "device_map": device, "trust_remote_code": True, + # AngelSlim DFlash uses flash_attention_2 by default to match + # the inference kernel and avoid train/test mismatch. + "attn_implementation": "flash_attention_2", } - default_kwargs.update(self.kwargs) + # Only pass through kwargs that are valid for from_pretrained; + # filter out non-model kwargs like modal_type, target_model_type, etc. + _non_model_keys = {"modal_type", "target_model_type"} + filtered = {k: v for k, v in self.kwargs.items() if k not in _non_model_keys} + default_kwargs.update(filtered) return default_kwargs def _freeze_model_parameters(self) -> None: @@ -688,29 +703,68 @@ def get_hidden_states_and_logits( """ Extract hidden states and logits using Transformers backend. + Processes each sample INDIVIDUALLY (without padding). Batch processing + with padding causes numerical differences in hidden states even with + flash_attention_2, because padding tokens still participate in + attention and leak into other positions' representations. + Args: - input_ids: Input token IDs - attention_mask: Attention mask + input_ids: Input token IDs [batch_size, seq_len] + attention_mask: Attention mask [batch_size, seq_len] **kwargs: May contain 'aux_hidden_states_layer_ids' to specify custom layers Returns: - Tuple of (concatenated_hidden_states, logits) + Tuple of (concatenated_hidden_states, logits) padded back to + [batch_size, seq_len, ...] so downstream collator/trainer code is + unchanged. Padding positions in the returned tensors are zero; + they are masked out by ``loss_mask`` later. """ - with torch.no_grad(): - outputs = self.model( - input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - output_logits=True, - use_cache=False, # match SpecForge: no KV-cache during training - ) - - # Extract auxiliary hidden states aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None) - hidden_states = self._extract_auxiliary_hidden_states(outputs.hidden_states, aux_layer_ids) + bsz, seq_len = input_ids.shape - # Return hidden states and logits on the same device as input - return hidden_states, outputs.logits.to(input_ids.device) + # Process each sample individually to avoid padding-induced hidden + # state corruption (AngelSlim DFlash per-sample processing). + hidden_list = [] + logits_list = [] + + for i in range(bsz): + # Determine actual (non-padded) length for this sample + if attention_mask is not None: + actual_len = int(attention_mask[i].sum().item()) + else: + actual_len = seq_len + + # Extract the unpadded portion + single_ids = input_ids[i : i + 1, :actual_len] + + with torch.no_grad(): + outputs = self.model( + single_ids, + output_hidden_states=True, + use_cache=False, # AngelSlim DFlash: no KV-cache during training + ) + + # Extract auxiliary hidden states for this sample + # h shape: [1, actual_len, D*num_layers] + h = self._extract_auxiliary_hidden_states(outputs.hidden_states, aux_layer_ids) + + # Pad back to seq_len to maintain batch shape + if actual_len < seq_len: + pad_size = seq_len - actual_len + # Pad seq dim only (last-but-one dim); hidden dim is untouched + h = torch.nn.functional.pad(h, (0, 0, 0, pad_size)) + logits_padded = torch.nn.functional.pad(outputs.logits, (0, 0, 0, pad_size)) + else: + logits_padded = outputs.logits + + hidden_list.append(h) + logits_list.append(logits_padded) + + # Stack back to batch + hidden_states = torch.cat(hidden_list, dim=0) # [B, seq_len, D*num_layers] + logits = torch.cat(logits_list, dim=0) # [B, seq_len, vocab] + + return hidden_states, logits.to(input_ids.device) def get_aux_and_target_hiddens( self, diff --git a/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py b/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py index e2ff3741..ffe55eae 100755 --- a/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py +++ b/angelslim/compressor/speculative/train/trainer/online_dflash_trainer.py @@ -206,6 +206,256 @@ def _load_file_content( self.lm_head.weight.data.copy_(tensor) +class _FP32StateAdamW(torch.optim.Optimizer): + """AdamW with fp32 master weights (AngelSlim DFlash optimizer). + + Maintains fp32 master copies of all parameters (in optimizer state). + On each step: + 1. Cast bf16 gradients to fp32. + 2. Clip fp32 grad norm. + 3. Adam update on fp32 master weights. + 4. Copy fp32 master -> bf16 model params. + + Key properties: + * Accumulation in fp32 (no precision loss from bf16 quantization). + * Grad clipping on fp32 gradients. + * Only the final copy-back introduces bf16 quantization. + + Compatible with FSDP + accelerate + HF Trainer (operates on the SAME + parameter objects required for FSDP state_dict). + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=1.0, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.max_grad_norm = max_grad_norm + super().__init__(params, defaults) + + # Eagerly initialize all master parameters at construction so all + # ranks start from synchronized bf16 params before any training step. + with torch.no_grad(): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["master_param"] = p.data.detach().clone().to(torch.float32) + + @torch.no_grad() + def step(self, closure=None): + """Full fp32 master-weight update step (AngelSlim DFlash optimizer).""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Phase 1: Cast all bf16 grads to fp32 (kept temporarily in state). + all_fp32_grads = [] + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + + # Ensure states are fp32 (handles resume from checkpoint). + if state["exp_avg"].dtype != torch.float32: + state["exp_avg"] = state["exp_avg"].to(torch.float32) + if state["exp_avg_sq"].dtype != torch.float32: + state["exp_avg_sq"] = state["exp_avg_sq"].to(torch.float32) + if state["master_param"].dtype != torch.float32: + state["master_param"] = state["master_param"].to(torch.float32) + + fp32_grad = p.grad.detach().to(torch.float32) + state["_fp32_grad"] = fp32_grad + all_fp32_grads.append(fp32_grad) + + # Phase 2: Clip fp32 grad norm. + # Manual clipping because all_fp32_grads holds plain tensors (not Parameters). + # In FSDP SHARD_GRAD_OP + use_orig_params=True, p.grad is the full + # all-reduced gradient (same on all ranks), so per-rank clipping is correct. + if self.max_grad_norm > 0 and all_fp32_grads: + total_norm_sq = sum(g.norm().pow(2) for g in all_fp32_grads) + total_norm = total_norm_sq.sqrt() + clip_coef = self.max_grad_norm / (total_norm + 1e-6) + clip_coef_clamped = min(clip_coef.item(), 1.0) + if clip_coef_clamped < 1.0: + for g in all_fp32_grads: + g.mul_(clip_coef_clamped) + + # Phase 3: Adam update on fp32 master weights, then copy back to bf16. + for group in self.param_groups: + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group["weight_decay"] + + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + grad = state.pop("_fp32_grad") + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + master_param = state["master_param"] + + state["step"] += 1 + step_t = state["step"].item() + + # Decoupled weight decay on fp32 master (AdamW style). + if weight_decay != 0: + master_param.mul_(1.0 - lr * weight_decay) + + # Adam update in fp32. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + bias_correction1 = 1 - beta1**step_t + bias_correction2 = 1 - beta2**step_t + + step_size = lr / bias_correction1 + denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(eps) + + master_param.addcdiv_(exp_avg, denom, value=-step_size) + + # Copy fp32 master -> bf16 model param (only quantization point). + p.data.copy_(master_param.to(p.dtype)) + p.grad = None + + return loss + + +class _FP32MasterWeightOptimizer(torch.optim.Optimizer): + """Thin wrapper around any torch optimizer that maintains fp32 master weights. + + AngelSlim DFlash fp32 master-weight pattern: + 1. At construction, clone bf16 model params -> fp32 master copies. + 2. On every step(): + a. Copy bf16 grads -> fp32 master grads (cast up), clear bf16 grads. + b. Clip fp32 grad norm. + c. Run inner optimizer step on fp32 params. + d. Copy updated fp32 values -> bf16 model params (cast down). + + Used as ``self.optimizer`` inside HF Trainer for the DDP code path. + HF Trainer calls ``self.optimizer.step()`` / ``self.optimizer.zero_grad()`` + directly, so placing the sync logic inside the optimizer is the only + reliable way to ensure it actually runs. + + Inherits from torch.optim.Optimizer so isinstance checks in HF Trainer's + lr_scheduler creation (LambdaLR.__init__) pass correctly. + """ + + def __init__( + self, + bf16_params: List[torch.Tensor], + inner_optimizer: torch.optim.Optimizer, + max_grad_norm: float = 1.0, + ): + self._bf16_params = bf16_params + + # Build fp32 master copies and replace the optimizer's param groups. + self._fp32_params: List[torch.Tensor] = [ + p.detach().clone().to(torch.float32).requires_grad_(True) for p in bf16_params + ] + assert len(inner_optimizer.param_groups) == 1, ( + "_FP32MasterWeightOptimizer expects a single param group; " + "extend if/when multiple groups (e.g. LoRA, lr split) are needed." + ) + inner_optimizer.param_groups[0]["params"] = self._fp32_params + # Re-initialise state dict for the new param objects. + from collections import defaultdict + + inner_optimizer.state = defaultdict(dict) + + self._inner = inner_optimizer + self.max_grad_norm = max_grad_norm + + # Call torch.optim.Optimizer.__init__ so that isinstance(self, Optimizer) + # returns True. The _initializing flag prevents add_param_group from + # delegating to self._inner during super().__init__ (which would + # create a duplicate param group in the inner optimizer). + self._initializing = True + super().__init__(self._fp32_params, inner_optimizer.defaults) + self._initializing = False + + # Redirect param_groups and state to inner optimizer's versions so + # lr_scheduler / lr logging always see the correct param groups. + self.param_groups = self._inner.param_groups + self.state = self._inner.state + + # ------------------------------------------------------------------ # + # Core step / zero_grad — called directly by HF Trainer # + # ------------------------------------------------------------------ # + + def step(self, closure=None): + """Full fp32 master-weight update step.""" + with torch.no_grad(): + # (a) Copy bf16 grads -> fp32 master grads. + for bf16_p, fp32_p in zip(self._bf16_params, self._fp32_params): + if bf16_p.grad is not None: + fp32_p.grad = bf16_p.grad.detach().to(torch.float32) + bf16_p.grad = None + else: + fp32_p.grad = None + + # (b) Clip fp32 grad norm. + if self.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(self._fp32_params, self.max_grad_norm) + + # (c) Optimizer step on fp32 params. + loss = self._inner.step(closure) + + # (d) Copy fp32 -> bf16 model params. + with torch.no_grad(): + for bf16_p, fp32_p in zip(self._bf16_params, self._fp32_params): + bf16_p.data.copy_(fp32_p.data.to(bf16_p.dtype)) + + return loss + + def zero_grad(self, set_to_none: bool = True): + """Zero gradients on both bf16 model params and fp32 master params.""" + for bf16_p in self._bf16_params: + if set_to_none: + bf16_p.grad = None + elif bf16_p.grad is not None: + bf16_p.grad.zero_() + for fp32_p in self._fp32_params: + if set_to_none: + fp32_p.grad = None + elif fp32_p.grad is not None: + fp32_p.grad.zero_() + + # ------------------------------------------------------------------ # + # Delegate everything else to the inner optimizer # + # ------------------------------------------------------------------ # + + def state_dict(self): + return self._inner.state_dict() + + def load_state_dict(self, state_dict): + return self._inner.load_state_dict(state_dict) + + def add_param_group(self, param_group): + # During super().__init__ the inner optimizer is not yet assigned, + # so fall back to the default Optimizer behaviour. + if getattr(self, "_initializing", True): + return super().add_param_group(param_group) + return self._inner.add_param_group(param_group) + + def __repr__(self): + return f"_FP32MasterWeightOptimizer({self._inner})" + + @Eagle3TrainerFactory.register("online", "DFlash") class OnlineDFlashTrainer(Eagle3Trainer): """Online DFlash Trainer for speculative decoding training. @@ -244,12 +494,30 @@ def __init__( self.block_size = getattr(draft_model_config, "block_size", 16) self.num_anchors = getattr(draft_model_config, "num_anchors", 512) self.loss_decay_gamma = getattr(draft_model_config, "loss_decay_gamma", None) + # Gamma warmup: gradually increase loss_decay_gamma per epoch + # (AngelSlim DFlash gamma-warmup schedule). + self._gamma_init = self.loss_decay_gamma + self.gamma_warmup = getattr(draft_model_config, "gamma_warmup", False) + self._gamma_step = getattr(draft_model_config, "gamma_warmup_step", 0.5) self.attention_backend = getattr(draft_model_config, "attention_backend", "flex_attention") self.mask_token_id = dflash_config.get( "mask_token_id", getattr(draft_model_config, "mask_token_id", None), ) + # Sync _attn_implementation on the draft model so its attention layers + # dispatch to the correct backend (eager vs flex_attention vs sdpa). + if self.attention_backend == "eager": + draft_model.config._attn_implementation = "eager" + elif self.attention_backend == "flex_attention": + draft_model.config._attn_implementation = "flex_attention" + else: + draft_model.config._attn_implementation = self.attention_backend + + # fp32 master weights optimizer — set by create_optimizer() (DDP path). + # FSDP path uses _FP32StateAdamW directly as self.optimizer instead. + self._fp32_optimizer: Optional["_FP32MasterWeightOptimizer"] = None + # Load target model's lm_head and embed_tokens # In offline mode target_model may be None; fall back to config path. target_model_path = None @@ -279,6 +547,165 @@ def __init__( "or target_model.model_path for DFlash training." ) + def create_optimizer(self, model=None): + """Create optimizer for DFlash training. + + Three branches: + * DeepSpeed: defer to HF Trainer's default optimizer creation. + * FSDP: AdamW with fp32 optimizer states (``_FP32StateAdamW``), + using the AngelSlim DFlash fp32-master pattern. Critical because + bf16 momentum and variance only have 7-bit mantissa, which causes + training quality degradation after a few thousand steps. + * DDP / single GPU: ``_FP32MasterWeightOptimizer`` wrapping AdamW for + fp32 master weight updates. + """ + if self.is_deepspeed_enabled: + return super().create_optimizer(model) + + if self.is_fsdp_enabled: + args = self.args + param_groups = [{"params": [p for p in self.model.parameters() if p.requires_grad]}] + optimizer = _FP32StateAdamW( + param_groups, + lr=args.learning_rate, + betas=( + getattr(args, "adam_beta1", 0.9), + getattr(args, "adam_beta2", 0.999), + ), + eps=getattr(args, "adam_epsilon", 1e-8), + weight_decay=args.weight_decay, + max_grad_norm=args.max_grad_norm, + ) + self.optimizer = optimizer + return self.optimizer + + bf16_params: List[torch.Tensor] = [p for p in self.model.parameters() if p.requires_grad] + if not bf16_params: + return super().create_optimizer(model) + + from torch.optim import AdamW + + args = self.args + inner_optimizer = AdamW( + # Placeholder — _FP32MasterWeightOptimizer will replace param_groups + # with fp32 copies immediately after construction. + bf16_params, + lr=args.learning_rate, + betas=( + getattr(args, "adam_beta1", 0.9), + getattr(args, "adam_beta2", 0.999), + ), + eps=getattr(args, "adam_epsilon", 1e-8), + weight_decay=args.weight_decay, + ) + + fp32_opt = _FP32MasterWeightOptimizer( + bf16_params=bf16_params, + inner_optimizer=inner_optimizer, + max_grad_norm=args.max_grad_norm, + ) + self._fp32_optimizer = fp32_opt + self.optimizer = fp32_opt + return self.optimizer + + def create_scheduler(self, num_training_steps: int, optimizer=None): + """Create LR scheduler: AngelSlim DFlash CosineAnnealingWarmupLR. + + AngelSlim warmup formula: lr = base_lr * (step + 1) / warmup_steps + HF Trainer warmup formula: lr = base_lr * step / warmup_steps + + The +1 offset means step 0 yields lr = base_lr / warmup_steps instead + of 0. After warmup, both use identical cosine annealing. + """ + import math + + from torch.optim.lr_scheduler import LambdaLR + + if optimizer is None: + optimizer = self.optimizer + + warmup_steps = self.args.get_warmup_steps(num_training_steps) + + def angelslim_cosine_schedule(current_step: int) -> float: + """LR multiplier for AngelSlim DFlash CosineAnnealingWarmupLR.""" + if current_step < warmup_steps: + # AngelSlim: (last_epoch + 1) / warmup_epochs * base_lr + # After N step() calls last_epoch = N, so first lr = 1/warmup_steps. + return float(current_step + 1) / float(max(1, warmup_steps)) + # Cosine decay phase — identical to HF. + progress = float(current_step - warmup_steps) / float( + max(1, num_training_steps - warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + self.lr_scheduler = LambdaLR(optimizer, angelslim_cosine_schedule) + return self.lr_scheduler + + def _clip_grad_norm(self, *args, **kwargs): + """Skip HF Trainer's built-in grad clipping when running our fp32 optimizers. + + Both ``_FP32MasterWeightOptimizer`` (DDP) and ``_FP32StateAdamW`` (FSDP) + clip gradients internally on fp32 grads (AngelSlim DFlash fp32-master + clipping). HF Trainer's Accelerator-based clip_grad_norm_ would + otherwise run on bf16 grads (incorrect precision for clipping), causing + DOUBLE CLIPPING that significantly slows down training. + """ + if self._fp32_optimizer is not None: + # DDP path — clipped inside _FP32MasterWeightOptimizer.step() + return torch.tensor(0.0) + + # FSDP path: _FP32StateAdamW clips internally on fp32 grads. + # self.optimizer may be wrapped by AcceleratedOptimizer, so unwrap once. + optimizer = self.optimizer + if hasattr(optimizer, "optimizer"): + optimizer = optimizer.optimizer + if isinstance(optimizer, _FP32StateAdamW): + return torch.tensor(0.0) + + return super()._clip_grad_norm(*args, **kwargs) + + def save_optimizer_and_scheduler(self, output_dir, **kwargs): + """Override to handle fp32 master weight optimizer with FSDP. + + FSDP's built-in optim_state_dict() cannot handle our custom fp32 + master-weight optimizers because their fp32 params are not registered + in the FSDP module's parameter graph. Save optimizer/scheduler state + directly instead. + """ + self._save_optimizer_and_scheduler(output_dir) + + def _save_optimizer_and_scheduler(self, output_dir): + """Bypass FSDP's optim_state_dict for our custom fp32 optimizers.""" + optimizer = self.optimizer + if hasattr(optimizer, "optimizer"): + # Unwrap AcceleratedOptimizer + optimizer = optimizer.optimizer + + if isinstance(optimizer, (_FP32StateAdamW, _FP32MasterWeightOptimizer)): + if self.args.should_save: + torch.save( + optimizer.state_dict(), + os.path.join(output_dir, "optimizer.pt"), + ) + if self.lr_scheduler is not None: + torch.save( + self.lr_scheduler.state_dict(), + os.path.join(output_dir, "scheduler.pt"), + ) + else: + super()._save_optimizer_and_scheduler(output_dir) + + def _update_gamma_warmup(self): + """Update loss_decay_gamma: gamma = gamma_init + step * epoch. + + AngelSlim DFlash gamma-warmup schedule: + current_gamma = loss_decay_gamma + step * float(epoch) + """ + if not self.gamma_warmup or self._gamma_init is None: + return + current_epoch = int(self.state.epoch) if hasattr(self.state, "epoch") else 0 + self.loss_decay_gamma = self._gamma_init + self._gamma_step * float(current_epoch) + def prepare_data_for_draft_model(self, inputs): """Prepare data for DFlash training. @@ -501,6 +928,9 @@ def compute_loss( Unlike Eagle3's iterative multi-step loss, DFlash computes a single block-parallel cross-entropy loss over all sampled anchor positions. """ + # Update gamma if warmup is enabled (no-op when gamma_warmup=False) + self._update_gamma_warmup() + data = self.prepare_data_for_draft_model(inputs) loss, accuracy = self._compute_dflash_loss_and_accuracy( diff --git a/configs/fsdp_config.json b/configs/fsdp_config.json new file mode 100644 index 00000000..9721e1ae --- /dev/null +++ b/configs/fsdp_config.json @@ -0,0 +1,6 @@ +{ + "fsdp_sharding_strategy": "SHARD_GRAD_OP", + "fsdp_auto_wrap_policy": "NO_WRAP", + "fsdp_backward_prefetch": "backward_pre", + "fsdp_use_orig_params": true +} diff --git a/configs/qwen3_dflare.json b/configs/qwen3_dflare.json new file mode 100644 index 00000000..5378f418 --- /dev/null +++ b/configs/qwen3_dflare.json @@ -0,0 +1,60 @@ +{ + "architectures": [ + "QwenDFlareDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 33 + ] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 7, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 7, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + + "num_anchors": 512, + "loss_decay_gamma": 7.0, + "attention_backend": "flex_attention" +} diff --git a/docs/source/assets/dflare/intro.png b/docs/source/assets/dflare/intro.png new file mode 100644 index 00000000..38d38f11 Binary files /dev/null and b/docs/source/assets/dflare/intro.png differ diff --git a/docs/source/assets/dflare/speedup.png b/docs/source/assets/dflare/speedup.png new file mode 100644 index 00000000..860cb7f1 Binary files /dev/null and b/docs/source/assets/dflare/speedup.png differ diff --git a/docs/source/features/speculative_decoding/dflare.md b/docs/source/features/speculative_decoding/dflare.md new file mode 100644 index 00000000..596c2e62 --- /dev/null +++ b/docs/source/features/speculative_decoding/dflare.md @@ -0,0 +1,116 @@ +# DFlare + +**DFlare** is a block-diffusion speculative decoding framework that accelerates large language model inference by predicting an entire block of tokens in one shot for the target model to verify in parallel. It removes the narrow conditioning bottleneck of the prior state-of-the-art DFlash through a lightweight **layer-wise fusion** mechanism: each draft layer attends to its own learnable combination of a broad set of target layers at negligible overhead, simultaneously injecting richer target knowledge and giving every draft layer a distinct input. Combined with training-data scaling, this enhanced per-layer expressiveness allows the draft model to scale to deeper architectures with consistent gains, achieving up to **5.52× end-to-end speedup** without compromising output quality. + +This repository contains the official implementation and resources for the paper: **DFLARE: Scaling Up Draft Capacity for Block Diffusion Speculative Decoding**. + + +:::{image} /assets/dflare/intro.png +:alt: An overview of the DFlare framework. +::: + +--- + +## 🚀 Abstract + +Block diffusion speculative decoding accelerates LLM inference by predicting all tokens within a block simultaneously for the target model to verify in parallel. Predicting an entire block at once requires a sufficiently capable draft model and effective utilization of the target model's internal knowledge. However, the state-of-the-art method DFlash constrains all draft layers to share a single fused representation derived from only a few target layers, limiting per-layer expressiveness and hindering further scaling of draft capacity. We present **DFLARE**, which flares out the narrow conditioning bottleneck of DFlash through a lightweight layer-wise fusion mechanism: each draft layer attends to its own learnable combination of a broad set of target layers at negligible overhead, simultaneously injecting richer target knowledge and providing every draft layer with a distinct input. This enhanced per-layer expressiveness enables scaling the draft model to deeper architectures with consistent gains. We further scale training data from 800K to 2.4M samples to fully exploit the enlarged capacity. On six benchmarks spanning mathematical reasoning, code generation, and conversation, DFLARE attains average wall-clock speedups of **5.52× on Qwen3-4B**, **5.46× on Qwen3-8B**, and **3.91× on GPT-OSS-20B**, improving over DFlash by roughly 11%, 8%, and 5% respectively. + + +## ✨ Key Highlights + +- **Layer-wise Fusion for Richer Conditioning**: Replaces DFlash's single fused representation with a lightweight mechanism in which each draft layer attends to its own learnable combination of a broad set of target layers, removing the conditioning bottleneck at negligible overhead. +- **Scalable Draft Capacity**: The enriched per-layer expressiveness lets the draft model scale to deeper architectures with consistent gains, complemented by scaling training data from 800K to 2.4M samples to fully exploit the enlarged capacity. +- **Substantial End-to-End Speedups**: Across six benchmarks covering mathematical reasoning, code generation, and conversation, DFlare delivers average wall-clock speedups of 5.52× on Qwen3-4B, 5.46× on Qwen3-8B, and 3.91× on GPT-OSS-20B — roughly 11%, 8%, and 5% over DFlash respectively. + + +## ⚡ Quick Start + +### Training + +DFlare reuses the DFlash training pipeline and selects the layer-wise fusion architecture via `--draft_arch dflare`. Two entry points are provided: + +**Online training** (recommended) — runs the target model on the fly to produce hidden states each step. No data pre-generation step needed. + +```shell +export TARGET_MODEL_PATH=/path/to/Qwen3-4B +export TRAIN_DATA_PATH=/path/to/train.jsonl +export OUTPUT_DIR=/path/to/output + +bash scripts/speculative/run_dflare_online.sh 8 flex_attention +``` + +**Offline training** — trains from pre-computed hidden-state `.ckpt` files. First generate the cache with `scripts/speculative/generate_dflash_data.sh` using a DFlare-compatible draft config, then: + +```shell +export TARGET_MODEL_PATH=/path/to/Qwen3-4B +export TRAIN_HIDDEN_PATH=/path/to/hidden_cache +export OUTPUT_DIR=/path/to/output + +bash scripts/speculative/run_dflare_offline.sh 8 flex_attention +``` + +Both entries use the same defaults: `block_size=16`, `num_anchors=512`, `lr=6e-4`, cosine schedule with 4% warmup, `max_length=3072`, FSDP `shard_grad_op` with FP32 master-weights optimizer, and `flash_attention_2` for the target model. The default draft model config is `configs/qwen3_dflare.json`. + +### Inference and Evaluation + +To benchmark a trained DFlare draft model on tasks such as GSM8K, MT-Bench, MATH-500, and HumanEval, use `tools/dflash_benchmark.py`. The script supports both DFlash and DFlare draft architectures via the `--draft-arch` flag — for DFlare set `--draft-arch dflare`. It loads the matching `QwenDFlareDraftModel` class, runs block-parallel speculative decoding (block-size proposal from the draft + parallel target verification + longest-prefix accept), and reports decoding speedup, average acceptance length, and the per-block acceptance-length histogram. + +**Single-GPU evaluation:** + +```shell +python tools/dflash_benchmark.py \ + --model-name-or-path /path/to/Qwen3-4B \ + --draft-name-or-path /path/to/dflare_checkpoint \ + --draft-arch dflare \ + --dataset gsm8k \ + --max-samples 128 \ + --max-new-tokens 2048 \ + --temperature 0.0 \ + --block-size 16 +``` + +**Multi-GPU evaluation** (workload is sharded across ranks; results are gathered to rank 0): + +```shell +torchrun --nproc_per_node=8 --master_port=29600 \ + tools/dflash_benchmark.py \ + --model-name-or-path /path/to/Qwen3-4B \ + --draft-name-or-path /path/to/dflare_checkpoint \ + --draft-arch dflare \ + --dataset gsm8k \ + --max-samples 128 \ + --max-new-tokens 2048 \ + --temperature 0.0 \ + --block-size 16 +``` + +Notes: + +- `--block-size` is optional; if omitted, the script reads `block_size` directly from the loaded draft checkpoint's config. +- The script runs each prompt twice — once with `block_size=1` (vanilla AR decoding) and once with the speculative `block_size` — so the reported `Decoding speedup` is a self-contained ratio. No external baseline run is required. +- Both target and draft are loaded in `bfloat16` with `flash_attention_2` when `flash-attn` is installed (otherwise it falls back to PyTorch SDPA, which reduces wall-clock speedup but does not affect acceptance length). +- Supported datasets out of the box: `gsm8k`, `math500`, `aime24`, `aime25`, `alpaca`, `mt-bench`, `humaneval`, `mbpp`, `lbpp`, `swe-bench`, `livecodebench`. +- To compare DFlash and DFlare on the same checkpoint format, switch `--draft-arch dflash` and point `--draft-name-or-path` to a DFlash checkpoint — the rest of the command stays identical. + + +## 📈 Results + +We evaluate DFlare on six benchmarks spanning mathematical reasoning (GSM8K, MATH-500, AIME), code generation (HumanEval, MBPP, LiveCodeBench), and open-domain conversation (MT-Bench, Alpaca), against DFlash and EAGLE-3 baselines on Qwen3-4B, Qwen3-8B, and GPT-OSS-20B target models. + +:::{image} /assets/dflare/speedup.png +:alt: DFlare end-to-end speedup vs DFlash and EAGLE-3 across six benchmarks. +::: + + +## 📜 Citation + +If you find our work useful in your research, please consider citing our paper: + +```bibtex +@article{DFlare2026, + title={DFlare: Scaling Up Draft Capacity for Block Diffusion Speculative Decoding}, + author={Jiebin Zhang and Zhenghan Yu and Song Liu and Eugene J. Yu and Zheng Li and Dawei Zhu and Jiangshan Duo and Weimin Xiong and Yifan Song and Guanghua Yu and Jianchen Zhu and Sujian Li}, + journal={arXiv preprint arXiv}, + year={2026} +} +``` \ No newline at end of file diff --git a/docs/source/features/speculative_decoding/index.md b/docs/source/features/speculative_decoding/index.md index b8b8b641..263d4631 100644 --- a/docs/source/features/speculative_decoding/index.md +++ b/docs/source/features/speculative_decoding/index.md @@ -9,4 +9,5 @@ eagle/index spec_exit dcut +dflare ::: diff --git a/scripts/speculative/run_dflare_offline.sh b/scripts/speculative/run_dflare_offline.sh new file mode 100755 index 00000000..f4a44d50 --- /dev/null +++ b/scripts/speculative/run_dflare_offline.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# ========================================================================== +# AngelSlim DFlare Offline Training — Fully Aligned Configuration +# ========================================================================== +# +# Trains a DFlare draft model from pre-computed hidden-state .ckpt files. +# DFlare is the enhanced DFlash variant with separate context/noise k/v +# projections and learnable per-layer fusion weights. Training-side logic +# is identical to DFlash, so this script reuses tools/train_dflash_offline.py +# and selects DFlare via --draft_arch. +# +# Prerequisite: run scripts/speculative/generate_dflash_data.sh with a +# DFlare-compatible draft config first to produce the .ckpt files at +# $TRAIN_HIDDEN_PATH. Specifically the .ckpt's hidden_states must contain +# the SAME number of target layers that the DFlare config expects (i.e. +# len(dflash_config.target_layer_ids) in qwen3_dflare.json). +# +# Enables all AngelSlim DFlash alignment features (same as the online entry, +# minus the on-the-fly target-model forward): +# +# - block_size: 16, num_anchors: 512. +# - batch_size: 2, lr: 6e-4, cosine schedule, warmup_ratio: 0.04. +# - max_length: 3072. +# - dataloader_drop_last=True (avoids FSDP shape mismatches on the +# trailing batch). +# - FP32 master weights optimizer (fp32 accumulation + fp32 grad clip; +# only the final copy-back introduces bf16 quantization). +# - FSDP shard_grad_op + auto_wrap (with configs/fsdp_config.json: +# NO_WRAP, use_orig_params=True). +# - loss_decay_gamma: 7 (fixed by default; pass --gamma_warmup to enable +# per-epoch increment via --gamma_warmup_step). +# +# Usage: +# bash scripts/speculative/run_dflare_offline.sh [NUM_GPUS] [ATTENTION_BACKEND] +# +# ========================================================================== + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +# Use local source code instead of installed site-packages +export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-} + +NUM_GPUS=${1:-8} +ATTENTION_BACKEND=${2:-flex_attention} + +# ========================================================================== +# Paths — set these before running (left empty by default for portability) +# ========================================================================== +TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-""} +DRAFT_CONFIG_PATH=${DRAFT_CONFIG_PATH:-"${ROOT_DIR}/configs/qwen3_dflare.json"} +TRAIN_HIDDEN_PATH=${TRAIN_HIDDEN_PATH:-""} +EVAL_HIDDEN_PATH=${EVAL_HIDDEN_PATH:-""} +OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/qwen3-4b-dflare-offline"} + +if [ -z "$TARGET_MODEL_PATH" ]; then + echo "[ERROR] TARGET_MODEL_PATH is empty. Export it to your local Qwen3 (or other) HF model dir." + exit 1 +fi +if [ -z "$TRAIN_HIDDEN_PATH" ]; then + echo "[ERROR] TRAIN_HIDDEN_PATH is empty. Set it to the directory holding " + echo " pre-computed .ckpt files (output of generate_dflash_data.sh" + echo " run with a DFlare-compatible draft config so target_layer_ids match)." + exit 1 +fi + +# ========================================================================== +# torch.compile / inductor kernel cache +# ========================================================================== +export TORCHINDUCTOR_CACHE_DIR=${TORCHINDUCTOR_CACHE_DIR:-${ROOT_DIR}/cache/compiled_kernels} + +# ========================================================================== +# WandB configuration +# ========================================================================== +export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflare"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"angelslim-qwen3-4b-dflare-offline-fp32master"} + +# ========================================================================== +# Multi-node configuration (optional) +# ========================================================================== +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-12347} + +if [ "$NNODES" -gt 1 ]; then + DISTRIBUTED_ARGS="--nproc_per_node $NUM_GPUS --nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT" + echo "[INFO] Multi-node training: nnodes=$NNODES, node_rank=$NODE_RANK, master=$MASTER_ADDR:$MASTER_PORT" +else + DISTRIBUTED_ARGS="--standalone --nproc_per_node $NUM_GPUS" + echo "[INFO] Single-node training: $NUM_GPUS GPUs" +fi + +# ========================================================================== +# NCCL multi-node communication (for H20 + RoCE 400Gbps); harmless on single node +# ========================================================================== +if [ "$NNODES" -gt 1 ]; then + export NCCL_IB_DISABLE=0 + export NCCL_SOCKET_IFNAME=bond1 + export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7,mlx5_bond_8 + export NCCL_IB_GID_INDEX=3 + export NCCL_IB_TIMEOUT=23 + export NCCL_IB_RETRY_CNT=7 + export NCCL_NET_GDR_LEVEL=2 + export NCCL_IB_QPS_PER_CONNECTION=4 + export NCCL_CROSS_NIC=1 + export NCCL_ALGO=Ring + export NCCL_PROTO=Simple + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NCCL_TIMEOUT=1800 +fi + +# Optional eval-data argument (only added if path provided) +EVAL_FLAGS="" +if [ -n "$EVAL_HIDDEN_PATH" ]; then + EVAL_FLAGS="--eval_hidden_path $EVAL_HIDDEN_PATH" +fi + +echo "[INFO] Draft config: $DRAFT_CONFIG_PATH" +echo "[INFO] Target model: $TARGET_MODEL_PATH" +echo "[INFO] Train hidden path: $TRAIN_HIDDEN_PATH" +echo "[INFO] Eval hidden path: ${EVAL_HIDDEN_PATH:-}" +echo "[INFO] Output dir: $OUTPUT_DIR" +echo "[INFO] Attention backend (draft): $ATTENTION_BACKEND" +echo "[INFO] Draft architecture: dflare (--draft_arch dflare)" +echo "[INFO] WandB project: $WANDB_PROJECT, run: $WANDB_RUN_NAME" +echo "" + +# ========================================================================== +# Launch training +# ========================================================================== +torchrun $DISTRIBUTED_ARGS \ + $ROOT_DIR/tools/train_dflash_offline.py \ + --target_model_name_or_path $TARGET_MODEL_PATH \ + --draft_model_config_path $DRAFT_CONFIG_PATH \ + --draft_arch dflare \ + --train_hidden_path $TRAIN_HIDDEN_PATH \ + $EVAL_FLAGS \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 12 \ + --per_device_train_batch_size 2 \ + --learning_rate 6e-4 \ + --warmup_ratio 0.04 \ + --max_grad_norm 1.0 \ + --model_max_length 3072 \ + --chat_template_type qwen3 \ + --attention_backend $ATTENTION_BACKEND \ + --block_size 16 \ + --num_anchors 512 \ + --loss_decay_gamma 7 \ + --logging_steps 50 \ + --save_strategy steps \ + --save_steps 5000 \ + --bf16 \ + --lr_scheduler_type cosine \ + --dataloader_drop_last \ + --fsdp "shard_grad_op auto_wrap" \ + --fsdp_config ${ROOT_DIR}/configs/fsdp_config.json \ + --report_to wandb \ + --wandb_project $WANDB_PROJECT \ + --wandb_run_name $WANDB_RUN_NAME diff --git a/scripts/speculative/run_dflare_online.sh b/scripts/speculative/run_dflare_online.sh new file mode 100755 index 00000000..c0ee041a --- /dev/null +++ b/scripts/speculative/run_dflare_online.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +# ========================================================================== +# AngelSlim DFlare Online Training — Fully Aligned Configuration +# ========================================================================== +# +# Recommended training entry for DFlare. DFlare is the enhanced DFlash +# variant with separate context/noise k/v projections and learnable +# per-layer fusion weights. Training-side logic is identical to DFlash, so +# all alignment features below apply unchanged: +# +# - loss_decay_gamma: 7 (fixed by default; pass --gamma_warmup to enable +# per-epoch increment via --gamma_warmup_step). +# - block_size: 16, num_anchors: 512. +# - batch_size: 2, lr: 6e-4, cosine schedule, warmup_ratio: 0.04. +# - max_length: 3072, num_epochs: 6. +# - num_proc: 64 for data preprocessing. +# - Target model uses flash_attention_2 (matches the inference kernel and +# avoids train/test attention-backend mismatch). +# - dataloader_drop_last=True (avoids FSDP shape mismatches on the +# trailing batch). +# - FP32 master weights optimizer (fp32 accumulation + fp32 grad clip; +# only the final copy-back introduces bf16 quantization). +# - FSDP shard_grad_op + auto_wrap (with configs/fsdp_config.json: +# NO_WRAP, use_orig_params=True). +# +# Usage: +# bash scripts/speculative/run_dflare_online.sh [NUM_GPUS] [ATTENTION_BACKEND] +# +# ========================================================================== + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +# Use local source code instead of installed site-packages +export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-} + +NUM_GPUS=${1:-8} +ATTENTION_BACKEND=${2:-flex_attention} + +# ========================================================================== +# Paths — set these before running (left empty by default for portability) +# ========================================================================== +TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-""} +DRAFT_CONFIG_PATH=${DRAFT_CONFIG_PATH:-"${ROOT_DIR}/configs/qwen3_dflare.json"} +TRAIN_DATA_PATH=${TRAIN_DATA_PATH:-""} +OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/qwen3-4b-dflare-aligned"} + +if [ -z "$TARGET_MODEL_PATH" ]; then + echo "[ERROR] TARGET_MODEL_PATH is empty. Export it to your local Qwen3 (or other) HF model dir." + exit 1 +fi +if [ -z "$TRAIN_DATA_PATH" ]; then + echo "[ERROR] TRAIN_DATA_PATH is empty. Export it to a JSON/JSONL conversation dataset file." + exit 1 +fi + +# ========================================================================== +# torch.compile / inductor kernel cache +# ========================================================================== +export TORCHINDUCTOR_CACHE_DIR=${TORCHINDUCTOR_CACHE_DIR:-${ROOT_DIR}/cache/compiled_kernels} + +# ========================================================================== +# Data preprocessing parallelism +# ========================================================================== +DATA_NUM_PROC=${DATA_NUM_PROC:-64} + +# ========================================================================== +# WandB configuration +# ========================================================================== +export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflare"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"angelslim-qwen3-4b-dflare-fp32master-aligned"} + +# ========================================================================== +# Multi-node configuration (optional) +# ========================================================================== +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-12347} + +if [ "$NNODES" -gt 1 ]; then + DISTRIBUTED_ARGS="--nproc_per_node $NUM_GPUS --nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT" + echo "[INFO] Multi-node training: nnodes=$NNODES, node_rank=$NODE_RANK, master=$MASTER_ADDR:$MASTER_PORT" +else + DISTRIBUTED_ARGS="--standalone --nproc_per_node $NUM_GPUS" + echo "[INFO] Single-node training: $NUM_GPUS GPUs" +fi + +# ========================================================================== +# NCCL multi-node communication (for H20 + RoCE 400Gbps); harmless on single node +# ========================================================================== +if [ "$NNODES" -gt 1 ]; then + export NCCL_IB_DISABLE=0 + export NCCL_SOCKET_IFNAME=bond1 + export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7,mlx5_bond_8 + export NCCL_IB_GID_INDEX=3 + export NCCL_IB_TIMEOUT=23 + export NCCL_IB_RETRY_CNT=7 + export NCCL_NET_GDR_LEVEL=2 + export NCCL_IB_QPS_PER_CONNECTION=4 + export NCCL_CROSS_NIC=1 + export NCCL_ALGO=Ring + export NCCL_PROTO=Simple + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NCCL_TIMEOUT=1800 +fi + +echo "[INFO] Draft config: $DRAFT_CONFIG_PATH" +echo "[INFO] Target model: $TARGET_MODEL_PATH" +echo "[INFO] Train data: $TRAIN_DATA_PATH" +echo "[INFO] Output dir: $OUTPUT_DIR" +echo "[INFO] Attention backend (draft): $ATTENTION_BACKEND" +echo "[INFO] Target model attn: flash_attention_2 (set in target_model_wrapper.py)" +echo "[INFO] Draft architecture: dflare (--draft_arch dflare)" +echo "[INFO] WandB project: $WANDB_PROJECT, run: $WANDB_RUN_NAME" +echo "" + +# ========================================================================== +# Launch training +# ========================================================================== +torchrun $DISTRIBUTED_ARGS \ + $ROOT_DIR/tools/train_dflash_online.py \ + --target_model_name_or_path $TARGET_MODEL_PATH \ + --draft_model_config_path $DRAFT_CONFIG_PATH \ + --draft_arch dflare \ + --train_data_path $TRAIN_DATA_PATH \ + --output_dir $OUTPUT_DIR \ + --modal_type DFlash \ + --training_mode online \ + --num_train_epochs 6 \ + --per_device_train_batch_size 2 \ + --learning_rate 6e-4 \ + --warmup_ratio 0.04 \ + --max_grad_norm 1.0 \ + --model_max_length 3072 \ + --chat_template_type qwen3 \ + --attention_backend $ATTENTION_BACKEND \ + --block_size 16 \ + --num_anchors 512 \ + --loss_decay_gamma 7 \ + --num_proc $DATA_NUM_PROC \ + --logging_steps 50 \ + --save_strategy steps \ + --save_steps 5000 \ + --bf16 \ + --lr_scheduler_type cosine \ + --dataloader_drop_last \ + --fsdp "shard_grad_op auto_wrap" \ + --fsdp_config ${ROOT_DIR}/configs/fsdp_config.json \ + --report_to wandb \ + --wandb_project $WANDB_PROJECT \ + --wandb_run_name $WANDB_RUN_NAME diff --git a/scripts/speculative/run_dflash_offline.sh b/scripts/speculative/run_dflash_offline.sh index 05524719..8cee87b1 100644 --- a/scripts/speculative/run_dflash_offline.sh +++ b/scripts/speculative/run_dflash_offline.sh @@ -1,38 +1,134 @@ #!/bin/bash -# ============================================================================= -# Step 2: Train DFlash draft model in OFFLINE mode. + +# ========================================================================== +# AngelSlim DFlash Offline Training — Fully Aligned Configuration +# ========================================================================== +# +# Trains a DFlash draft model from pre-computed hidden-state .ckpt files. +# Prerequisite: run scripts/speculative/generate_dflash_data.sh first to +# produce the .ckpt files at $TRAIN_HIDDEN_PATH. +# +# Enables all AngelSlim DFlash alignment features (same as the online entry, +# minus the on-the-fly target-model forward): # -# Prerequisites: -# Run generate_qwen3_dflash_data.sh first to produce the .ckpt files. +# - block_size: 16, num_anchors: 512. +# - batch_size: 2, lr: 6e-4, cosine schedule, warmup_ratio: 0.04. +# - max_length: 3072. +# - dataloader_drop_last=True (avoids FSDP shape mismatches on the +# trailing batch). +# - FP32 master weights optimizer (fp32 accumulation + fp32 grad clip; +# only the final copy-back introduces bf16 quantization). +# - FSDP shard_grad_op + auto_wrap (with configs/fsdp_config.json: +# NO_WRAP, use_orig_params=True). +# - loss_decay_gamma: 7 (fixed by default; pass --gamma_warmup to enable +# per-epoch increment via --gamma_warmup_step). # # Usage: -# bash scripts/speculative/run_qwen3_dflash_offline.sh [NUM_GPUS] -# ============================================================================= +# bash scripts/speculative/run_dflash_offline.sh [NUM_GPUS] [ATTENTION_BACKEND] +# +# ========================================================================== + +set -euo pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) # Use local source code instead of installed site-packages -export PYTHONPATH=$ROOT_DIR:$PYTHONPATH +export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-} NUM_GPUS=${1:-8} +ATTENTION_BACKEND=${2:-flex_attention} + +# ========================================================================== +# Paths — set these before running (left empty by default for portability) +# ========================================================================== +TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-""} +DRAFT_CONFIG_PATH=${DRAFT_CONFIG_PATH:-"${ROOT_DIR}/configs/qwen3_dflash.json"} +TRAIN_HIDDEN_PATH=${TRAIN_HIDDEN_PATH:-""} +EVAL_HIDDEN_PATH=${EVAL_HIDDEN_PATH:-""} +OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/qwen3-4b-dflash-offline"} -# ---- Paths -- modify these to match your environment ---- -TARGET_MODEL_PATH="" -TRAIN_HIDDEN_PATH="" -OUTPUT_DIR="${ROOT_DIR}/outputs/" +if [ -z "$TARGET_MODEL_PATH" ]; then + echo "[ERROR] TARGET_MODEL_PATH is empty. Export it to your local Qwen3 (or other) HF model dir." + exit 1 +fi +if [ -z "$TRAIN_HIDDEN_PATH" ]; then + echo "[ERROR] TRAIN_HIDDEN_PATH is empty. Set it to the directory holding " + echo " pre-computed .ckpt files (output of generate_dflash_data.sh)." + exit 1 +fi +# ========================================================================== +# torch.compile / inductor kernel cache +# ========================================================================== +export TORCHINDUCTOR_CACHE_DIR=${TORCHINDUCTOR_CACHE_DIR:-${ROOT_DIR}/cache/compiled_kernels} + +# ========================================================================== # WandB configuration +# ========================================================================== export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflash"} -WANDB_RUN_NAME=${WANDB_RUN_NAME:-"qwen3-4b-dflash-offline"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"angelslim-qwen3-4b-dflash-offline-fp32master"} + +# ========================================================================== +# Multi-node configuration (optional) +# ========================================================================== +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-12347} + +if [ "$NNODES" -gt 1 ]; then + DISTRIBUTED_ARGS="--nproc_per_node $NUM_GPUS --nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT" + echo "[INFO] Multi-node training: nnodes=$NNODES, node_rank=$NODE_RANK, master=$MASTER_ADDR:$MASTER_PORT" +else + DISTRIBUTED_ARGS="--standalone --nproc_per_node $NUM_GPUS" + echo "[INFO] Single-node training: $NUM_GPUS GPUs" +fi + +# ========================================================================== +# NCCL multi-node communication (for H20 + RoCE 400Gbps); harmless on single node +# ========================================================================== +if [ "$NNODES" -gt 1 ]; then + export NCCL_IB_DISABLE=0 + export NCCL_SOCKET_IFNAME=bond1 + export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7,mlx5_bond_8 + export NCCL_IB_GID_INDEX=3 + export NCCL_IB_TIMEOUT=23 + export NCCL_IB_RETRY_CNT=7 + export NCCL_NET_GDR_LEVEL=2 + export NCCL_IB_QPS_PER_CONNECTION=4 + export NCCL_CROSS_NIC=1 + export NCCL_ALGO=Ring + export NCCL_PROTO=Simple + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NCCL_TIMEOUT=1800 +fi + +# Optional eval-data argument (only added if path provided) +EVAL_FLAGS="" +if [ -n "$EVAL_HIDDEN_PATH" ]; then + EVAL_FLAGS="--eval_hidden_path $EVAL_HIDDEN_PATH" +fi + +echo "[INFO] Draft config: $DRAFT_CONFIG_PATH" +echo "[INFO] Target model: $TARGET_MODEL_PATH" +echo "[INFO] Train hidden path: $TRAIN_HIDDEN_PATH" +echo "[INFO] Eval hidden path: ${EVAL_HIDDEN_PATH:-}" +echo "[INFO] Output dir: $OUTPUT_DIR" +echo "[INFO] Attention backend (draft): $ATTENTION_BACKEND" +echo "[INFO] WandB project: $WANDB_PROJECT, run: $WANDB_RUN_NAME" +echo "" -torchrun \ - --standalone \ - --nproc_per_node $NUM_GPUS \ +# ========================================================================== +# Launch training +# ========================================================================== +torchrun $DISTRIBUTED_ARGS \ $ROOT_DIR/tools/train_dflash_offline.py \ --target_model_name_or_path $TARGET_MODEL_PATH \ - --draft_model_config_path $ROOT_DIR/configs/qwen3_dflash.json \ + --draft_model_config_path $DRAFT_CONFIG_PATH \ --train_hidden_path $TRAIN_HIDDEN_PATH \ + $EVAL_FLAGS \ --output_dir $OUTPUT_DIR \ --num_train_epochs 12 \ --per_device_train_batch_size 2 \ @@ -41,14 +137,18 @@ torchrun \ --max_grad_norm 1.0 \ --model_max_length 3072 \ --chat_template_type qwen3 \ - --attention_backend flex_attention \ + --attention_backend $ATTENTION_BACKEND \ --block_size 16 \ --num_anchors 512 \ - --loss_decay_gamma 7.0 \ + --loss_decay_gamma 7 \ --logging_steps 50 \ --save_strategy steps \ - --save_steps 2500 \ + --save_steps 5000 \ --bf16 \ --lr_scheduler_type cosine \ + --dataloader_drop_last \ + --fsdp "shard_grad_op auto_wrap" \ + --fsdp_config ${ROOT_DIR}/configs/fsdp_config.json \ --report_to wandb \ - --run_name $WANDB_RUN_NAME + --wandb_project $WANDB_PROJECT \ + --wandb_run_name $WANDB_RUN_NAME diff --git a/scripts/speculative/run_dflash_online.sh b/scripts/speculative/run_dflash_online.sh index 1d4a1397..62d3f4ae 100644 --- a/scripts/speculative/run_dflash_online.sh +++ b/scripts/speculative/run_dflash_online.sh @@ -1,34 +1,128 @@ #!/bin/bash -# DFlash Online Training Script for Qwen3 -# Usage: bash scripts/speculative/run_qwen3_dflash_online.sh [NUM_GPUS] [ATTENTION_BACKEND] +# ========================================================================== +# AngelSlim DFlash Online Training — Fully Aligned Configuration +# ========================================================================== +# +# Recommended training entry for DFlash. Enables all AngelSlim DFlash +# alignment features: +# +# - loss_decay_gamma: 7 (fixed by default; pass --gamma_warmup to enable +# per-epoch increment via --gamma_warmup_step). +# - block_size: 16, num_anchors: 512. +# - batch_size: 2, lr: 6e-4, cosine schedule, warmup_ratio: 0.04. +# - max_length: 3072, num_epochs: 6. +# - num_proc: 64 for data preprocessing. +# - Target model uses flash_attention_2 (matches the inference kernel and +# avoids train/test attention-backend mismatch). +# - dataloader_drop_last=True (avoids FSDP shape mismatches on the +# trailing batch). +# - FP32 master weights optimizer (fp32 accumulation + fp32 grad clip; +# only the final copy-back introduces bf16 quantization). +# - FSDP shard_grad_op + auto_wrap (with configs/fsdp_config.json: +# NO_WRAP, use_orig_params=True). +# +# Usage: +# bash scripts/speculative/run_dflash_online.sh [NUM_GPUS] [ATTENTION_BACKEND] +# +# ========================================================================== + +set -euo pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) ROOT_DIR=$(dirname $(dirname $SCRIPT_DIR)) # Use local source code instead of installed site-packages -export PYTHONPATH=$ROOT_DIR:$PYTHONPATH +export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-} NUM_GPUS=${1:-8} ATTENTION_BACKEND=${2:-flex_attention} -# Set paths - modify these to match your environment -TARGET_MODEL_PATH="" -TRAIN_DATA_PATH="" -OUTPUT_DIR="${ROOT_DIR}/outputs/" +# ========================================================================== +# Paths — set these before running (left empty by default for portability) +# ========================================================================== +TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-""} +DRAFT_CONFIG_PATH=${DRAFT_CONFIG_PATH:-"${ROOT_DIR}/configs/qwen3_dflash.json"} +TRAIN_DATA_PATH=${TRAIN_DATA_PATH:-""} +OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/qwen3-4b-dflash-online"} + +if [ -z "$TARGET_MODEL_PATH" ]; then + echo "[ERROR] TARGET_MODEL_PATH is empty. Export it to your local Qwen3 (or other) HF model dir." + exit 1 +fi +if [ -z "$TRAIN_DATA_PATH" ]; then + echo "[ERROR] TRAIN_DATA_PATH is empty. Export it to a JSON/JSONL conversation dataset file." + exit 1 +fi -export CONFIG_DIR=${ROOT_DIR}/angelslim/compressor/speculative/train/configs +# ========================================================================== +# torch.compile / inductor kernel cache +# ========================================================================== +export TORCHINDUCTOR_CACHE_DIR=${TORCHINDUCTOR_CACHE_DIR:-${ROOT_DIR}/cache/compiled_kernels} -# WandB configuration (mirrors SpecForge's --wandb-project / --wandb-name) +# ========================================================================== +# Data preprocessing parallelism +# ========================================================================== +DATA_NUM_PROC=${DATA_NUM_PROC:-64} + +# ========================================================================== +# WandB configuration +# ========================================================================== export WANDB_PROJECT=${WANDB_PROJECT:-"angelslim-qwen3-4b-dflash"} -WANDB_RUN_NAME=${WANDB_RUN_NAME:-"qwen3-4b-dflash"} +WANDB_RUN_NAME=${WANDB_RUN_NAME:-"angelslim-qwen3-4b-dflash-online-fp32master"} + +# ========================================================================== +# Multi-node configuration (optional) +# ========================================================================== +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-12347} -torchrun \ - --standalone \ - --nproc_per_node $NUM_GPUS \ +if [ "$NNODES" -gt 1 ]; then + DISTRIBUTED_ARGS="--nproc_per_node $NUM_GPUS --nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT" + echo "[INFO] Multi-node training: nnodes=$NNODES, node_rank=$NODE_RANK, master=$MASTER_ADDR:$MASTER_PORT" +else + DISTRIBUTED_ARGS="--standalone --nproc_per_node $NUM_GPUS" + echo "[INFO] Single-node training: $NUM_GPUS GPUs" +fi + +# ========================================================================== +# NCCL multi-node communication (for H20 + RoCE 400Gbps); harmless on single node +# ========================================================================== +if [ "$NNODES" -gt 1 ]; then + export NCCL_IB_DISABLE=0 + export NCCL_SOCKET_IFNAME=bond1 + export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7,mlx5_bond_8 + export NCCL_IB_GID_INDEX=3 + export NCCL_IB_TIMEOUT=23 + export NCCL_IB_RETRY_CNT=7 + export NCCL_NET_GDR_LEVEL=2 + export NCCL_IB_QPS_PER_CONNECTION=4 + export NCCL_CROSS_NIC=1 + export NCCL_ALGO=Ring + export NCCL_PROTO=Simple + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NCCL_TIMEOUT=1800 +fi + +echo "[INFO] Draft config: $DRAFT_CONFIG_PATH" +echo "[INFO] Target model: $TARGET_MODEL_PATH" +echo "[INFO] Train data: $TRAIN_DATA_PATH" +echo "[INFO] Output dir: $OUTPUT_DIR" +echo "[INFO] Attention backend (draft): $ATTENTION_BACKEND" +echo "[INFO] Target model attn: flash_attention_2 (set in target_model_wrapper.py)" +echo "[INFO] WandB project: $WANDB_PROJECT, run: $WANDB_RUN_NAME" +echo "" + +# ========================================================================== +# Launch training +# ========================================================================== +torchrun $DISTRIBUTED_ARGS \ $ROOT_DIR/tools/train_dflash_online.py \ --target_model_name_or_path $TARGET_MODEL_PATH \ - --draft_model_config_path $ROOT_DIR/configs/qwen3_dflash.json \ + --draft_model_config_path $DRAFT_CONFIG_PATH \ --train_data_path $TRAIN_DATA_PATH \ --output_dir $OUTPUT_DIR \ --modal_type DFlash \ @@ -43,12 +137,16 @@ torchrun \ --attention_backend $ATTENTION_BACKEND \ --block_size 16 \ --num_anchors 512 \ - --loss_decay_gamma 7.0 \ + --loss_decay_gamma 7 \ + --num_proc $DATA_NUM_PROC \ --logging_steps 50 \ --save_strategy steps \ - --save_steps 2500 \ + --save_steps 5000 \ --bf16 \ --lr_scheduler_type cosine \ + --dataloader_drop_last \ + --fsdp "shard_grad_op auto_wrap" \ + --fsdp_config ${ROOT_DIR}/configs/fsdp_config.json \ --report_to wandb \ - --run_name $WANDB_RUN_NAME - + --wandb_project $WANDB_PROJECT \ + --wandb_run_name $WANDB_RUN_NAME diff --git a/tools/dflash_benchmark.py b/tools/dflash_benchmark.py new file mode 100644 index 00000000..376311d1 --- /dev/null +++ b/tools/dflash_benchmark.py @@ -0,0 +1,526 @@ +"""DFlash / DFlare end-to-end speculative decoding benchmark. + +A self-contained evaluation entry point for AngelSlim's draft model classes. +Selects the draft architecture via ``--draft-arch``: + + --draft-arch dflash -> angelslim.compressor.speculative.train.models.draft + .qwen_dflash.QwenDFlashDraftModel + --draft-arch dflare -> angelslim.compressor.speculative.train.models.draft + .qwen_dflare.QwenDFlareDraftModel + +Reports decoding speedup vs single-token decoding and per-block acceptance +length distribution. Supports torchrun for multi-GPU sharded evaluation. + +Usage (single GPU):: + + python tools/dflash_benchmark.py \\ + --model-name-or-path /path/to/Qwen3-4B \\ + --draft-name-or-path /path/to/dflash_or_dflare_ckpt \\ + --draft-arch dflare \\ + --dataset gsm8k --max-samples 128 + +Usage (8 GPUs):: + + torchrun --nproc_per_node=8 --master_port=29600 \\ + tools/dflash_benchmark.py \\ + --model-name-or-path /path/to/Qwen3-4B \\ + --draft-name-or-path /path/to/dflare_ckpt \\ + --draft-arch dflare \\ + --dataset gsm8k --max-samples 128 +""" + +from __future__ import annotations + +import argparse +import os +import random +import time +import warnings +from itertools import chain +from types import SimpleNamespace +from typing import Any, List, Optional + +import numpy as np +import torch +from datasets import Features, Sequence, Value, load_dataset +from loguru import logger +from rich import print +from torch import distributed as torch_dist +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache + + +# --------------------------------------------------------------------------- +# Distributed helpers (small wrapper over torch.distributed; no extra package +# dependency on AngelSlim's side). +# --------------------------------------------------------------------------- +def _dist_init() -> None: + if "RANK" not in os.environ: + warnings.warn( + "Environment variable `RANK` is not set; running single-process.", + stacklevel=2, + ) + return + torch_dist.init_process_group(backend="nccl", init_method="env://") + + +def _dist_is_initialized() -> bool: + return torch_dist.is_initialized() + + +def _dist_size() -> int: + return int(os.environ.get("WORLD_SIZE", 1)) + + +def _dist_rank() -> int: + return int(os.environ.get("RANK", 0)) + + +def _dist_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + +def _dist_is_main() -> bool: + return _dist_rank() == 0 + + +def _dist_gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: + if not _dist_is_initialized(): + return [obj] + if _dist_is_main(): + objs: List[Any] = [None for _ in range(_dist_size())] + torch_dist.gather_object(obj, objs, dst=dst) + return objs + torch_dist.gather_object(obj, dst=dst) + return None + + +# --------------------------------------------------------------------------- +# Dataset loader. Each loaded item must expose a ``turns`` field that is a +# list of user messages (one entry per turn for multi-turn datasets like +# mt-bench). +# --------------------------------------------------------------------------- +def load_and_process_dataset(data_name: str): + if data_name == "gsm8k": + ds = load_dataset("openai/gsm8k", "main", split="test") + fmt = ( + "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + ) + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "math500": + ds = load_dataset("HuggingFaceH4/MATH-500", split="test") + fmt = ( + "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + ) + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "aime24": + ds = load_dataset("HuggingFaceH4/aime_2024", split="train") + fmt = ( + "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + ) + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "aime25": + ds = load_dataset("MathArena/aime_2025", split="train") + fmt = ( + "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + ) + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "alpaca": + ds = load_dataset("tatsu-lab/alpaca", split="train") + ds = ds.map( + lambda x: { + "formatted_input": ( + f"{x['instruction']}\n\nInput:\n{x['input']}" + if x["input"] + else x["instruction"] + ) + } + ) + return ds.map(lambda x: {"turns": [x["formatted_input"]]}) + + if data_name == "mt-bench": + ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + return ds.map(lambda x: {"turns": x["prompt"]}) + + if data_name == "humaneval": + ds = load_dataset("openai/openai_humaneval", split="test") + fmt = ( + "Write a solution to the following problem and make sure that it passes the tests:\n" + "```python\n{prompt}\n```" + ) + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "mbpp": + ds = load_dataset("google-research-datasets/mbpp", "sanitized", split="test") + return ds.map(lambda x: {"turns": [x["prompt"]]}) + + if data_name == "lbpp": + url = "https://huggingface.co/datasets/CohereLabs/lbpp/resolve/main/python/test.parquet" + ds = load_dataset("parquet", data_files={"test": url})["test"] + return ds.map(lambda x: {"turns": [x["instruction"]]}) + + if data_name == "swe-bench": + ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + fmt = "Problem Statement:\n{problem_statement}\nPlease fix the issue described above." + return ds.map(lambda x: {"turns": [fmt.format(**x)]}) + + if data_name == "livecodebench": + base = "https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/main/" + files = [ + "test.jsonl", + "test2.jsonl", + "test3.jsonl", + "test4.jsonl", + "test5.jsonl", + "test6.jsonl", + ] + ds = load_dataset("json", data_files={"test": [base + fn for fn in files]})["test"] + + def _fmt(doc): + sys = ( + "You are an expert Python programmer. You will be given a question " + "(problem specification) and will generate a correct Python program " + "that matches the specification and passes all tests. " + "You will NOT return anything except for the program" + ) + q = f"### Question:\n{doc['question_content']}" + if doc.get("starter_code"): + fmt_msg = "### Format: Use the following code structure:" + code = f"```python\n{doc['starter_code']}\n```" + else: + fmt_msg = "### Format: Write your code in the following format:" + code = "```python\n# YOUR CODE HERE\n```" + tail = "### Answer: (use the provided format with backticks)" + return f"{sys}\n\n{q}\n\n{fmt_msg}\n{code}\n\n{tail}" + + target_features = Features({"turns": Sequence(Value("large_string"))}) + return ds.map( + lambda x: {"turns": [_fmt(x)]}, + remove_columns=ds.column_names, + features=target_features, + ) + + raise ValueError(f"Unknown dataset: {data_name}") + + +# --------------------------------------------------------------------------- +# Draft architecture dispatch. +# --------------------------------------------------------------------------- +def _resolve_draft_arch(arch: str): + """Return (DraftModelClass, sample_fn, extract_context_feature_fn).""" + arch = arch.lower() + if arch == "dflash": + from angelslim.compressor.speculative.train.models.draft.qwen_dflash import ( + QwenDFlashDraftModel, + extract_context_feature, + sample, + ) + + return QwenDFlashDraftModel, sample, extract_context_feature + if arch == "dflare": + from angelslim.compressor.speculative.train.models.draft.qwen_dflare import ( + QwenDFlareDraftModel, + extract_context_feature, + sample, + ) + + return QwenDFlareDraftModel, sample, extract_context_feature + raise ValueError(f"--draft-arch must be one of {{dflash, dflare}}, got: {arch}") + + +# --------------------------------------------------------------------------- +# Speculative-decoding loop: block-parallel draft proposal, target +# verification, longest-prefix accept. +# --------------------------------------------------------------------------- +def cuda_time() -> float: + torch.cuda.synchronize() + return time.perf_counter() + + +@torch.inference_mode() +def dflash_generate( + model, + target, + input_ids: torch.Tensor, + mask_token_id: int, + max_new_tokens: int, + block_size: int, + stop_token_ids: list, + sample_fn, + extract_context_feature_fn, + temperature: float = 0.0, +) -> SimpleNamespace: + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + output_ids = torch.full( + (1, max_length + block_size), + mask_token_id, + dtype=torch.long, + device=model.device, + ) + position_ids = torch.arange(output_ids.shape[1], device=model.device).unsqueeze(0) + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + prefill_start = cuda_time() + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True if block_size > 1 else False, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = sample_fn(output.logits, temperature) + if block_size > 1: + target_hidden = extract_context_feature_fn(output.hidden_states, model.target_layer_ids) + + time_to_first_token = cuda_time() - prefill_start + + # Decode stage + decode_start = cuda_time() + start = input_ids.shape[1] + acceptance_lengths = [] + draft_prefill = True + + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + if block_size > 1: + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head( + model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[ + :, past_key_values_draft.get_seq_length() : start + block_size + ], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size + 1 :, :] + ) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample_fn(draft_logits) + if draft_prefill: + draft_prefill = False + decode_start = cuda_time() + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True if block_size > 1 else False, + ) + + posterior = sample_fn(output.logits, temperature) + acceptance_length = ( + (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() + ) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ + :, : acceptance_length + 1 + ] + output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] + + acceptance_lengths.append(acceptance_length + 1) + start += acceptance_length + 1 + past_key_values_target.crop(start) + if block_size > 1: + target_hidden = extract_context_feature_fn( + output.hidden_states, model.target_layer_ids + )[:, : acceptance_length + 1, :] + + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids + ): + break + + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != mask_token_id] + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=output_ids.device) + stop_indices = torch.isin(output_ids[0][num_input_tokens:], stop_tensor).nonzero( + as_tuple=True + )[0] + if stop_indices.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + stop_indices[0] + 1] + + num_output_tokens = output_ids.shape[1] - num_input_tokens + total_decode_time = cuda_time() - decode_start + time_per_output_token = total_decode_time / num_output_tokens + + return SimpleNamespace( + output_ids=output_ids, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + time_to_first_token=time_to_first_token, + time_per_output_token=time_per_output_token, + acceptance_lengths=acceptance_lengths, + ) + + +# --------------------------------------------------------------------------- +# Entry point. +# --------------------------------------------------------------------------- +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-name-or-path", type=str, required=True, help="Path or HF id of the target model." + ) + parser.add_argument( + "--draft-name-or-path", + type=str, + required=True, + help="Path of the trained DFlash/DFlare draft checkpoint.", + ) + parser.add_argument( + "--draft-arch", + type=str, + choices=["dflash", "dflare"], + required=True, + help="Which AngelSlim draft architecture to load.", + ) + parser.add_argument( + "--block-size", + type=int, + default=None, + help="Speculative block size. Defaults to draft model's config value.", + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Dataset name; see load_and_process_dataset() for the supported list.", + ) + parser.add_argument("--max-samples", type=int, default=None) + parser.add_argument("--max-new-tokens", type=int, default=16384) + parser.add_argument("--temperature", type=float, default=0.0) + args = parser.parse_args() + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + _dist_init() + torch.cuda.set_device(_dist_local_rank()) + device = torch.device(f"cuda:{_dist_local_rank()}") + + DraftModelCls, sample_fn, extract_context_feature_fn = _resolve_draft_arch(args.draft_arch) + + def has_flash_attn() -> bool: + try: + import flash_attn # noqa: F401 + + return True + except ImportError: + logger.warning( + "flash_attn is not installed; falling back to torch.sdpa. " + "End-to-end speedup will be lower." + ) + return False + + installed_flash_attn = has_flash_attn() + attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa" + + target = ( + AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + attn_implementation=attn_impl, + dtype=torch.bfloat16, + ) + .to(device) + .eval() + ) + + draft_model = ( + DraftModelCls.from_pretrained( + args.draft_name_or_path, + attn_implementation=attn_impl, + dtype=torch.bfloat16, + local_files_only=True, + ) + .to(device) + .eval() + ) + + block_size = args.block_size if args.block_size is not None else draft_model.block_size + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + dataset = load_and_process_dataset(args.dataset) + + if args.max_samples is not None and len(dataset) > args.max_samples: + dataset = dataset.shuffle(seed=0).select(range(args.max_samples)) + + responses = [] + indices = range(_dist_rank(), len(dataset), _dist_size()) + for idx in tqdm(indices, disable=not _dist_is_main()): + instance = dataset[idx] + messages = [] + for user_content in instance["turns"]: + messages.append({"role": "user", "content": user_content}) + input_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + input_ids = tokenizer.encode(input_text, return_tensors="pt").to(target.device) + + response = {} + for bs in [1, block_size]: + response[bs] = dflash_generate( + model=draft_model, + target=target, + input_ids=input_ids, + mask_token_id=draft_model.mask_token_id, + max_new_tokens=args.max_new_tokens, + block_size=bs, + stop_token_ids=[tokenizer.eos_token_id], + sample_fn=sample_fn, + extract_context_feature_fn=extract_context_feature_fn, + temperature=args.temperature, + ) + + spec_response = response[block_size] + generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens :] + output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + messages.append({"role": "assistant", "content": output_text}) + responses.append(response) + + if _dist_size() > 1: + gathered = _dist_gather(responses, dst=0) + if not _dist_is_main(): + return + responses = list(chain(*gathered)) + + if not responses: + return + + t1 = np.mean([r[1].time_per_output_token for r in responses]) + tb = np.mean([r[block_size].time_per_output_token for r in responses]) + print(f"[draft_arch={args.draft_arch}] Decoding speedup: {t1 / tb:.2f}") + + tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses]) + print(f"[draft_arch={args.draft_arch}] Average Acceptance length: {tau:.2f}") + + acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses])) + histogram = [ + acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1) + ] + print( + f"[draft_arch={args.draft_arch}] Acceptance length histogram: " + f"{[f'{x * 100:.1f}%' for x in histogram]}" + ) + + +if __name__ == "__main__": + main() diff --git a/tools/train_dflash_offline.py b/tools/train_dflash_offline.py index 4f8ad528..f75914c5 100755 --- a/tools/train_dflash_offline.py +++ b/tools/train_dflash_offline.py @@ -97,6 +97,17 @@ def parse_args(): m.add_argument("--trust_remote_code", action="store_true", default=True) m.add_argument("--embed_weight_key", type=str, default="model.embed_tokens.weight") m.add_argument("--lm_head_key", type=str, default="lm_head.weight") + m.add_argument( + "--draft_arch", + type=str, + default=None, + choices=["dflash", "dflare"], + help=( + "Override draft model architecture. If unset, uses the " + "'architectures' field from the draft_model_config JSON. " + "'dflash' -> QwenDFlashDraftModel, 'dflare' -> QwenDFlareDraftModel." + ), + ) # DFlash-specific (override values in config JSON) d = parser.add_argument_group("DFlash Arguments") @@ -160,6 +171,44 @@ def parse_args(): t.add_argument("--fp16", action="store_true") t.add_argument("--bf16", action="store_true") t.add_argument("--deepspeed", type=str, default=None) + t.add_argument( + "--fsdp", + type=str, + default="", + help="FSDP configuration string passed to TrainingArguments " + "(e.g. 'shard_grad_op auto_wrap'). Empty disables FSDP.", + ) + t.add_argument( + "--fsdp_config", + type=str, + default=None, + help="Path to FSDP config JSON file (consumed by TrainingArguments).", + ) + t.add_argument( + "--dataloader_drop_last", + action="store_true", + default=False, + help=( + "Drop last incomplete batch. Note: when using DFlash trainer this " + "is forced True internally to avoid FSDP shape mismatches on the " + "trailing batch." + ), + ) + t.add_argument( + "--gamma_warmup", + action="store_true", + default=False, + help=( + "Enable gamma warmup. When set, loss_decay_gamma is increased " + "per epoch as: gamma = loss_decay_gamma + gamma_warmup_step * epoch." + ), + ) + t.add_argument( + "--gamma_warmup_step", + type=float, + default=0.5, + help="Per-epoch increment for gamma warmup. Default 0.5.", + ) t.add_argument("--report_to", type=str, default="none") t.add_argument("--run_name", type=str, default=None) t.add_argument("--training_time_test_length", type=int, default=7) @@ -216,6 +265,22 @@ def train(): draft_model_config.embed_weight_key = args.embed_weight_key draft_model_config.trust_remote_code = args.trust_remote_code + # Optionally override draft architecture from CLI. Both DFlash and DFlare + # share the same Qwen3Config schema (block_size, dflash_config, etc.), so + # swapping the architectures field is sufficient to route create_draft_model + # to the desired class. + if args.draft_arch is not None: + arch_map = { + "dflash": "QwenDFlashDraftModel", + "dflare": "QwenDFlareDraftModel", + } + new_arch = arch_map[args.draft_arch] + rank0_print( + f"Overriding draft architecture: " + f"{getattr(draft_model_config, 'architectures', None)} -> [{new_arch}]" + ) + draft_model_config.architectures = [new_arch] + # Override DFlash params from CLI if specified if args.block_size is not None: draft_model_config.block_size = args.block_size @@ -223,6 +288,10 @@ def train(): draft_model_config.num_anchors = args.num_anchors if args.loss_decay_gamma is not None: draft_model_config.loss_decay_gamma = args.loss_decay_gamma + # Always propagate gamma_warmup flags to the draft model config so the + # trainer can pick them up regardless of CLI defaults. + draft_model_config.gamma_warmup = args.gamma_warmup + draft_model_config.gamma_warmup_step = args.gamma_warmup_step if args.attention_backend is not None: draft_model_config.attention_backend = args.attention_backend draft_model_config._attn_implementation = args.attention_backend @@ -264,7 +333,7 @@ def train(): # ------------------------------------------------------------------ # 4. TrainingArguments # ------------------------------------------------------------------ - training_args = transformers.TrainingArguments( + ta_kwargs = dict( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, @@ -288,8 +357,16 @@ def train(): report_to=args.report_to, run_name=args.run_name, deepspeed=args.deepspeed, + fsdp=args.fsdp, + # Force drop_last=True (AngelSlim default) to avoid FSDP shape + # mismatches on the trailing batch. + dataloader_drop_last=True, remove_unused_columns=False, ) + if args.fsdp_config: + ta_kwargs["fsdp_config"] = args.fsdp_config + + training_args = transformers.TrainingArguments(**ta_kwargs) # ------------------------------------------------------------------ # 5. Trainer -- use Eagle3TrainerFactory diff --git a/tools/train_dflash_online.py b/tools/train_dflash_online.py index 5a33a16d..1d2648a8 100755 --- a/tools/train_dflash_online.py +++ b/tools/train_dflash_online.py @@ -93,6 +93,17 @@ def parse_args(): default="model.embed_tokens.weight", help="Key for embedding weights in model config", ) + model_group.add_argument( + "--draft_arch", + type=str, + default=None, + choices=["dflash", "dflare"], + help=( + "Override draft model architecture. If unset, uses the " + "'architectures' field from the draft_model_config JSON. " + "'dflash' -> QwenDFlashDraftModel, 'dflare' -> QwenDFlareDraftModel." + ), + ) # DFlash-specific arguments dflash_group = parser.add_argument_group("DFlash Arguments") @@ -266,6 +277,45 @@ def parse_args(): ) training_group.add_argument("--fp16", action="store_true", help="Whether to use fp16 training") training_group.add_argument("--bf16", action="store_true", help="Whether to use bf16 training") + training_group.add_argument( + "--fsdp", + type=str, + default="", + help="FSDP configuration string passed to TrainingArguments " + "(e.g. 'shard_grad_op auto_wrap'). Empty disables FSDP.", + ) + training_group.add_argument( + "--fsdp_config", + type=str, + default=None, + help="Path to FSDP config JSON file (consumed by TrainingArguments).", + ) + training_group.add_argument( + "--dataloader_drop_last", + action="store_true", + default=False, + help=( + "Drop last incomplete batch. Note: when using DFlash trainer this " + "is forced True internally to match AngelSlim's drop_last=True " + "and avoid FSDP shape mismatches on the trailing batch." + ), + ) + training_group.add_argument( + "--gamma_warmup", + action="store_true", + default=False, + help=( + "Enable gamma warmup. When set, loss_decay_gamma is increased per " + "epoch as: gamma = loss_decay_gamma + gamma_warmup_step * epoch " + "(AngelSlim gamma warmup formula)." + ), + ) + training_group.add_argument( + "--gamma_warmup_step", + type=float, + default=0.5, + help="Per-epoch increment for gamma warmup. Default 0.5.", + ) training_group.add_argument( "--save_strategy", type=str, default="no", help="Save strategy for checkpoints" ) @@ -286,7 +336,7 @@ def parse_args(): help="The list of integrations to report the results and logs to (e.g. 'wandb')", ) - # WandB arguments (mirrors SpecForge's --wandb-project / --wandb-name) + # WandB arguments wandb_group = parser.add_argument_group("WandB Arguments") wandb_group.add_argument( "--wandb_project", @@ -307,7 +357,7 @@ def parse_args(): def _setup_wandb(args) -> None: """Set up WandB environment variables and initialize wandb run on rank 0. - Mirrors the --wandb-project / --wandb-name pattern from SpecForge. + Sets up WandB project / run name from CLI or env vars. Priority: CLI args > env vars > defaults. """ if args.report_to not in ("wandb", "all"): @@ -355,6 +405,22 @@ def train(): draft_model_config = DraftModelConfig.from_file(args.draft_model_config_path) target_model_type = getattr(draft_model_config, "target_model_type", None) + # Optionally override draft architecture from CLI. Both DFlash and DFlare + # share the same Qwen3Config schema (block_size, dflash_config, etc.), so + # swapping the architectures field is sufficient to route create_draft_model + # to the desired class via DraftModelFactory._get_model_class. + if args.draft_arch is not None: + arch_map = { + "dflash": "QwenDFlashDraftModel", + "dflare": "QwenDFlareDraftModel", + } + new_arch = arch_map[args.draft_arch] + rank0_print( + f"Overriding draft architecture: " + f"{getattr(draft_model_config, 'architectures', None)} -> [{new_arch}]" + ) + draft_model_config.architectures = [new_arch] + # Inject DFlash-specific config into the draft model config # so the trainer can access them draft_model_config.target_model_name_or_path = args.target_model_name_or_path @@ -368,6 +434,10 @@ def train(): draft_model_config.num_anchors = args.num_anchors if args.loss_decay_gamma is not None: draft_model_config.loss_decay_gamma = args.loss_decay_gamma + # Always propagate gamma_warmup flags to the draft model config so the + # trainer can pick them up regardless of CLI defaults. + draft_model_config.gamma_warmup = args.gamma_warmup + draft_model_config.gamma_warmup_step = args.gamma_warmup_step if args.attention_backend is not None: draft_model_config.attention_backend = args.attention_backend if args.mask_token_id is not None: @@ -442,6 +512,10 @@ def train(): "per_device_eval_batch_size": args.per_device_eval_batch_size, "gradient_accumulation_steps": args.gradient_accumulation_steps, "remove_unused_columns": False, + # Force drop_last=True (AngelSlim default) to avoid FSDP shape + # mismatches on the trailing batch. CLI --dataloader_drop_last is + # accepted for compatibility but currently overridden here. + "dataloader_drop_last": True, } optimizer_args = { @@ -475,7 +549,10 @@ def train(): distributed_args = { "deepspeed": args.deepspeed, + "fsdp": args.fsdp, } + if args.fsdp_config: + distributed_args["fsdp_config"] = args.fsdp_config training_args = transformers.TrainingArguments( **basic_args,