Skip to content

Commit 2ca7dbe

Browse files
committed
[TRTLLM-13250][fix] Address Wave 5 review findings
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent ba9ea82 commit 2ca7dbe

23 files changed

Lines changed: 128 additions & 48 deletions

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,14 +1518,14 @@ def __init__(self,
15181518
# attribute queries do not end up frozen into a captured graph.
15191519
warmup_heuristic_topk_decode(top_k=self.index_topk)
15201520

1521-
def cache_derived_state(self):
1521+
def cache_derived_state(self) -> None:
15221522
"""Fuse wk + weights_proj into single FP32 weight for F.linear GEMM under allow_tf32 (TF32 tensor cores on Ampere+)."""
15231523
# wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size]
15241524
# → fused: [head_dim + n_heads, hidden_size]
15251525
self._fused_wk_wp_weight = torch.cat(
15261526
[self.wk.weight.data, self.weights_proj.weight.data], dim=0)
15271527

1528-
def post_load_weights(self):
1528+
def post_load_weights(self) -> None:
15291529
self.cache_derived_state()
15301530

15311531
@staticmethod

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19211921
weight_loader = DeepseekV3WeightLoader(self)
19221922
weight_loader.load_weights(weights)
19231923

1924-
def setup_aliases(self):
1924+
def setup_aliases(self) -> None:
19251925
for idx, layer in enumerate(
19261926
self.model.layers[:self.config.num_hidden_layers]):
19271927
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(
725725
allow_partial_loading=allow_partial_loading,
726726
)
727727

728-
def setup_aliases(self):
728+
def setup_aliases(self) -> None:
729729
# For the cross-layer residual+LN fusion.
730730
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
731731
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def setup_aliases(self):
1077+
def setup_aliases(self) -> None:
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def setup_aliases(self):
634+
def setup_aliases(self) -> None:
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def setup_aliases(self):
1143+
def setup_aliases(self) -> None:
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def setup_aliases(self):
1567+
def setup_aliases(self) -> None:
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(self,
323323

324324
# After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
325325
# the trtllm-gen gemm+swiglu kernel.
326-
def cache_derived_state(self):
326+
def cache_derived_state(self) -> None:
327327
if self.gate_up_proj.has_fp8_qdq:
328328
# For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
329329
# of down_proj's combined input scale.
@@ -332,7 +332,7 @@ def cache_derived_state(self):
332332
# combined input scale times inv_output_scale.
333333
self.gate_up_proj.trtllm_gen_global_scale = self.gate_up_proj.combined_scale * self.gate_up_proj.inv_output_scale
334334

335-
def post_load_weights(self):
335+
def post_load_weights(self) -> None:
336336
self.cache_derived_state()
337337

338338
def forward(
@@ -584,7 +584,7 @@ def __init__(
584584
dtype=model_config.pretrained_config.torch_dtype,
585585
quant_config=None)
586586

587-
def cache_derived_state(self):
587+
def cache_derived_state(self) -> None:
588588
# Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
589589
# This is because the routed experts' input scale is after the score multiplication, so we must use the
590590
# pre-score scaling input scale, which happens to be shared expert's input scale.
@@ -600,7 +600,7 @@ def cache_derived_state(self):
600600
fc1_input_dequant=pre_score_scaling_input_scale,
601601
)
602602

603-
def post_load_weights(self):
603+
def post_load_weights(self) -> None:
604604
self.cache_derived_state()
605605

606606
def compute_routed_output(

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def setup_aliases(self):
420+
def setup_aliases(self) -> None:
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
980980
new_weights = weight_mapper.preprocess_weights(weights)
981981
super().load_weights(new_weights, weight_mapper)
982982

983-
def setup_aliases(self):
983+
def setup_aliases(self) -> None:
984984
for idx, layer in enumerate(
985985
self.model.layers[:self.config.num_hidden_layers]):
986986
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,10 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False)
650650
assert hasattr(self.backend, "load_weights"), (
651651
f"Backend {self.backend.__class__.__name__} must implement load_weights()"
652652
)
653+
self._weights_transformed = False
653654
return self.backend.load_weights(weights, allow_partial_loading)
654655

655-
def transform_weights(self):
656+
def transform_weights(self) -> None:
656657
"""
657658
Transform weights - delegated to backend
658659
@@ -665,17 +666,17 @@ def transform_weights(self):
665666
self.backend.transform_weights()
666667
self._weights_transformed = True
667668

668-
def cache_derived_state(self):
669+
def cache_derived_state(self) -> None:
669670
"""
670671
Cache derived state - delegated to backend
671672
672673
"""
673674
assert hasattr(self.backend, "cache_derived_state"), (
674675
f"Backend {self.backend.__class__.__name__} must implement cache_derived_state()"
675676
)
676-
return self.backend.cache_derived_state()
677+
self.backend.cache_derived_state()
677678

678-
def post_load_weights(self):
679+
def post_load_weights(self) -> None:
679680
"""
680681
Backward-compatible staged post-load processing - delegated to backend
681682

0 commit comments

Comments
 (0)