3030import torch
3131from einops import rearrange
3232from torch import nn
33+ from transformers .activations import ACT2FN
3334
3435from vllm .compilation .decorators import support_torch_compile
3536from vllm .config import (
37+ CacheConfig ,
38+ ModelConfig ,
39+ SpeculativeConfig ,
3640 VllmConfig ,
41+ get_current_vllm_config ,
3742)
3843from vllm .distributed import (
44+ divide ,
3945 get_pp_group ,
46+ get_tensor_model_parallel_rank ,
47+ get_tensor_model_parallel_world_size ,
4048)
4149from vllm .logger import init_logger
4250from vllm .model_executor .layers .layernorm import (
4351 GemmaRMSNorm as Qwen3_5RMSNorm ,
4452)
45- from vllm .model_executor .layers .linear import MergedColumnParallelLinear
53+ from vllm .model_executor .layers .layernorm import RMSNormGated
54+ from vllm .model_executor .layers .linear import (
55+ ColumnParallelLinear ,
56+ MergedColumnParallelLinear ,
57+ RowParallelLinear ,
58+ )
4659from vllm .model_executor .layers .logits_processor import LogitsProcessor
60+ from vllm .model_executor .layers .mamba .mamba_mixer2 import (
61+ mamba_v2_sharded_weight_loader ,
62+ )
4763from vllm .model_executor .layers .mamba .mamba_utils import (
4864 MambaStateCopyFunc ,
4965 MambaStateCopyFuncCalculator ,
5773)
5874from vllm .model_executor .model_loader .weight_utils import (
5975 default_weight_loader ,
76+ sharded_weight_loader ,
6077)
78+ from vllm .model_executor .utils import set_weight_attrs
6179from vllm .multimodal import MULTIMODAL_REGISTRY
80+ from vllm .platforms import current_platform
6281from vllm .sequence import IntermediateTensors
6382from vllm .transformers_utils .configs .qwen3_5 import (
6483 Qwen3_5Config ,
8099)
81100from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
82101from .qwen3_next import (
102+ ChunkGatedDeltaRule ,
83103 Qwen3NextAttention ,
84104 Qwen3NextDecoderLayer ,
85105 Qwen3NextGatedDeltaNet ,
@@ -119,29 +139,152 @@ def get_hf_config(self):
119139
120140
121141class Qwen3_5GatedDeltaNet (Qwen3NextGatedDeltaNet ):
122- def fix_query_key_value_ordering (
142+ def __init__ (
123143 self ,
124- mixed_qkvz : torch .Tensor ,
125- mixed_ba : torch .Tensor ,
126- ):
127- raise NotImplementedError (
128- "Qwen3.5 Series dont need to fix query key value ordering"
144+ config : Qwen3_5TextConfig | Qwen3_5MoeTextConfig ,
145+ model_config : ModelConfig | None = None ,
146+ cache_config : CacheConfig | None = None ,
147+ quant_config : QuantizationConfig | None = None ,
148+ speculative_config : SpeculativeConfig | None = None ,
149+ prefix : str = "" ,
150+ ) -> None :
151+ super (Qwen3NextGatedDeltaNet , self ).__init__ ()
152+ self .tp_size = get_tensor_model_parallel_world_size ()
153+ self .tp_rank = get_tensor_model_parallel_rank ()
154+ self .hidden_size = config .hidden_size
155+ self .num_v_heads = config .linear_num_value_heads
156+ self .num_k_heads = config .linear_num_key_heads
157+ self .head_k_dim = config .linear_key_head_dim
158+ self .head_v_dim = config .linear_value_head_dim
159+ self .key_dim = self .head_k_dim * self .num_k_heads
160+ self .value_dim = self .head_v_dim * self .num_v_heads
161+
162+ self .conv_kernel_size = config .linear_conv_kernel_dim
163+ self .layer_idx = extract_layer_index (prefix )
164+ self .activation = config .hidden_act
165+ self .act = ACT2FN [config .hidden_act ]
166+ self .layer_norm_epsilon = config .rms_norm_eps
167+ self .prefix = prefix
168+
169+ self .config = config
170+ self .model_config = model_config
171+ self .cache_config = cache_config
172+ self .quant_config = quant_config
173+ self .speculative_config = speculative_config
174+ self .num_spec = (
175+ self .speculative_config .num_speculative_tokens
176+ if self .speculative_config
177+ else 0
129178 )
130179
131- def create_qkvz_proj (
132- self ,
133- hidden_size : int ,
134- key_dim : int ,
135- value_dim : int ,
136- quant_config : QuantizationConfig | None ,
137- prefix : str ,
138- ) -> MergedColumnParallelLinear :
139- return MergedColumnParallelLinear (
140- input_size = hidden_size ,
141- output_sizes = [key_dim , key_dim , value_dim , value_dim ],
180+ # QKV
181+ self .conv_dim = self .key_dim * 2 + self .value_dim
182+ self .conv1d = ColumnParallelLinear (
183+ input_size = self .conv_kernel_size ,
184+ output_size = self .conv_dim ,
185+ bias = False ,
186+ prefix = f"{ prefix } .conv1d" ,
187+ )
188+ self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
189+
190+ self .in_proj_qkv = MergedColumnParallelLinear (
191+ input_size = self .hidden_size ,
192+ output_sizes = [self .key_dim , self .key_dim , self .value_dim ],
193+ bias = False ,
194+ quant_config = quant_config ,
195+ prefix = f"{ prefix } .in_proj_qkv" ,
196+ )
197+ self .in_proj_z = ColumnParallelLinear (
198+ input_size = self .hidden_size ,
199+ output_size = self .value_dim ,
200+ bias = False ,
201+ quant_config = quant_config ,
202+ prefix = f"{ prefix } .in_proj_z" ,
203+ )
204+ self .in_proj_b = ColumnParallelLinear (
205+ input_size = self .hidden_size ,
206+ output_size = self .num_v_heads ,
207+ bias = False ,
208+ quant_config = quant_config ,
209+ prefix = f"{ prefix } .in_proj_b" ,
210+ )
211+ self .in_proj_a = ColumnParallelLinear (
212+ input_size = self .hidden_size ,
213+ output_size = self .num_v_heads ,
142214 bias = False ,
143215 quant_config = quant_config ,
144- prefix = prefix ,
216+ prefix = f"{ prefix } .in_proj_a" ,
217+ )
218+
219+ query_key_settings = (self .key_dim , 0 , False )
220+ value_settings = (self .value_dim , 0 , False )
221+
222+ delattr (self .conv1d .weight , "weight_loader" )
223+ set_weight_attrs (
224+ self .conv1d .weight ,
225+ {
226+ "weight_loader" : mamba_v2_sharded_weight_loader (
227+ [
228+ query_key_settings ,
229+ query_key_settings ,
230+ value_settings ,
231+ ],
232+ self .tp_size ,
233+ self .tp_rank ,
234+ )
235+ },
236+ )
237+
238+ # selective projection used to make dt, B and C input dependant
239+
240+ # time step projection (discretization)
241+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
242+ self .dt_bias = nn .Parameter (
243+ torch .ones (self .num_v_heads // self .tp_size ),
244+ )
245+ self .A_log = nn .Parameter (
246+ torch .empty (
247+ divide (self .num_v_heads , self .tp_size ),
248+ )
249+ )
250+
251+ set_weight_attrs (self .A_log , {"weight_loader" : sharded_weight_loader (0 )})
252+ set_weight_attrs (self .dt_bias , {"weight_loader" : sharded_weight_loader (0 )})
253+
254+ self .norm = RMSNormGated (
255+ self .head_v_dim ,
256+ eps = self .layer_norm_epsilon ,
257+ group_size = None ,
258+ norm_before_gate = True ,
259+ device = current_platform .current_device (),
260+ dtype = config .dtype ,
261+ )
262+
263+ self .out_proj = RowParallelLinear (
264+ self .value_dim ,
265+ self .hidden_size ,
266+ bias = False ,
267+ input_is_parallel = True ,
268+ quant_config = quant_config ,
269+ prefix = f"{ prefix } .out_proj" ,
270+ )
271+
272+ self .chunk_gated_delta_rule = ChunkGatedDeltaRule ()
273+
274+ compilation_config = get_current_vllm_config ().compilation_config
275+ if prefix in compilation_config .static_forward_context :
276+ raise ValueError (f"Duplicate layer name: { prefix } " )
277+ compilation_config .static_forward_context [prefix ] = self
278+
279+ def fix_query_key_value_ordering (
280+ self ,
281+ mixed_qkv ,
282+ z ,
283+ b ,
284+ a ,
285+ ):
286+ raise NotImplementedError (
287+ "Qwen3.5 Series dont need to fix query key value ordering"
145288 )
146289
147290 def forward (
@@ -160,13 +303,11 @@ def forward(
160303 # ============================================================
161304 # Part 1: Input Projection
162305 # ============================================================
163- mixed_qkvz , _ = self .in_proj_qkvz (hidden_states )
164- qkv_size = (self .key_dim * 2 + self .value_dim ) // self .tp_size
165- z_size = self .value_dim // self .tp_size
166- mixed_qkv , z = mixed_qkvz .split ([qkv_size , z_size ], dim = - 1 )
306+ mixed_qkv , _ = self .in_proj_qkv (hidden_states )
307+ z , _ = self .in_proj_z (hidden_states )
167308 z = z .reshape (z .size (0 ), - 1 , self .head_v_dim )
168- ba , _ = self .in_proj_ba (hidden_states )
169- b , a = ba . chunk ( 2 , dim = - 1 )
309+ b , _ = self .in_proj_b (hidden_states )
310+ a , _ = self . in_proj_a ( hidden_states )
170311
171312 b = b .contiguous ()
172313 a = a .contiguous ()
@@ -365,18 +506,11 @@ def load_fused_expert_weights(
365506 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
366507 stacked_params_mapping = [
367508 # (param_name, shard_name, shard_id)
368- # self attention
369509 ("qkv_proj" , "q_proj" , "q" ),
370510 ("qkv_proj" , "k_proj" , "k" ),
371511 ("qkv_proj" , "v_proj" , "v" ),
372- # mlp
373512 ("gate_up_proj" , "gate_proj" , 0 ),
374513 ("gate_up_proj" , "up_proj" , 1 ),
375- # GDN
376- ("in_proj_qkvz" , "in_proj_qkv" , (0 , 1 , 2 )),
377- ("in_proj_qkvz" , "in_proj_z" , 3 ),
378- ("in_proj_ba" , "in_proj_b" , 0 ),
379- ("in_proj_ba" , "in_proj_a" , 1 ),
380514 ]
381515
382516 params_dict = dict (self .named_parameters ())
0 commit comments