Skip to content

Commit 8a2c936

Browse files
committed
process dummy bf16 trtllm moe weights
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent 5b0a3fb commit 8a2c936

1 file changed

Lines changed: 21 additions & 0 deletions

File tree

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ class FusedMoEMethodBase(ABC):
251251
to work with online EPLB should override this to SUPPORTED; those that
252252
have not yet been tested may set it to NOT_VERIFIED.
253253
"""
254+
needs_post_load_processing_for_dummy: bool = False
255+
"""Whether LoadFormat.DUMMY must finish weight processing in post_load_weights()."""
254256

255257
@classmethod
256258
def supports_online_eplb(cls) -> bool:
@@ -326,6 +328,8 @@ def create_weights(
326328
module.w2_bias = None
327329

328330
module.rebuild_tensor_metadata = {}
331+
module._needs_post_load_weight_processing = True
332+
module._weights_loaded_via_load_weights = False
329333

330334
def load_expert_weights_to_dst(
331335
self,
@@ -437,6 +441,7 @@ def load_weights(self,
437441
weights: List[Dict],
438442
weight_loading_mode: MoEWeightLoadingMode,
439443
allow_partial_loading: bool = False):
444+
module._weights_loaded_via_load_weights = True
440445
if allow_partial_loading:
441446
if not isinstance(self,
442447
(UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod,
@@ -519,8 +524,23 @@ def load_weights(self,
519524

520525
if not allow_partial_loading:
521526
self.process_weights_after_loading(module)
527+
module._needs_post_load_weight_processing = False
522528

523529
def post_load_weights(self, module: torch.nn.Module):
530+
# LoadFormat.DUMMY initializes parameters in-place without calling
531+
# load_weights(), so only methods that explicitly opt in should finish
532+
# their processing here unless load_weights() left work unfinished.
533+
needs_post_load_processing = getattr(
534+
module, "_needs_post_load_weight_processing", True)
535+
loaded_via_load_weights = getattr(module,
536+
"_weights_loaded_via_load_weights",
537+
False)
538+
if needs_post_load_processing and (
539+
loaded_via_load_weights
540+
or self.needs_post_load_processing_for_dummy):
541+
self.process_weights_after_loading(module)
542+
module._needs_post_load_weight_processing = False
543+
524544
if self.need_load_shared_weights(module):
525545
weight_fns = {
526546
'w3_w1_weight': getattr(module, 'local_shared_w3_w1_tensors'),
@@ -676,6 +696,7 @@ class BF16TRTLLMGenFusedMoEMethod(UnquantizedFusedMoEMethod):
676696
block_k = 64
677697
use_shuffled_weight = True
678698
weight_layout = TRTLLM_GEN_WEIGHT_LAYOUT_BLOCK_MAJOR_K
699+
needs_post_load_processing_for_dummy = True
679700
_cache_permute_indices: Dict[tuple[tuple[int, ...], str, int],
680701
torch.Tensor] = {}
681702

0 commit comments

Comments
 (0)