Skip to content

Commit 8a1f82b

Browse files
committed
Fix minimax tp8
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
1 parent 968ed02 commit 8a1f82b

3 files changed

Lines changed: 250 additions & 23 deletions

File tree

tests/kernels/moe/test_moe_weight_loading_padded.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
correctly handles this mismatch.
1010
"""
1111

12+
import math
13+
from unittest.mock import MagicMock
14+
1215
import pytest
1316
import torch
1417

@@ -290,3 +293,182 @@ def test_bnb_shape_mismatch_raises(self):
290293
shard_id="w2",
291294
expert_id=0,
292295
)
296+
297+
298+
def _make_fused_moe_mock(*, is_act_and_mul: bool = True):
299+
"""Build a FusedMoE mock for weight loading tests."""
300+
moe_module = MagicMock(spec=FusedMoE)
301+
moe_module.moe_config = MagicMock()
302+
moe_module.moe_config.is_act_and_mul = is_act_and_mul
303+
304+
moe_module._get_hidden_dim = FusedMoE._get_hidden_dim
305+
moe_module._narrow_expert_data_for_padding = (
306+
FusedMoE._narrow_expert_data_for_padding
307+
)
308+
return moe_module
309+
310+
311+
class TestBlockQuantPaddedHiddenAndIntermediateSize:
312+
"""Tests weight loading with padded hidden_size and intermediate_size
313+
across TP ranks.
314+
315+
hidden_size: 192 -> 256 (DeepEP-style round-up)
316+
intermediate_size_per_partition: 448 -> 512 (block_n=128 alignment)
317+
"""
318+
319+
BLOCK_N = 128
320+
HIDDEN_UNPADDED = 192
321+
HIDDEN_PADDED = math.ceil(HIDDEN_UNPADDED / BLOCK_N) * BLOCK_N
322+
INTERMEDIATE_UNPADDED = 448
323+
INTERMEDIATE_PADDED = math.ceil(INTERMEDIATE_UNPADDED / BLOCK_N) * BLOCK_N
324+
TP_SIZE = 4
325+
GLOBAL_INTER = INTERMEDIATE_UNPADDED * TP_SIZE
326+
327+
def _make_fused_moe(self):
328+
return _make_fused_moe_mock()
329+
330+
def test_load_w1_weight_all_tp_ranks(self):
331+
"""Each TP rank loads block-aligned rows into the w1 half.
332+
The last rank gets fewer rows; the rest is padding."""
333+
moe_module = self._make_fused_moe()
334+
checkpoint = torch.randn(self.GLOBAL_INTER, self.HIDDEN_UNPADDED)
335+
336+
for tp_rank in range(self.TP_SIZE):
337+
expert_data = torch.zeros(2 * self.INTERMEDIATE_PADDED, self.HIDDEN_PADDED)
338+
FusedMoE._load_w13(
339+
moe_module,
340+
expert_data=expert_data,
341+
shard_dim=0,
342+
shard_id="w1",
343+
loaded_weight=checkpoint.clone(),
344+
tp_rank=tp_rank,
345+
)
346+
w1 = expert_data[: self.INTERMEDIATE_PADDED]
347+
start = tp_rank * self.INTERMEDIATE_PADDED
348+
n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start)
349+
expected = checkpoint[start : start + n_available]
350+
351+
assert torch.equal(w1[:n_available, : self.HIDDEN_UNPADDED], expected)
352+
assert torch.all(w1[n_available:] == 0)
353+
assert torch.all(w1[:n_available, self.HIDDEN_UNPADDED :] == 0)
354+
assert torch.all(expert_data[self.INTERMEDIATE_PADDED :] == 0)
355+
356+
def test_load_w3_weight_into_second_half(self):
357+
"""w3 weight is written into the second half of the w13 allocation."""
358+
moe_module = self._make_fused_moe()
359+
checkpoint = torch.randn(self.GLOBAL_INTER, self.HIDDEN_UNPADDED)
360+
tp_rank = 2
361+
362+
expert_data = torch.zeros(2 * self.INTERMEDIATE_PADDED, self.HIDDEN_PADDED)
363+
FusedMoE._load_w13(
364+
moe_module,
365+
expert_data=expert_data,
366+
shard_dim=0,
367+
shard_id="w3",
368+
loaded_weight=checkpoint.clone(),
369+
tp_rank=tp_rank,
370+
)
371+
assert torch.all(expert_data[: self.INTERMEDIATE_PADDED] == 0)
372+
373+
w3 = expert_data[self.INTERMEDIATE_PADDED :]
374+
start = tp_rank * self.INTERMEDIATE_PADDED
375+
n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start)
376+
assert torch.equal(
377+
w3[:n_available, : self.HIDDEN_UNPADDED],
378+
checkpoint[start : start + n_available],
379+
)
380+
assert torch.all(w3[n_available:] == 0)
381+
382+
def test_load_w2_weight_all_tp_ranks(self):
383+
"""Each TP rank loads block-aligned columns of w2."""
384+
moe_module = self._make_fused_moe()
385+
checkpoint = torch.randn(self.HIDDEN_UNPADDED, self.GLOBAL_INTER)
386+
387+
for tp_rank in range(self.TP_SIZE):
388+
expert_data = torch.zeros(self.HIDDEN_PADDED, self.INTERMEDIATE_PADDED)
389+
FusedMoE._load_w2(
390+
moe_module,
391+
expert_data=expert_data,
392+
shard_dim=1,
393+
loaded_weight=checkpoint.clone(),
394+
tp_rank=tp_rank,
395+
)
396+
start = tp_rank * self.INTERMEDIATE_PADDED
397+
n_available = min(self.INTERMEDIATE_PADDED, self.GLOBAL_INTER - start)
398+
expected = checkpoint[:, start : start + n_available]
399+
assert torch.equal(
400+
expert_data[: self.HIDDEN_UNPADDED, :n_available], expected
401+
)
402+
assert torch.all(expert_data[:, n_available:] == 0)
403+
assert torch.all(expert_data[self.HIDDEN_UNPADDED :] == 0)
404+
405+
def test_load_w1_scale_all_tp_ranks(self):
406+
"""Each TP rank loads block-aligned scale rows for w1."""
407+
moe_module = self._make_fused_moe()
408+
n_rows_global = math.ceil(self.GLOBAL_INTER / self.BLOCK_N)
409+
n_cols_ckpt = math.ceil(self.HIDDEN_UNPADDED / self.BLOCK_N)
410+
n_rows_local = math.ceil(self.INTERMEDIATE_PADDED / self.BLOCK_N)
411+
n_cols_alloc = math.ceil(self.HIDDEN_PADDED / self.BLOCK_N)
412+
413+
checkpoint_scale = torch.randn(n_rows_global, n_cols_ckpt)
414+
415+
for tp_rank in range(self.TP_SIZE):
416+
expert_data = torch.zeros(2 * n_rows_local, n_cols_alloc)
417+
FusedMoE._load_w13(
418+
moe_module,
419+
expert_data=expert_data,
420+
shard_dim=0,
421+
shard_id="w1",
422+
loaded_weight=checkpoint_scale.clone(),
423+
tp_rank=tp_rank,
424+
)
425+
w1_scale = expert_data[:n_rows_local]
426+
start = n_rows_local * tp_rank
427+
loaded = min(n_rows_local, n_rows_global - start)
428+
expected = checkpoint_scale[start : start + loaded]
429+
assert torch.equal(w1_scale[:loaded, :n_cols_ckpt], expected)
430+
431+
def test_load_w2_scale_all_tp_ranks(self):
432+
"""Each TP rank loads block-aligned scale columns for w2."""
433+
moe_module = self._make_fused_moe()
434+
n_rows = math.ceil(self.HIDDEN_UNPADDED / self.BLOCK_N)
435+
n_cols_global = math.ceil(self.GLOBAL_INTER / self.BLOCK_N)
436+
n_cols_local = math.ceil(self.INTERMEDIATE_PADDED / self.BLOCK_N)
437+
438+
checkpoint_scale = torch.randn(n_rows, n_cols_global)
439+
440+
for tp_rank in range(self.TP_SIZE):
441+
expert_data = torch.zeros(n_rows, n_cols_local)
442+
FusedMoE._load_w2(
443+
moe_module,
444+
expert_data=expert_data,
445+
shard_dim=1,
446+
loaded_weight=checkpoint_scale.clone(),
447+
tp_rank=tp_rank,
448+
)
449+
start = n_cols_local * tp_rank
450+
loaded = min(n_cols_local, n_cols_global - start)
451+
expected = checkpoint_scale[:, start : start + loaded]
452+
assert torch.equal(expert_data[:, :loaded], expected)
453+
454+
def test_no_padding_matches_simple_shard(self):
455+
"""When sizes are already block-aligned, loading is a simple
456+
shard_size * tp_rank partition."""
457+
intermediate = 512
458+
hidden = 256
459+
moe_module = _make_fused_moe_mock()
460+
checkpoint = torch.randn(intermediate * self.TP_SIZE, hidden)
461+
462+
for tp_rank in range(self.TP_SIZE):
463+
expert_data = torch.zeros(2 * intermediate, hidden)
464+
FusedMoE._load_w13(
465+
moe_module,
466+
expert_data=expert_data,
467+
shard_dim=0,
468+
shard_id="w1",
469+
loaded_weight=checkpoint.clone(),
470+
tp_rank=tp_rank,
471+
)
472+
w1 = expert_data[:intermediate]
473+
start = tp_rank * intermediate
474+
assert torch.equal(w1, checkpoint[start : start + intermediate])

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,10 @@ def _load_per_channel_weight_scale(
842842
if shard_id == "w2":
843843
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
844844
expert_data = self._narrow_expert_data_for_padding(
845-
expert_data, loaded_weight, hidden_dim=hidden_dim
845+
expert_data,
846+
loaded_weight,
847+
hidden_dim=hidden_dim,
848+
intermediate_dim=shard_dim,
846849
)
847850
expert_data.copy_(loaded_weight)
848851
elif shard_id in ("w1", "w3"):
@@ -882,29 +885,32 @@ def _narrow_expert_data_for_padding(
882885
expert_data: torch.Tensor,
883886
loaded_weight: torch.Tensor,
884887
hidden_dim: int,
888+
intermediate_dim: int = -1,
885889
) -> torch.Tensor:
886-
"""Narrow expert_data hidden dim to match loaded_weight for padded
887-
hidden_size.
890+
"""Narrow expert_data to match loaded_weight for padded dimensions.
888891
889892
When backends (e.g., DeepEP) round up hidden_size, weight parameters
890-
are larger than checkpoint weights. Narrow the padded hidden dimension
891-
before copying.
893+
are larger than checkpoint weights. Similarly, on the last TP rank the
894+
intermediate dimension of loaded_weight may be smaller than the padded
895+
allocation. Narrow both dimensions before copying.
892896
893897
Args:
894898
expert_data: The (possibly padded) parameter tensor to narrow.
895899
loaded_weight: The checkpoint weight tensor with original size.
896900
hidden_dim: The dimension index corresponding to hidden_size.
897901
Must be non-negative.
902+
intermediate_dim: The dimension index corresponding to the
903+
intermediate size. When >= 0, expert_data is also narrowed
904+
along this axis if it is larger than loaded_weight.
898905
"""
899-
if (
900-
loaded_weight.ndim > 0
901-
and 0 <= hidden_dim < expert_data.ndim
902-
and hidden_dim < loaded_weight.ndim
903-
and expert_data.shape[hidden_dim] > loaded_weight.shape[hidden_dim]
904-
):
905-
expert_data = expert_data.narrow(
906-
hidden_dim, 0, loaded_weight.shape[hidden_dim]
907-
)
906+
for dim in (hidden_dim, intermediate_dim):
907+
if (
908+
loaded_weight.ndim > 0
909+
and 0 <= dim < expert_data.ndim
910+
and dim < loaded_weight.ndim
911+
and expert_data.shape[dim] > loaded_weight.shape[dim]
912+
):
913+
expert_data = expert_data.narrow(dim, 0, loaded_weight.shape[dim])
908914
return expert_data
909915

910916
def _load_w13(
@@ -922,6 +928,7 @@ def _load_w13(
922928
shard_size = expert_data.shape[shard_dim] // 2
923929
else:
924930
shard_size = expert_data.shape[shard_dim]
931+
925932
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
926933
# and we're not loading the full weight
927934
if not load_full and loaded_weight.ndim > 0:
@@ -946,7 +953,10 @@ def _load_w13(
946953
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
947954
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
948955
expert_data = self._narrow_expert_data_for_padding(
949-
expert_data, loaded_weight, hidden_dim=hidden_dim
956+
expert_data,
957+
loaded_weight,
958+
hidden_dim=hidden_dim,
959+
intermediate_dim=shard_dim,
950960
)
951961
expert_data.copy_(loaded_weight)
952962

@@ -962,6 +972,7 @@ def _load_w2(
962972
# down_proj: "RowParallel" so tp sharding on input_dim
963973
# Narrow parameter and load.
964974
shard_size = expert_data.shape[shard_dim]
975+
965976
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
966977
# and we're not loading the full weight
967978
if not load_full and loaded_weight.ndim > 0:
@@ -979,7 +990,10 @@ def _load_w2(
979990
# w2, down_proj: Load into only logical weight of w2.
980991
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
981992
expert_data = self._narrow_expert_data_for_padding(
982-
expert_data, loaded_weight, hidden_dim=hidden_dim
993+
expert_data,
994+
loaded_weight,
995+
hidden_dim=hidden_dim,
996+
intermediate_dim=shard_dim,
983997
)
984998
expert_data.copy_(loaded_weight)
985999

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FusedMoeWeightScaleSupported,
2525
)
2626
from vllm.model_executor.layers.fused_moe.config import (
27+
FusedMoEParallelConfig,
2728
FusedMoEQuantConfig,
2829
)
2930
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
@@ -608,6 +609,36 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
608609
allow_vllm_cutlass=False,
609610
)
610611

612+
def maybe_roundup_sizes(
613+
self,
614+
hidden_size: int,
615+
intermediate_size_per_partition: int,
616+
act_dtype: torch.dtype,
617+
moe_parallel_config: "FusedMoEParallelConfig",
618+
) -> tuple[int, int]:
619+
hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
620+
hidden_size=hidden_size,
621+
intermediate_size_per_partition=intermediate_size_per_partition,
622+
act_dtype=act_dtype,
623+
moe_parallel_config=moe_parallel_config,
624+
)
625+
if self.block_quant:
626+
assert self.weight_block_size is not None
627+
block_n = self.weight_block_size[0]
628+
if intermediate_size_per_partition % block_n != 0:
629+
padded = (
630+
(intermediate_size_per_partition + block_n - 1) // block_n * block_n
631+
)
632+
logger.info_once(
633+
"Padding MoE intermediate size per partition from %d to "
634+
"%d for FP8 block quantization alignment (block_n=%d).",
635+
intermediate_size_per_partition,
636+
padded,
637+
block_n,
638+
)
639+
intermediate_size_per_partition = padded
640+
return hidden_size, intermediate_size_per_partition
641+
611642
def create_weights(
612643
self,
613644
layer: Module,
@@ -635,13 +666,13 @@ def create_weights(
635666
# NOTE: To ensure proper alignment of the block-wise quantization
636667
# scales, the output_size of the weights for both the gate and up
637668
# layers must be divisible by block_n.
638-
# Required by column parallel or enabling merged weights
639-
if intermediate_size_per_partition % block_n != 0:
640-
raise ValueError(
641-
f"The output_size of gate's and up's weight = "
642-
f"{intermediate_size_per_partition} is not divisible by "
643-
f"weight quantization block_n = {block_n}."
644-
)
669+
# Required by column parallel or enabling merged weights.
670+
# This is guaranteed by maybe_roundup_sizes() which pads
671+
# intermediate_size_per_partition to the next block_n multiple.
672+
assert intermediate_size_per_partition % block_n == 0, (
673+
f"intermediate_size_per_partition={intermediate_size_per_partition} "
674+
f"should have been padded to a multiple of block_n={block_n}"
675+
)
645676
if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
646677
# Required by row parallel
647678
raise ValueError(

0 commit comments

Comments
 (0)