|
54 | 54 |
|
55 | 55 | import torch |
56 | 56 | import torch.nn.functional as F |
57 | | -from torch import nn |
58 | 57 | from transformers import PreTrainedModel |
59 | | -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
60 | 58 | from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config |
61 | | -from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: N814 |
62 | | -from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: N814 |
63 | | -from transformers.models.qwen3.modeling_qwen3 import ( |
64 | | - Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814 |
65 | | -) |
66 | | -from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half |
67 | 59 | from transformers.trainer_pt_utils import LabelSmoother |
68 | 60 | from transformers.utils import ModelOutput |
69 | 61 |
|
70 | 62 | from ..dflash.conversion import DFlashDMRegistry |
71 | 63 | from ..dflash.dflash_model import DFlashModel |
| 64 | +from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401 |
72 | 65 | from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS |
73 | 66 |
|
74 | 67 | logger = logging.getLogger(__name__) |
75 | 68 |
|
76 | 69 | __all__ = ["HFDFlashModel"] |
77 | 70 |
|
78 | 71 |
|
79 | | -def build_target_layer_ids(num_target_layers, num_draft_layers): |
80 | | - """Select layers uniformly from the target model for feature extraction.""" |
81 | | - if num_target_layers < num_draft_layers: |
82 | | - raise ValueError( |
83 | | - f"num_target_layers ({num_target_layers}) must be >= num_draft_layers ({num_draft_layers})" |
84 | | - ) |
85 | | - if num_draft_layers == 1: |
86 | | - return [num_target_layers // 2] |
87 | | - start = min(1, num_target_layers - 1) |
88 | | - end = max(start, num_target_layers - 3) |
89 | | - span = end - start |
90 | | - return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] |
91 | | - |
92 | | - |
93 | | -def apply_rotary_pos_emb(q, k, cos, sin): |
94 | | - """Apply RoPE. Q uses last q_len positions, K uses all positions.""" |
95 | | - cos = cos.unsqueeze(1) # [B, 1, seq, dim] |
96 | | - sin = sin.unsqueeze(1) |
97 | | - q_len = q.size(2) |
98 | | - q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) |
99 | | - k_embed = (k * cos) + (_rotate_half(k) * sin) |
100 | | - return q_embed, k_embed |
101 | | - |
102 | | - |
103 | | -class DFlashAttention(nn.Module): |
104 | | - """Attention with KV injection, using HF's attention dispatch.""" |
105 | | - |
106 | | - def __init__(self, config, layer_idx): |
107 | | - """Initialize DFlash attention with KV injection projections and QK-norm.""" |
108 | | - super().__init__() |
109 | | - self.config = config |
110 | | - self.layer_idx = layer_idx |
111 | | - self.head_dim = getattr( |
112 | | - config, "head_dim", config.hidden_size // config.num_attention_heads |
113 | | - ) |
114 | | - self.num_heads = config.num_attention_heads |
115 | | - self.num_kv_heads = config.num_key_value_heads |
116 | | - self.num_key_value_groups = self.num_heads // self.num_kv_heads |
117 | | - self.scaling = self.head_dim**-0.5 |
118 | | - self.attention_dropout = getattr(config, "attention_dropout", 0.0) |
119 | | - self.is_causal = False |
120 | | - |
121 | | - attn_bias = getattr(config, "attention_bias", False) |
122 | | - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) |
123 | | - self.k_proj = nn.Linear( |
124 | | - config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias |
125 | | - ) |
126 | | - self.v_proj = nn.Linear( |
127 | | - config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias |
128 | | - ) |
129 | | - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) |
130 | | - |
131 | | - self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) |
132 | | - self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) |
133 | | - |
134 | | - # Resolve HF attention function |
135 | | - self._attn_fn = None |
136 | | - # Qwen3 uses sliding window attention on some layers (config.layer_types) |
137 | | - if hasattr(config, "layer_types") and hasattr(config, "sliding_window"): |
138 | | - is_sliding = config.layer_types[layer_idx] == "sliding_attention" |
139 | | - self.sliding_window = config.sliding_window if is_sliding else None |
140 | | - else: |
141 | | - self.sliding_window = None |
142 | | - |
143 | | - def _get_attn_fn(self): |
144 | | - """Lazily resolve the HF attention function (default: sdpa).""" |
145 | | - if self._attn_fn is not None: |
146 | | - return self._attn_fn |
147 | | - impl = self.config._attn_implementation # default set in dflash/default_config.py |
148 | | - self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) |
149 | | - return self._attn_fn |
150 | | - |
151 | | - def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): |
152 | | - """Forward with KV injection. |
153 | | -
|
154 | | - Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]). |
155 | | - K and V are projected from the concatenation of target hidden states (context from the |
156 | | - base model) and noise block, so the draft can attend to both context and its own block. |
157 | | - """ |
158 | | - bsz, q_len, _ = hidden_states.shape |
159 | | - ctx_len = target_hidden.shape[1] |
160 | | - |
161 | | - # Q from noise block only (the draft tokens being predicted), with QK-norm |
162 | | - q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) |
163 | | - q = self.q_norm(q).transpose(1, 2) |
164 | | - |
165 | | - # K from context + noise, with QK-norm |
166 | | - k_ctx = self.k_proj(target_hidden) |
167 | | - k_noise = self.k_proj(hidden_states) |
168 | | - k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) |
169 | | - k = self.k_norm(k).transpose(1, 2) |
170 | | - |
171 | | - # V from context + noise (no norm) |
172 | | - v_ctx = self.v_proj(target_hidden) |
173 | | - v_noise = self.v_proj(hidden_states) |
174 | | - v = ( |
175 | | - torch.cat([v_ctx, v_noise], dim=1) |
176 | | - .view(bsz, ctx_len + q_len, -1, self.head_dim) |
177 | | - .transpose(1, 2) |
178 | | - ) |
179 | | - |
180 | | - # RoPE |
181 | | - cos, sin = position_embeddings |
182 | | - q, k = apply_rotary_pos_emb(q, k, cos, sin) |
183 | | - |
184 | | - # Use HF's attention dispatch (handles GQA internally) |
185 | | - attn_fn = self._get_attn_fn() |
186 | | - attn_output, _ = attn_fn( |
187 | | - self, |
188 | | - q, |
189 | | - k, |
190 | | - v, |
191 | | - attention_mask, |
192 | | - dropout=0.0 if not self.training else self.attention_dropout, |
193 | | - scaling=self.scaling, |
194 | | - sliding_window=self.sliding_window, |
195 | | - ) |
196 | | - attn_output = attn_output.reshape(bsz, q_len, -1) |
197 | | - return self.o_proj(attn_output) |
198 | | - |
199 | | - |
200 | | -class DFlashDecoderLayer(nn.Module): |
201 | | - """Draft decoder layer with KV injection.""" |
202 | | - |
203 | | - def __init__(self, config, layer_idx): |
204 | | - """Initialize decoder layer with attention, MLP, and layer norms.""" |
205 | | - super().__init__() |
206 | | - self.self_attn = DFlashAttention(config, layer_idx) |
207 | | - self.mlp = _MLP_CLS(config) |
208 | | - self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) |
209 | | - self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) |
210 | | - |
211 | | - def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): |
212 | | - """Forward pass with residual connections.""" |
213 | | - residual = hidden_states |
214 | | - hidden_states = self.input_layernorm(hidden_states) |
215 | | - hidden_states = self.self_attn( |
216 | | - hidden_states, target_hidden, position_embeddings, attention_mask |
217 | | - ) |
218 | | - hidden_states = residual + hidden_states |
219 | | - |
220 | | - residual = hidden_states |
221 | | - hidden_states = self.post_attention_layernorm(hidden_states) |
222 | | - hidden_states = self.mlp(hidden_states) |
223 | | - hidden_states = residual + hidden_states |
224 | | - return hidden_states |
225 | | - |
226 | | - |
227 | | -class DFlashModule(nn.Module): |
228 | | - """DFlash draft module using Qwen3 components (MLP, RMSNorm, RotaryEmbedding).""" |
229 | | - |
230 | | - def __init__(self, config): |
231 | | - """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" |
232 | | - super().__init__() |
233 | | - self.config = config |
234 | | - self.block_size = config.block_size |
235 | | - |
236 | | - # Feature fusion |
237 | | - num_fused_layers = len(config.target_layer_ids) |
238 | | - self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) |
239 | | - self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) |
240 | | - |
241 | | - # Decoder layers |
242 | | - self.layers = nn.ModuleList( |
243 | | - [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
244 | | - ) |
245 | | - self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) |
246 | | - self._rotary_config = config # Used by _maybe_init_rotary_emb |
247 | | - |
248 | | - # Explicit weight init is needed because DFlashModule is instantiated via |
249 | | - # mtsp.convert() AFTER the base model's post_init() has already run, so HF's |
250 | | - # automatic _init_weights walk doesn't reach these new layers. |
251 | | - self._init_weights(config) |
252 | | - |
253 | | - def _maybe_init_rotary_emb(self, device=None): |
254 | | - """Lazily initialize rotary embeddings on first forward call. |
255 | | -
|
256 | | - Same pattern as EAGLE3's _maybe_init_rope. Avoids creating rotary_emb |
257 | | - during __init__ (which runs on meta device during from_pretrained), |
258 | | - preventing the meta-tensor inv_freq issue on checkpoint resume. |
259 | | - """ |
260 | | - if not hasattr(self, "rotary_emb"): |
261 | | - self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=device) |
262 | | - |
263 | | - def _init_weights(self, config): |
264 | | - """Initialize weights matching HF PreTrainedModel._init_weights.""" |
265 | | - std = getattr(config, "initializer_range", 0.02) |
266 | | - for module in self.modules(): |
267 | | - if isinstance(module, nn.Linear): |
268 | | - nn.init.normal_(module.weight, mean=0.0, std=std) |
269 | | - if module.bias is not None: |
270 | | - nn.init.zeros_(module.bias) |
271 | | - |
272 | | - def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): |
273 | | - """Forward with feature fusion, KV injection, and position embeddings.""" |
274 | | - hidden_states = noise_embedding |
275 | | - target_hidden = self.hidden_norm(self.fc(target_hidden)) |
276 | | - self._maybe_init_rotary_emb(device=hidden_states.device) |
277 | | - position_embeddings = self.rotary_emb(hidden_states, position_ids) |
278 | | - |
279 | | - for layer in self.layers: |
280 | | - hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) |
281 | | - |
282 | | - return self.norm(hidden_states) |
283 | | - |
284 | | - |
285 | 72 | @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) |
286 | 73 | class HFDFlashModel(DFlashModel): |
287 | 74 | """DFlash Model for HuggingFace transformers.""" |
|
0 commit comments