Skip to content

Commit b2646d3

Browse files
committed
[TRTLLM-13248][feat] Wave 3 migrate MoE staged hooks
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent d011a67 commit b2646d3

16 files changed

Lines changed: 302 additions & 47 deletions

File tree

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
"""Dense Sparse Attention (DSA) backend for TRT-LLM with indexer-based TopK selection."""
216
import math
317
import threading
@@ -1504,13 +1518,16 @@ def __init__(self,
15041518
# attribute queries do not end up frozen into a captured graph.
15051519
warmup_heuristic_topk_decode(top_k=self.index_topk)
15061520

1507-
def post_load_weights(self):
1521+
def cache_derived_state(self):
15081522
"""Fuse wk + weights_proj into single FP32 weight for F.linear GEMM under allow_tf32 (TF32 tensor cores on Ampere+)."""
15091523
# wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size]
15101524
# → fused: [head_dim + n_heads, hidden_size]
15111525
self._fused_wk_wp_weight = torch.cat(
15121526
[self.wk.weight.data, self.weights_proj.weight.data], dim=0)
15131527

1528+
def post_load_weights(self):
1529+
self.cache_derived_state()
1530+
15141531
@staticmethod
15151532
def prepare_one_prefill_chunk(
15161533
metadata: DSAtrtllmAttentionMetadata,
@@ -2404,7 +2421,7 @@ def pre_indexer_proj(
24042421
split in MLA.forward_dsa_proj sees a stable signature.
24052422
"""
24062423
assert self._fused_wk_wp_weight is not None, \
2407-
"post_load_weights() must be called before forward()"
2424+
"cache_derived_state() must be called before forward()"
24082425
hidden_float = _to_float(hidden_states)
24092426
with _tf32_matmul_enabled():
24102427
# F.linear computes input @ weight.T internally; no explicit .t() needed.

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
from collections.abc import Callable
217
from typing import Dict, List, Optional, Tuple, Union
318

@@ -308,7 +323,7 @@ def __init__(self,
308323

309324
# After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
310325
# the trtllm-gen gemm+swiglu kernel.
311-
def post_load_weights(self):
326+
def cache_derived_state(self):
312327
if self.gate_up_proj.has_fp8_qdq:
313328
# For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
314329
# of down_proj's combined input scale.
@@ -317,6 +332,9 @@ def post_load_weights(self):
317332
# combined input scale times inv_output_scale.
318333
self.gate_up_proj.trtllm_gen_global_scale = self.gate_up_proj.combined_scale * self.gate_up_proj.inv_output_scale
319334

335+
def post_load_weights(self):
336+
self.cache_derived_state()
337+
320338
def forward(
321339
self,
322340
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -566,7 +584,7 @@ def __init__(
566584
dtype=model_config.pretrained_config.torch_dtype,
567585
quant_config=None)
568586

569-
def post_load_weights(self):
587+
def cache_derived_state(self):
570588
# Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
571589
# This is because the routed experts' input scale is after the score multiplication, so we must use the
572590
# pre-score scaling input scale, which happens to be shared expert's input scale.
@@ -582,6 +600,9 @@ def post_load_weights(self):
582600
fc1_input_dequant=pre_score_scaling_input_scale,
583601
)
584602

603+
def post_load_weights(self):
604+
self.cache_derived_state()
605+
585606
def compute_routed_output(
586607
self,
587608
hidden_states,

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,36 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False)
652652
)
653653
return self.backend.load_weights(weights, allow_partial_loading)
654654

655-
def post_load_weights(self):
655+
def transform_weights(self):
656+
"""
657+
Transform weights - delegated to backend
658+
659+
"""
660+
if getattr(self, "_weights_transformed", False):
661+
return
662+
assert hasattr(self.backend, "transform_weights"), (
663+
f"Backend {self.backend.__class__.__name__} must implement transform_weights()"
664+
)
665+
self.backend.transform_weights()
666+
self._weights_transformed = True
667+
668+
def cache_derived_state(self):
656669
"""
657-
Post load weights processing - delegated to backend
670+
Cache derived state - delegated to backend
658671
659672
"""
660-
assert hasattr(self.backend, "post_load_weights"), (
661-
f"Backend {self.backend.__class__.__name__} must implement post_load_weights()"
673+
assert hasattr(self.backend, "cache_derived_state"), (
674+
f"Backend {self.backend.__class__.__name__} must implement cache_derived_state()"
662675
)
663-
return self.backend.post_load_weights()
676+
return self.backend.cache_derived_state()
677+
678+
def post_load_weights(self):
679+
"""
680+
Backward-compatible staged post-load processing - delegated to backend
681+
682+
"""
683+
self.transform_weights()
684+
self.cache_derived_state()
664685

665686
def process_weights_after_loading(self):
666687
"""

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class CuteDslB12xFusedMoE(CuteDslFusedMoE):
6868
``_get_quant_method``). The inherited CUTLASS NVFP4 layout is finalised
6969
by the base class, and the b12x-shaped tensors (un-normalised FP8 SF,
7070
``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instance) are
71-
materialised on top by the quant method's ``post_load_weights``. Both
71+
materialised on top by the quant method's ``transform_weights``. Both
7272
layouts coexist in memory and the dispatcher picks per call based on
7373
``x.shape[0]``.
7474
@@ -173,7 +173,7 @@ def _route_to_cutlass(self, x) -> bool:
173173
return isinstance(x, torch.Tensor) and x.shape[0] >= self._PREFILL_VIA_CUTLASS_THRESHOLD
174174

175175
# ``post_load_weights`` is inherited from ``CutlassFusedMoE`` and
176-
# dispatches to ``self.quant_method.post_load_weights(self)`` — for this
176+
# dispatches to ``self.quant_method.transform_weights(self)`` — for this
177177
# backend ``self.quant_method`` is ``NVFP4CuteDslB12xFusedMoEMethod``
178178
# (see ``_get_quant_method`` override), which performs the SF un-normalization,
179179
# ``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instantiation,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import inspect
217
import os
318
from functools import cached_property
@@ -1577,4 +1592,5 @@ def load_weights(self,
15771592
**kargs)
15781593

15791594
def post_load_weights(self):
1580-
self.quant_method.post_load_weights(self)
1595+
self.transform_weights()
1596+
self.cache_derived_state()

tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False)
291291
)
292292

293293
def post_load_weights(self):
294-
self.quant_method.post_load_weights(self)
294+
self.transform_weights()
295+
self.cache_derived_state()
295296

296297
def _transform_w2_weight_scale_for_min_latency(self):
297298
"""Transform w2_weight_scale for minimum latency path optimization."""

tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ def _maybe_remove_padding(gemm_output, expected_size):
13681368

13691369
return gemm2_output
13701370

1371-
def post_load_weights(self, module: torch.nn.Module):
1371+
def transform_weights(self, module: torch.nn.Module):
13721372
if 'w3_w1_weight' in module._parameters:
13731373
w31_scale = shuffle_weight_for_activation_kernel(
13741374
module.fc31_dequant.data)
@@ -1382,7 +1382,7 @@ def post_load_weights(self, module: torch.nn.Module):
13821382
module.fc31_input_dequant = None
13831383
module.fc2_input_dequant = None
13841384

1385-
super().post_load_weights(module)
1385+
super().transform_weights(module)
13861386

13871387

13881388
class TritonFusedMoE(MoE):
@@ -1586,4 +1586,5 @@ def load_weights(self,
15861586
self.quant_method.load_weights(self, weights, self.weight_loading_mode)
15871587

15881588
def post_load_weights(self):
1589-
self.quant_method.post_load_weights(self)
1589+
self.transform_weights()
1590+
self.cache_derived_state()

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ def load_weights(self,
525525
**kargs)
526526

527527
def post_load_weights(self):
528-
self.quant_method.post_load_weights(self)
528+
self.transform_weights()
529+
self.cache_derived_state()
529530

530531
def quantize_input(self, x, post_quant_comm: bool = True):
531532
"""Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import inspect
217
import os
318
from typing import Dict, List, Optional, Tuple, Union
@@ -950,7 +965,8 @@ def load_weights(self,
950965
**kargs)
951966

952967
def post_load_weights(self):
953-
self.quant_method.post_load_weights(self)
968+
self.transform_weights()
969+
self.cache_derived_state()
954970

955971
def forward_fake(
956972
self,

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,18 @@ def load_weights(self,
827827
"""
828828
raise NotImplementedError
829829

830+
def transform_weights(self):
831+
if getattr(self, "_weights_transformed", False):
832+
return
833+
self.quant_method.transform_weights(self)
834+
self._weights_transformed = True
835+
836+
def cache_derived_state(self):
837+
self.quant_method.cache_derived_state(self)
838+
830839
def post_load_weights(self):
831-
pass
840+
self.transform_weights()
841+
self.cache_derived_state()
832842

833843
def process_weights_after_loading(self):
834844
"""

0 commit comments

Comments
 (0)