Skip to content

Commit 1b88965

Browse files
authored
Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (vllm-project#34683)
1 parent 6b8cae9 commit 1b88965

3 files changed

Lines changed: 182 additions & 87 deletions

File tree

vllm/model_executor/layers/linear.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -685,13 +685,8 @@ def weight_loader(
685685
self,
686686
param: Parameter,
687687
loaded_weight: torch.Tensor,
688-
loaded_shard_id: tuple[int, ...] | int | None = None,
688+
loaded_shard_id: int | None = None,
689689
):
690-
if isinstance(loaded_shard_id, tuple):
691-
raise NotImplementedError(
692-
"Shard id with multiple indices is not supported in weight_loader, "
693-
"please use weight_loader_v2 instead."
694-
)
695690
# Special case for GGUF
696691
# initialize GGUF param after we know the quantize type
697692
is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -830,10 +825,7 @@ def weight_loader(
830825
param_data.copy_(loaded_weight)
831826

832827
def _load_fused_module_from_checkpoint(
833-
self,
834-
param: BasevLLMParameter,
835-
loaded_weight: torch.Tensor,
836-
output_sizes: list[int] | None = None,
828+
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
837829
):
838830
"""
839831
Handle special case for models where MLP layers are already
@@ -847,8 +839,7 @@ def _load_fused_module_from_checkpoint(
847839

848840
current_shard_offset = 0
849841
shard_offsets: list[tuple[int, int, int]] = []
850-
output_sizes = output_sizes or self.output_sizes
851-
for i, output_size in enumerate(output_sizes):
842+
for i, output_size in enumerate(self.output_sizes):
852843
shard_offsets.append((i, current_shard_offset, output_size))
853844
current_shard_offset += output_size
854845

@@ -873,30 +864,17 @@ def weight_loader_v2(
873864
self,
874865
param: BasevLLMParameter,
875866
loaded_weight: torch.Tensor,
876-
loaded_shard_id: tuple[int, ...] | int | None = None,
867+
loaded_shard_id: int | None = None,
877868
):
878-
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
869+
if loaded_shard_id is None:
879870
if isinstance(param, PerTensorScaleParameter):
880871
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
881872
return
882873
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
883874
param.load_merged_column_weight(loaded_weight=loaded_weight)
884875
return
885-
output_sizes = (
886-
[self.output_sizes[idx] for idx in loaded_shard_id]
887-
if loaded_shard_id
888-
else None
889-
)
890-
if isinstance(param, BlockQuantScaleParameter):
891-
weight_block_size = getattr(self, "weight_block_size", None)
892-
output_sizes = [
893-
adjust_block_scale_shard(weight_block_size, size, 0)[0]
894-
for size in (output_sizes or self.output_sizes)
895-
]
896876
# TODO: @dsikka - move to parameter.py
897-
self._load_fused_module_from_checkpoint(
898-
param, loaded_weight, output_sizes=output_sizes
899-
)
877+
self._load_fused_module_from_checkpoint(param, loaded_weight)
900878
return
901879

902880
assert loaded_shard_id < len(self.output_sizes)

vllm/model_executor/models/qwen3_5.py

Lines changed: 166 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,36 @@
3030
import torch
3131
from einops import rearrange
3232
from torch import nn
33+
from transformers.activations import ACT2FN
3334

3435
from vllm.compilation.decorators import support_torch_compile
3536
from vllm.config import (
37+
CacheConfig,
38+
ModelConfig,
39+
SpeculativeConfig,
3640
VllmConfig,
41+
get_current_vllm_config,
3742
)
3843
from vllm.distributed import (
44+
divide,
3945
get_pp_group,
46+
get_tensor_model_parallel_rank,
47+
get_tensor_model_parallel_world_size,
4048
)
4149
from vllm.logger import init_logger
4250
from vllm.model_executor.layers.layernorm import (
4351
GemmaRMSNorm as Qwen3_5RMSNorm,
4452
)
45-
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
53+
from vllm.model_executor.layers.layernorm import RMSNormGated
54+
from vllm.model_executor.layers.linear import (
55+
ColumnParallelLinear,
56+
MergedColumnParallelLinear,
57+
RowParallelLinear,
58+
)
4659
from vllm.model_executor.layers.logits_processor import LogitsProcessor
60+
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
61+
mamba_v2_sharded_weight_loader,
62+
)
4763
from vllm.model_executor.layers.mamba.mamba_utils import (
4864
MambaStateCopyFunc,
4965
MambaStateCopyFuncCalculator,
@@ -57,8 +73,11 @@
5773
)
5874
from vllm.model_executor.model_loader.weight_utils import (
5975
default_weight_loader,
76+
sharded_weight_loader,
6077
)
78+
from vllm.model_executor.utils import set_weight_attrs
6179
from vllm.multimodal import MULTIMODAL_REGISTRY
80+
from vllm.platforms import current_platform
6281
from vllm.sequence import IntermediateTensors
6382
from vllm.transformers_utils.configs.qwen3_5 import (
6483
Qwen3_5Config,
@@ -80,6 +99,7 @@
8099
)
81100
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
82101
from .qwen3_next import (
102+
ChunkGatedDeltaRule,
83103
Qwen3NextAttention,
84104
Qwen3NextDecoderLayer,
85105
Qwen3NextGatedDeltaNet,
@@ -119,29 +139,152 @@ def get_hf_config(self):
119139

120140

121141
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
122-
def fix_query_key_value_ordering(
142+
def __init__(
123143
self,
124-
mixed_qkvz: torch.Tensor,
125-
mixed_ba: torch.Tensor,
126-
):
127-
raise NotImplementedError(
128-
"Qwen3.5 Series dont need to fix query key value ordering"
144+
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
145+
model_config: ModelConfig | None = None,
146+
cache_config: CacheConfig | None = None,
147+
quant_config: QuantizationConfig | None = None,
148+
speculative_config: SpeculativeConfig | None = None,
149+
prefix: str = "",
150+
) -> None:
151+
super(Qwen3NextGatedDeltaNet, self).__init__()
152+
self.tp_size = get_tensor_model_parallel_world_size()
153+
self.tp_rank = get_tensor_model_parallel_rank()
154+
self.hidden_size = config.hidden_size
155+
self.num_v_heads = config.linear_num_value_heads
156+
self.num_k_heads = config.linear_num_key_heads
157+
self.head_k_dim = config.linear_key_head_dim
158+
self.head_v_dim = config.linear_value_head_dim
159+
self.key_dim = self.head_k_dim * self.num_k_heads
160+
self.value_dim = self.head_v_dim * self.num_v_heads
161+
162+
self.conv_kernel_size = config.linear_conv_kernel_dim
163+
self.layer_idx = extract_layer_index(prefix)
164+
self.activation = config.hidden_act
165+
self.act = ACT2FN[config.hidden_act]
166+
self.layer_norm_epsilon = config.rms_norm_eps
167+
self.prefix = prefix
168+
169+
self.config = config
170+
self.model_config = model_config
171+
self.cache_config = cache_config
172+
self.quant_config = quant_config
173+
self.speculative_config = speculative_config
174+
self.num_spec = (
175+
self.speculative_config.num_speculative_tokens
176+
if self.speculative_config
177+
else 0
129178
)
130179

131-
def create_qkvz_proj(
132-
self,
133-
hidden_size: int,
134-
key_dim: int,
135-
value_dim: int,
136-
quant_config: QuantizationConfig | None,
137-
prefix: str,
138-
) -> MergedColumnParallelLinear:
139-
return MergedColumnParallelLinear(
140-
input_size=hidden_size,
141-
output_sizes=[key_dim, key_dim, value_dim, value_dim],
180+
# QKV
181+
self.conv_dim = self.key_dim * 2 + self.value_dim
182+
self.conv1d = ColumnParallelLinear(
183+
input_size=self.conv_kernel_size,
184+
output_size=self.conv_dim,
185+
bias=False,
186+
prefix=f"{prefix}.conv1d",
187+
)
188+
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
189+
190+
self.in_proj_qkv = MergedColumnParallelLinear(
191+
input_size=self.hidden_size,
192+
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
193+
bias=False,
194+
quant_config=quant_config,
195+
prefix=f"{prefix}.in_proj_qkv",
196+
)
197+
self.in_proj_z = ColumnParallelLinear(
198+
input_size=self.hidden_size,
199+
output_size=self.value_dim,
200+
bias=False,
201+
quant_config=quant_config,
202+
prefix=f"{prefix}.in_proj_z",
203+
)
204+
self.in_proj_b = ColumnParallelLinear(
205+
input_size=self.hidden_size,
206+
output_size=self.num_v_heads,
207+
bias=False,
208+
quant_config=quant_config,
209+
prefix=f"{prefix}.in_proj_b",
210+
)
211+
self.in_proj_a = ColumnParallelLinear(
212+
input_size=self.hidden_size,
213+
output_size=self.num_v_heads,
142214
bias=False,
143215
quant_config=quant_config,
144-
prefix=prefix,
216+
prefix=f"{prefix}.in_proj_a",
217+
)
218+
219+
query_key_settings = (self.key_dim, 0, False)
220+
value_settings = (self.value_dim, 0, False)
221+
222+
delattr(self.conv1d.weight, "weight_loader")
223+
set_weight_attrs(
224+
self.conv1d.weight,
225+
{
226+
"weight_loader": mamba_v2_sharded_weight_loader(
227+
[
228+
query_key_settings,
229+
query_key_settings,
230+
value_settings,
231+
],
232+
self.tp_size,
233+
self.tp_rank,
234+
)
235+
},
236+
)
237+
238+
# selective projection used to make dt, B and C input dependant
239+
240+
# time step projection (discretization)
241+
# instantiate once and copy inv_dt in init_weights of PretrainedModel
242+
self.dt_bias = nn.Parameter(
243+
torch.ones(self.num_v_heads // self.tp_size),
244+
)
245+
self.A_log = nn.Parameter(
246+
torch.empty(
247+
divide(self.num_v_heads, self.tp_size),
248+
)
249+
)
250+
251+
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
252+
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
253+
254+
self.norm = RMSNormGated(
255+
self.head_v_dim,
256+
eps=self.layer_norm_epsilon,
257+
group_size=None,
258+
norm_before_gate=True,
259+
device=current_platform.current_device(),
260+
dtype=config.dtype,
261+
)
262+
263+
self.out_proj = RowParallelLinear(
264+
self.value_dim,
265+
self.hidden_size,
266+
bias=False,
267+
input_is_parallel=True,
268+
quant_config=quant_config,
269+
prefix=f"{prefix}.out_proj",
270+
)
271+
272+
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
273+
274+
compilation_config = get_current_vllm_config().compilation_config
275+
if prefix in compilation_config.static_forward_context:
276+
raise ValueError(f"Duplicate layer name: {prefix}")
277+
compilation_config.static_forward_context[prefix] = self
278+
279+
def fix_query_key_value_ordering(
280+
self,
281+
mixed_qkv,
282+
z,
283+
b,
284+
a,
285+
):
286+
raise NotImplementedError(
287+
"Qwen3.5 Series dont need to fix query key value ordering"
145288
)
146289

147290
def forward(
@@ -160,13 +303,11 @@ def forward(
160303
# ============================================================
161304
# Part 1: Input Projection
162305
# ============================================================
163-
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
164-
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
165-
z_size = self.value_dim // self.tp_size
166-
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
306+
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
307+
z, _ = self.in_proj_z(hidden_states)
167308
z = z.reshape(z.size(0), -1, self.head_v_dim)
168-
ba, _ = self.in_proj_ba(hidden_states)
169-
b, a = ba.chunk(2, dim=-1)
309+
b, _ = self.in_proj_b(hidden_states)
310+
a, _ = self.in_proj_a(hidden_states)
170311

171312
b = b.contiguous()
172313
a = a.contiguous()
@@ -365,18 +506,11 @@ def load_fused_expert_weights(
365506
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
366507
stacked_params_mapping = [
367508
# (param_name, shard_name, shard_id)
368-
# self attention
369509
("qkv_proj", "q_proj", "q"),
370510
("qkv_proj", "k_proj", "k"),
371511
("qkv_proj", "v_proj", "v"),
372-
# mlp
373512
("gate_up_proj", "gate_proj", 0),
374513
("gate_up_proj", "up_proj", 1),
375-
# GDN
376-
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
377-
("in_proj_qkvz", "in_proj_z", 3),
378-
("in_proj_ba", "in_proj_b", 0),
379-
("in_proj_ba", "in_proj_a", 1),
380514
]
381515

382516
params_dict = dict(self.named_parameters())

0 commit comments

Comments
 (0)