@@ -71,9 +71,12 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor:
7171
7272 Args:
7373 weight: The weight tensor to compute scale for. Must be at least 2D.
74+ Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim).
7475
7576 Returns:
7677 torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
78+ For 2D input: (out_dim, in_dim // 32)
79+ For 3D MoE input: (num_experts, out_dim, in_dim // 32)
7780 """
7881 assert weight .dim () >= 2 , f"Weight must be at least 2D, got { weight .dim ()} D"
7982
@@ -83,7 +86,7 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor:
8386 f"Weight inner dimension ({ in_dim } ) must be divisible by MXFP8 block size ({ cls .BLOCK_SIZE } )"
8487 )
8588
86- # Compute amax per block (reduce_block_amax handles reshaping internally )
89+ # Compute amax per block (reduce_block_amax handles N-dimensional tensors )
8790 amax = reduce_block_amax (weight , block_sizes = {- 1 : cls .BLOCK_SIZE })
8891
8992 # Compute E8M0 exponent and convert to biased uint8 (bias = 127)
@@ -102,11 +105,12 @@ def get_weights_scaling_factor_from_quantizer(
102105 with proper format conversion and shape correction.
103106
104107 Args:
105- weight: The weight tensor.
108+ weight: The weight tensor. Can be 2D (out_dim, in_dim) or
109+ 3D for MoE (num_experts, out_dim, in_dim).
106110 weight_quantizer: The weight quantizer with block_sizes and optional _scale.
107111
108112 Returns:
109- torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // 32].
113+ torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
110114 """
111115 assert hasattr (weight_quantizer , "block_sizes" ), (
112116 "weight_quantizer must have 'block_sizes' attribute"
@@ -116,8 +120,11 @@ def get_weights_scaling_factor_from_quantizer(
116120 )
117121 assert weight .dim () >= 2 , f"Weight must be at least 2D, got { weight .dim ()} D"
118122
119- out_dim , in_dim = weight .shape [- 2 ], weight .shape [- 1 ]
120- expected_shape = (out_dim , in_dim // cls .BLOCK_SIZE )
123+ in_dim = weight .shape [- 1 ]
124+ # Expected scale shape: all dims except last, with last dim reduced by block size
125+ # For 2D: (out_dim, in_dim // 32)
126+ # For 3D MoE: (num_experts, out_dim, in_dim // 32)
127+ expected_shape = (* weight .shape [:- 1 ], in_dim // cls .BLOCK_SIZE )
121128
122129 if hasattr (weight_quantizer , "_scale" ) and weight_quantizer ._scale is not None :
123130 scale = weight_quantizer ._scale
@@ -127,11 +134,16 @@ def get_weights_scaling_factor_from_quantizer(
127134 )
128135
129136 # Reshape if needed (same number of elements but wrong shape)
130- if (
131- scale .shape != expected_shape
132- and scale .numel () == expected_shape [0 ] * expected_shape [1 ]
133- ):
134- scale = scale .reshape (expected_shape )
137+ if scale .shape != expected_shape :
138+ expected_numel = 1
139+ for dim in expected_shape :
140+ expected_numel *= dim
141+ if scale .numel () == expected_numel :
142+ scale = scale .reshape (expected_shape )
143+
144+ assert scale .shape == expected_shape , (
145+ f"Scale shape { scale .shape } does not match expected shape { expected_shape } "
146+ )
135147 return scale
136148
137149 # No scale in quantizer, compute from weight
0 commit comments