@@ -176,6 +176,19 @@ def maybe_pad_for_mxfp4(weight: torch.Tensor,
176176 return weight
177177
178178
179+ def _pad_tensor_to_shape (tensor : torch .Tensor , shape : tuple ) -> torch .Tensor :
180+ """Pad tensor to match target shape. Used for post-shard alignment."""
181+ if tensor .numel () == 0 :
182+ return tensor
183+ if tensor .shape == shape :
184+ return tensor
185+ if len (tensor .shape ) == 1 :
186+ return F .pad (tensor , (0 , shape [0 ] - tensor .shape [0 ])).contiguous ()
187+ row_pad = shape [0 ] - tensor .shape [0 ]
188+ col_pad = shape [1 ] - tensor .shape [1 ]
189+ return F .pad (tensor , (0 , col_pad , 0 , row_pad )).contiguous ()
190+
191+
179192def interleave_linear_and_gate (x : torch .Tensor ,
180193 group_size : int = 64 ,
181194 dim : int = - 1 ) -> torch .Tensor :
@@ -2915,6 +2928,9 @@ def round_up(x, alignment):
29152928 return (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape ,
29162929 w2_bias_shape , w3_w1_weight_scale_shape , w2_weight_scale_shape )
29172930
2931+ def _round_up (self , x , alignment ):
2932+ return (x + alignment - 1 ) // alignment * alignment
2933+
29182934 def create_weights (self , module : torch .nn .Module ):
29192935 # Here we only enable padding for hidden_size > 1024 since there are small unit tests that expect no padding.
29202936 if module .hidden_size > 1024 and module .hidden_size % 256 != 0 :
@@ -2923,6 +2939,15 @@ def create_weights(self, module: torch.nn.Module):
29232939 # See the comment in MXFP4WeightTRTLLMGenFusedMoEMethod for more details.
29242940 self .input_hidden_alignment = 256
29252941
2942+ else :
2943+ # Weight scales require M % 128 in get_shuffle_matrix_sf_a_row_indices.
2944+ # Check if intermediate_size after padding satisfies this requirement.
2945+ # If not, set weight_alignment to 128.
2946+ intermediate_size_padded = self ._round_up (
2947+ module .intermediate_size_per_partition , self .weight_alignment )
2948+ if intermediate_size_padded % 128 != 0 :
2949+ self .weight_alignment = 128
2950+
29262951 super ().create_weights (module , bias_dtype = torch .float32 )
29272952
29282953 def setup_quant_scales (self , module : torch .nn .Module ):
@@ -2981,6 +3006,8 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
29813006 dst_w3_weight .copy_ (w3_weight_shard .view (dst_w3_weight .dtype ))
29823007 dst_w1_weight .copy_ (w1_weight_shard .view (dst_w1_weight .dtype ))
29833008 else :
3009+ w1_weight_shard = _pad_tensor_to_shape (w1_weight_shard ,
3010+ dst_w3_w1_weight_gpu .shape )
29843011 dst_w3_w1_weight_gpu .copy_ (
29853012 w1_weight_shard .view (dst_w3_w1_weight_gpu .dtype ))
29863013
@@ -3038,6 +3065,8 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
30383065 epilogue_tile_m = 128
30393066
30403067 # Keep weights in device buffer
3068+ w2_weight_shard = _pad_tensor_to_shape (w2_weight_shard ,
3069+ dst_w2_weight_gpu .shape )
30413070 dst_w2_weight_gpu .copy_ (w2_weight_shard .view (dst_w2_weight_gpu .dtype ),
30423071 non_blocking = dst_on_gpu )
30433072 # Get permuted indices
@@ -3071,7 +3100,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
30713100 alignment = _get_weight_alignment (self .weight_alignment ,
30723101 module .scaling_vector_size ,
30733102 module .tp_size ,
3074- w3_weight_scale .shape [0 ])
3103+ w1_weight_scale .shape [0 ])
30753104 w1_weight_scale = maybe_pad_for_mxfp4 (
30763105 w1_weight_scale ,
30773106 self .input_hidden_alignment // module .scaling_vector_size ,
@@ -3113,6 +3142,8 @@ def load_expert_w3_w1_weight_scale_nvfp4(
31133142 w1_weight_scale .view (dst_w1_weight_scale .dtype ))
31143143 else :
31153144 # Non-gated activation (e.g., ReLU2): buffer only contains w1 scale
3145+ w1_weight_scale = _pad_tensor_to_shape (
3146+ w1_weight_scale , dst_w3_w1_weight_scale_gpu .shape )
31163147 dst_w3_w1_weight_scale_gpu .copy_ (
31173148 w1_weight_scale .view (dst_w3_w1_weight_scale_gpu .dtype ))
31183149
@@ -3170,6 +3201,8 @@ def load_expert_w2_weight_scale_nvfp4(self,
31703201 TensorParallelMode .ROW ,
31713202 device = device )
31723203 # Keep weights in device buffer
3204+ w2_weight_scale = _pad_tensor_to_shape (w2_weight_scale ,
3205+ dst_w2_weight_scale_gpu .shape )
31733206 dst_w2_weight_scale_gpu .copy_ (
31743207 w2_weight_scale .view (dst_w2_weight_scale_gpu .dtype ))
31753208
0 commit comments