@@ -380,9 +380,12 @@ def load_weights(self,
380380 if not allow_partial_loading :
381381 self .process_weights_after_loading (module )
382382
383- def post_load_weights (self , module : Linear ):
383+ def transform_weights (self , module : Linear ) -> None :
384384 pass
385385
386+ def post_load_weights (self , module : Linear ) -> None :
387+ self .transform_weights (module )
388+
386389 def load_weight_scales (self , weights : List [Dict ], * args , ** kwargs ):
387390 """
388391 Load quantized weight scales from the checkpoint.
@@ -1241,8 +1244,8 @@ def load_weights_fused_gate_up_linear(
12411244 copy_weight_shard (module .weight_scale , scale , shard_offset ,
12421245 shard_size )
12431246
1244- def post_load_weights (self , module : Linear ):
1245- super ().post_load_weights (module )
1247+ def transform_weights (self , module : Linear ) -> None :
1248+ super ().transform_weights (module )
12461249 if (is_sm_100f () and not (module .use_cute_dsl_blockscaling_mm
12471250 or module .disable_deep_gemm )) or \
12481251 get_sm_version () == 120 :
@@ -1821,9 +1824,9 @@ def process_weights_after_loading_fused_gate_up_linear(
18211824 torch .ops .trtllm .block_scale_interleave (ws_swapped ),
18221825 requires_grad = False )
18231826
1824- def post_load_weights (self , module : Linear ):
1827+ def transform_weights (self , module : Linear ) -> None :
18251828 """Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements."""
1826- super ().post_load_weights (module )
1829+ super ().transform_weights (module )
18271830 row_alignment , col_alignment = 32 , 16
18281831 row_pad_size = (row_alignment - module .weight .size (0 )) % row_alignment
18291832 col_pad_size = (col_alignment - module .weight .size (1 )) % col_alignment
@@ -1873,10 +1876,10 @@ class W4A16NVFP4LinearMethod(NVFP4LinearMethod):
18731876 its fused path is SM>=100-gated upstream.
18741877 """
18751878
1876- def post_load_weights (self , module : Linear ):
1879+ def transform_weights (self , module : Linear ) -> None :
18771880 # Skip parent's 32x16 weight padding (apply() accepts [N, K/2] as-is)
18781881 # and un-swizzle per-block scale once at load.
1879- LinearMethodBase .post_load_weights (self , module )
1882+ LinearMethodBase .transform_weights (self , module )
18801883 pad_rows = fp4_utils .pad_up (module .out_features , 128 )
18811884 pad_cols = fp4_utils .pad_up (
18821885 module .in_features // module .scaling_vector_size , 4 )
@@ -2914,6 +2917,7 @@ def __init__(
29142917 dtype = self .dtype ) if reduce_output else None
29152918
29162919 self ._weights_created = False
2920+ self ._weights_transformed = False
29172921 self .reduce_output = reduce_output
29182922 self .use_custom_cublas_mm = use_custom_cublas_mm
29192923 self .use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm
@@ -2966,6 +2970,7 @@ def create_weights(self):
29662970 self .dtype )
29672971
29682972 self ._weights_created = True
2973+ self ._weights_transformed = False
29692974
29702975 @property
29712976 def has_any_quant (self ):
@@ -3127,6 +3132,7 @@ def load_weights(self,
31273132 assert allow_partial_loading is False , (
31283133 f"{ type (self .quant_method ).__name__ } does not support "
31293134 "allow_partial_loading" )
3135+ self ._weights_transformed = False
31303136 self .quant_method .load_weights (
31313137 self ,
31323138 weights ,
@@ -3136,8 +3142,14 @@ def load_weights(self,
31363142 def process_weights_after_loading (self ):
31373143 self .quant_method .process_weights_after_loading (self )
31383144
3139- def post_load_weights (self ):
3140- self .quant_method .post_load_weights (self )
3145+ def transform_weights (self ) -> None :
3146+ if self ._weights_transformed :
3147+ return
3148+ self .quant_method .transform_weights (self )
3149+ self ._weights_transformed = True
3150+
3151+ def post_load_weights (self ) -> None :
3152+ self .transform_weights ()
31413153
31423154 def pre_reload_weights (self ):
31433155 assert hasattr (
0 commit comments