@@ -251,6 +251,8 @@ class FusedMoEMethodBase(ABC):
251251 to work with online EPLB should override this to SUPPORTED; those that
252252 have not yet been tested may set it to NOT_VERIFIED.
253253 """
254+ needs_post_load_processing_for_dummy : bool = False
255+ """Whether LoadFormat.DUMMY must finish weight processing in post_load_weights()."""
254256
255257 @classmethod
256258 def supports_online_eplb (cls ) -> bool :
@@ -326,6 +328,8 @@ def create_weights(
326328 module .w2_bias = None
327329
328330 module .rebuild_tensor_metadata = {}
331+ module ._needs_post_load_weight_processing = True
332+ module ._weights_loaded_via_load_weights = False
329333
330334 def load_expert_weights_to_dst (
331335 self ,
@@ -437,6 +441,7 @@ def load_weights(self,
437441 weights : List [Dict ],
438442 weight_loading_mode : MoEWeightLoadingMode ,
439443 allow_partial_loading : bool = False ):
444+ module ._weights_loaded_via_load_weights = True
440445 if allow_partial_loading :
441446 if not isinstance (self ,
442447 (UnquantizedFusedMoEMethod , FP8QDQFusedMoEMethod ,
@@ -519,8 +524,23 @@ def load_weights(self,
519524
520525 if not allow_partial_loading :
521526 self .process_weights_after_loading (module )
527+ module ._needs_post_load_weight_processing = False
522528
523529 def post_load_weights (self , module : torch .nn .Module ):
530+ # LoadFormat.DUMMY initializes parameters in-place without calling
531+ # load_weights(), so only methods that explicitly opt in should finish
532+ # their processing here unless load_weights() left work unfinished.
533+ needs_post_load_processing = getattr (
534+ module , "_needs_post_load_weight_processing" , True )
535+ loaded_via_load_weights = getattr (module ,
536+ "_weights_loaded_via_load_weights" ,
537+ False )
538+ if needs_post_load_processing and (
539+ loaded_via_load_weights
540+ or self .needs_post_load_processing_for_dummy ):
541+ self .process_weights_after_loading (module )
542+ module ._needs_post_load_weight_processing = False
543+
524544 if self .need_load_shared_weights (module ):
525545 weight_fns = {
526546 'w3_w1_weight' : getattr (module , 'local_shared_w3_w1_tensors' ),
@@ -676,6 +696,7 @@ class BF16TRTLLMGenFusedMoEMethod(UnquantizedFusedMoEMethod):
676696 block_k = 64
677697 use_shuffled_weight = True
678698 weight_layout = TRTLLM_GEN_WEIGHT_LAYOUT_BLOCK_MAJOR_K
699+ needs_post_load_processing_for_dummy = True
679700 _cache_permute_indices : Dict [tuple [tuple [int , ...], str , int ],
680701 torch .Tensor ] = {}
681702
0 commit comments