Skip to content

Commit fc87519

Browse files
committed
cleanup
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
1 parent 2e47843 commit fc87519

File tree

4 files changed

+15
-21
lines changed

4 files changed

+15
-21
lines changed

tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from torch import nn
2-
31
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
42
from tensorrt_llm._torch.models.modeling_utils import register_mapper
53

@@ -92,20 +90,14 @@ def permute(w, n_heads: int, head_dim: int, hidden_size: int):
9290
# If using quantized model in mistral format,
9391
# quantization scales (qscale_weight) also need to be sliced
9492
for name in weights.keys():
95-
# TODO: add scales if dequant is necessary
93+
# TODO: add scales if dequant is necessary
9694
if ".wq.weight" in name:
9795
weights[name] = permute(
98-
weights[name],
99-
config.num_attention_heads,
100-
config.head_dim,
101-
config.hidden_size
96+
weights[name], config.num_attention_heads, config.head_dim, config.hidden_size
10297
)
10398
elif ".wk.weight" in name:
10499
weights[name] = permute(
105-
weights[name],
106-
config.num_key_value_heads,
107-
config.head_dim,
108-
config.hidden_size
100+
weights[name], config.num_key_value_heads, config.head_dim, config.hidden_size
109101
)
110102
return weights
111103

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,13 +623,14 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs):
623623

624624
llm_weights = filter_weights(weights=weights, prefix="language_model")
625625
logger.debug(f"Loading weights for {type(self.llm)}")
626-
if weight_mapper:
627-
weight_mapper.permute_qk(weights=llm_weights, config=self.llm.config)
628-
self.llm.load_weights(llm_weights,
629-
weight_mapper=weight_mapper,
630-
params_map=weight_mapper.mistral_llm_mapping)
631-
else:
632-
self.llm.load_weights(llm_weights)
626+
if weight_mapper:
627+
weight_mapper.permute_qk(weights=llm_weights,
628+
config=self.llm.config)
629+
self.llm.load_weights(llm_weights,
630+
weight_mapper=weight_mapper,
631+
params_map=weight_mapper.mistral_llm_mapping)
632+
else:
633+
self.llm.load_weights(llm_weights)
633634
logger.debug(f"Successfully loaded weights for {type(self.llm)}")
634635

635636
vit_weights = filter_weights(weights=weights, prefix="vision_tower")
@@ -638,7 +639,8 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs):
638639
if vit_params_map is not None:
639640
# Pixtral uses num_attention_heads = num_key_value_heads
640641
self._vision_tower.config.num_key_value_heads = self._vision_tower.config.num_attention_heads
641-
weight_mapper.permute_qk(weights=vit_weights, config=self._vision_tower.config)
642+
weight_mapper.permute_qk(weights=vit_weights,
643+
config=self._vision_tower.config)
642644
vit_weights = weight_mapper.rename_by_params_map(
643645
weights=vit_weights, params_map=vit_params_map)
644646

tensorrt_llm/_torch/pyexecutor/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def load_pretrained_config(model_name_or_path: str,
283283

284284
elif model_type == "mistral3" and "layer_types" in config_dict:
285285
# TODO: update this for transformers v5.0
286-
config_class = "MinistralConfig"
286+
config_class = "MinistralConfig"
287287
model_config = config_class.from_pretrained(model_name_or_path,
288288
**kwargs)
289289

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _update_from_hf_quant_config(self) -> bool:
430430

431431
if hf_quant_config is not None:
432432
# DeepSeek V3 FP8 ckpt
433-
if hf_quant_config.get("quant_method") == "fp8":
433+
if hf_quant_config.get("quant_method") == "fp8":
434434
if hf_quant_config.get("weight_block_size"):
435435
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
436436
quant_config.exclude_modules = ["*eh_proj"]

0 commit comments

Comments
 (0)