Skip to content

Commit aaec5a1

Browse files
daniserebmeenchen
authored andcommitted
Add support and tests for 3D MoE in MXFP8QTensor
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent 8c28bb5 commit aaec5a1

2 files changed

Lines changed: 69 additions & 10 deletions

File tree

modelopt/torch/quantization/qtensor/mxfp8_tensor.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor:
7171
7272
Args:
7373
weight: The weight tensor to compute scale for. Must be at least 2D.
74+
Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim).
7475
7576
Returns:
7677
torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
78+
For 2D input: (out_dim, in_dim // 32)
79+
For 3D MoE input: (num_experts, out_dim, in_dim // 32)
7780
"""
7881
assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
7982

@@ -83,7 +86,7 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor:
8386
f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})"
8487
)
8588

86-
# Compute amax per block (reduce_block_amax handles reshaping internally)
89+
# Compute amax per block (reduce_block_amax handles N-dimensional tensors)
8790
amax = reduce_block_amax(weight, block_sizes={-1: cls.BLOCK_SIZE})
8891

8992
# Compute E8M0 exponent and convert to biased uint8 (bias = 127)
@@ -102,11 +105,12 @@ def get_weights_scaling_factor_from_quantizer(
102105
with proper format conversion and shape correction.
103106
104107
Args:
105-
weight: The weight tensor.
108+
weight: The weight tensor. Can be 2D (out_dim, in_dim) or
109+
3D for MoE (num_experts, out_dim, in_dim).
106110
weight_quantizer: The weight quantizer with block_sizes and optional _scale.
107111
108112
Returns:
109-
torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // 32].
113+
torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
110114
"""
111115
assert hasattr(weight_quantizer, "block_sizes"), (
112116
"weight_quantizer must have 'block_sizes' attribute"
@@ -116,8 +120,11 @@ def get_weights_scaling_factor_from_quantizer(
116120
)
117121
assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
118122

119-
out_dim, in_dim = weight.shape[-2], weight.shape[-1]
120-
expected_shape = (out_dim, in_dim // cls.BLOCK_SIZE)
123+
in_dim = weight.shape[-1]
124+
# Expected scale shape: all dims except last, with last dim reduced by block size
125+
# For 2D: (out_dim, in_dim // 32)
126+
# For 3D MoE: (num_experts, out_dim, in_dim // 32)
127+
expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE)
121128

122129
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
123130
scale = weight_quantizer._scale
@@ -127,11 +134,16 @@ def get_weights_scaling_factor_from_quantizer(
127134
)
128135

129136
# Reshape if needed (same number of elements but wrong shape)
130-
if (
131-
scale.shape != expected_shape
132-
and scale.numel() == expected_shape[0] * expected_shape[1]
133-
):
134-
scale = scale.reshape(expected_shape)
137+
if scale.shape != expected_shape:
138+
expected_numel = 1
139+
for dim in expected_shape:
140+
expected_numel *= dim
141+
if scale.numel() == expected_numel:
142+
scale = scale.reshape(expected_shape)
143+
144+
assert scale.shape == expected_shape, (
145+
f"Scale shape {scale.shape} does not match expected shape {expected_shape}"
146+
)
135147
return scale
136148

137149
# No scale in quantizer, compute from weight

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,50 @@ def test_mxfp8_quantize_with_scale_asserts(self, device):
850850
scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device)
851851
with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"):
852852
MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale)
853+
854+
@pytest.mark.parametrize("device", ["cuda"])
855+
def test_mxfp8_get_weights_scaling_factor_from_quantizer_3d_moe(self, device):
856+
"""Test get_weights_scaling_factor_from_quantizer handles 3D MoE tensors."""
857+
input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim)
858+
weight = torch.randn(input_shape, dtype=torch.float32, device=device)
859+
860+
class MockQuantizer:
861+
block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE}
862+
_scale = None
863+
864+
quantizer = MockQuantizer()
865+
866+
# Test when _scale is None (should compute from weight)
867+
scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer)
868+
869+
expected_shape = (
870+
input_shape[0],
871+
input_shape[1],
872+
input_shape[2] // MXFP8QTensor.BLOCK_SIZE,
873+
)
874+
assert scale.shape == expected_shape
875+
876+
# Test when _scale is provided with correct 3D shape
877+
quantizer._scale = torch.randint(0, 255, expected_shape, dtype=torch.uint8, device=device)
878+
scale_from_quantizer = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(
879+
weight, quantizer
880+
)
881+
assert torch.equal(scale_from_quantizer, quantizer._scale)
882+
883+
@pytest.mark.parametrize("device", ["cuda"])
884+
def test_mxfp8_get_weights_scaling_factor_from_quantizer_scale_shape_mismatch(self, device):
885+
"""Test get_weights_scaling_factor_from_quantizer raises assertion on shape mismatch."""
886+
input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim)
887+
weight = torch.randn(input_shape, dtype=torch.float32, device=device)
888+
889+
class MockQuantizer:
890+
block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE}
891+
# Wrong shape: 2D instead of 3D (missing num_experts dimension)
892+
_scale = torch.randint(
893+
0, 255, (64, 4), dtype=torch.uint8, device=device
894+
)
895+
896+
quantizer = MockQuantizer()
897+
898+
with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"):
899+
MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer)

0 commit comments

Comments
 (0)