diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index f47e77a81661..3d584fcb0cff 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -26,14 +26,17 @@ class HfWeightLoader(BaseWeightLoader): Loads weights from SafeTensors/bin/pth files. """ - def load_weights(self, checkpoint_dir: str, - mapping: Mapping) -> dict[str, Any]: + def load_weights(self, + checkpoint_dir: str, + mapping: Mapping, + use_consolidated: bool = False) -> dict[str, Any]: weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") # Some model checkpoint directories contain not only the sharded safetensors, but one - # consolidated tensor. In the presence of both, we favor the former, as there really is no need + # consolidated tensor. In the presence of both, we favor the former unless specified explicitly, as there really is no need # to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case. filtered_weight_files = [ - x for x in weight_files if "consolidated" not in os.path.split(x)[1] + x for x in weight_files + if ("consolidated" in os.path.split(x)[1]) == use_consolidated ] if len(filtered_weight_files) > 0: weight_files = filtered_weight_files diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py index 116fce261ac9..e43e1142abb0 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py @@ -64,7 +64,10 @@ def inverse_nvfp4_global_scales(self, weights): weights[key] = 1.0 / weights[key] def load_weights(self, checkpoint_dir: str, **kwargs): - weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs) + # Mistral native weight mapping is different from HF and stored in the .consolidated tensor + weights = super().weight_loader.load_weights( + checkpoint_dir, use_consolidated=True, **kwargs + ) weights = self.preprocess_weights(weights) self.broadcast_per_tensor_scales(weights) # The definition of global_scale is different in Mistral, need to inverse the scale diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py index dbc7abf73020..1b566a812ae0 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -348,5 +348,6 @@ def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: model_config.pretrained_config.gate_cls = Mistral3Gate model_config.pretrained_config.input_processor_type = "mistral_large_3" + model_config.pretrained_config.model_type = "mistral_large_3" model_config._frozen = True return model_config diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py index 035b539639fc..dd6e0332b849 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py @@ -1,5 +1,3 @@ -from torch import nn - from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper from tensorrt_llm._torch.models.modeling_utils import register_mapper @@ -10,8 +8,6 @@ class MistralWeightMapper(HfWeightMapper): def __init__(self): super().__init__() - self._callbacks.append(self._permute_qk) - self.pixtral_mapping = { "wq": "q_proj", "wk": "k_proj", @@ -31,8 +27,8 @@ def __init__(self): "qscale_weight": "weight_scale_inv", "kv_fake_quantizer.qscale_act": "kv_scale", "q_fake_quantizer.qscale_act": "attn.q_scale", - "k_fake_quantizer.qscale_act": "k_scale", - "v_fake_quantizer.qscale_act": "v_scale", + "k_fake_quantizer.qscale_act": "attn.k_scale", + "v_fake_quantizer.qscale_act": "attn.v_scale", "attention_norm": "input_layernorm", "feed_forward": "mlp", "ffn_norm": "post_attention_layernorm", @@ -78,38 +74,31 @@ def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dic return ConsumableWeightsDict(renamed_weights) return renamed_weights - def _permute_qk(self, module: nn.Module, new_name: str, weights: dict): + def permute_qk(self, weights: dict, config: dict): # Adapted from: # https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657 - processed_weights = {} - config = self.config.pretrained_config - - def permute(w, n_heads: int, attn_out: int): - attn_in = config.head_dim * n_heads - + def permute(w, n_heads: int, head_dim: int, hidden_size: int): + attn_in = head_dim * n_heads return ( - w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + w.view(n_heads, attn_in // n_heads // 2, 2, hidden_size) .transpose(1, 2) - .reshape(attn_in, attn_out) + .reshape(attn_in, hidden_size) ) # rotary embeds should be sliced # If using quantized model in mistral format, # quantization scales (qscale_weight) also need to be sliced - - if new_name in ["k_proj", "q_proj"]: - n_heads = ( - config.num_key_value_heads if new_name == "k_proj" else config.num_attention_heads - ) - - processed_weights["weight"] = permute(weights["weight"], n_heads, config.hidden_size) - - if "qscale_weight" in weights and weights["qscale_weight"].numel() > 1: - processed_weights["qscale_weight"] = permute(weights["qscale_weight"], n_heads, 1) - - return processed_weights - + for name in weights.keys(): + # TODO: add scales if dequant is necessary + if ".wq.weight" in name: + weights[name] = permute( + weights[name], config.num_attention_heads, config.head_dim, config.hidden_size + ) + elif ".wk.weight" in name: + weights[name] = permute( + weights[name], config.num_key_value_heads, config.head_dim, config.hidden_size + ) return weights diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 573318b41130..53eb1095749e 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -376,10 +376,7 @@ def __init__( # When the input only contains text, we use the text processor to process the input. self._processor = MistralCommonImageProcessor( tokenizer=self._tokenizer, dtype=self.dtype) - self.text_processor = AutoProcessor.from_pretrained( - model_path, - use_fast=self.use_fast, - trust_remote_code=trust_remote_code) + self.text_processor = self._processor else: # For other mistral models, we use the AutoProcessor to process the input. self._processor = AutoProcessor.from_pretrained( @@ -622,19 +619,29 @@ def _post_config(self): def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): vit_params_map = None - if weight_mapper: - if isinstance(weight_mapper, MistralWeightMapper): - vit_params_map = weight_mapper.pixtral_mapping + if weight_mapper and isinstance(weight_mapper, MistralWeightMapper): + vit_params_map = weight_mapper.pixtral_mapping llm_weights = filter_weights(weights=weights, prefix="language_model") logger.debug(f"Loading weights for {type(self.llm)}") - self.llm.load_weights(llm_weights) + if weight_mapper and type(weight_mapper) is MistralWeightMapper: + weight_mapper.permute_qk(weights=llm_weights, + config=self.llm.config) + self.llm.load_weights(llm_weights, + weight_mapper=weight_mapper, + params_map=weight_mapper.mistral_llm_mapping) + else: + self.llm.load_weights(llm_weights) logger.debug(f"Successfully loaded weights for {type(self.llm)}") vit_weights = filter_weights(weights=weights, prefix="vision_tower") logger.debug(f"Loading weights for {type(self._vision_tower)}") if vit_params_map is not None: + # Pixtral uses num_attention_heads = num_key_value_heads + self._vision_tower.config.num_key_value_heads = self._vision_tower.config.num_attention_heads + weight_mapper.permute_qk(weights=vit_weights, + config=self._vision_tower.config) vit_weights = weight_mapper.rename_by_params_map( weights=vit_weights, params_map=vit_params_map) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 9c3b4c37560f..a29dc41c174f 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -260,7 +260,12 @@ def load_pretrained_config(model_name_or_path: str, model_type = config_dict.get("model_type") architectures = config_dict.get("architectures") or [] - if model_type in _CONFIG_REGISTRY: + if checkpoint_format in ("mistral", "mistral_large_3"): + from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ + MistralConfigLoader + model_config = MistralConfigLoader().load( + model_name_or_path).pretrained_config + elif model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] model_config = config_class.from_pretrained(model_name_or_path, **kwargs) @@ -274,12 +279,6 @@ def load_pretrained_config(model_name_or_path: str, )): model_config = transformers.Qwen3NextConfig.from_dict( _Qwen35ConfigCompat.normalize(config_dict)) - elif checkpoint_format in ("mistral", "mistral_large_3"): - from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ - MistralConfigLoader - model_config = getattr( - MistralConfigLoader().load(model_name_or_path).pretrained_config, - "text_config") else: model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 1b97f72f2a38..a3b484bdc9c2 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -879,7 +879,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython, num_key_value_heads) # get head dim - mla = hasattr(config, "kv_lora_rank") + mla = hasattr(config, + "kv_lora_rank") and config.kv_lora_rank is not None if mla: head_dim = config.kv_lora_rank + config.qk_rope_head_dim kv_factor = 1 @@ -2553,7 +2554,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython, num_key_value_heads) # get head dim - mla = hasattr(config, "kv_lora_rank") + mla = hasattr(config, + "kv_lora_rank") and config.kv_lora_rank is not None if mla: head_dim = config.kv_lora_rank + config.qk_rope_head_dim kv_factor = 1 diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 5a6cb517caca..2ea28568f166 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -430,11 +430,13 @@ def _update_from_hf_quant_config(self) -> bool: if hf_quant_config is not None: # DeepSeek V3 FP8 ckpt - if hf_quant_config.get( - "quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size"): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - quant_config.exclude_modules = ["*eh_proj"] + if hf_quant_config.get("quant_method") == "fp8": + if hf_quant_config.get("weight_block_size") is not None: + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + quant_config.exclude_modules = ["*eh_proj"] + else: + # Ministral 3 static quant + quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "mxfp4": from .._torch.model_config import ModelConfig quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 65d58df6d8cd..8654c60ff5bd 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -332,7 +332,11 @@ def _init_visual_gen(self): def _init_llm(self, chat_template: Optional[str] = None): self.tokenizer = self.generator.tokenizer - hf_tokenizer_path = self.generator._hf_model_dir or self.tokenizer.tokenizer.name_or_path + hf_tokenizer_path = self.generator._hf_model_dir + if not hf_tokenizer_path: + hf_tokenizer_path = getattr( + self.tokenizer.tokenizer, "name_or_path", None) or getattr( + self.tokenizer, "name_or_path", None) trust_remote_code = self.generator.args.trust_remote_code try: self.processor = AutoProcessor.from_pretrained( @@ -1042,8 +1046,11 @@ async def chat_stream_generator( ] # Pass the tokenizer vocabulary size so ``logit_bias`` can be # expanded into an embedding bias tensor in the sampler. + vocab_size = getattr(self.tokenizer.tokenizer, + "vocab_size", None) or getattr( + self.tokenizer, "vocab_size", None) sampling_params = request.to_sampling_params( - vocab_size=self.tokenizer.tokenizer.vocab_size, + vocab_size=vocab_size, gather_generation_logits=self.generator.args. gather_generation_logits, reasoning_parser=self.generator.args.reasoning_parser, diff --git a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py index 9989e2821acb..fbe8a3dbf481 100644 --- a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py +++ b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py @@ -11,7 +11,7 @@ class MyError(Exception): @pytest.mark.parametrize( - "dir_name, safetensor_filenames, expected_safetensor_filenames", + "dir_name, safetensor_filenames, expected_safetensor_filenames, use_consolidated", [ ( "foo", @@ -21,6 +21,18 @@ class MyError(Exception): "consolidated.safetensors", ], ["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"], + False, + ), + # If use_consolidated specified explicitly. + ( + "foo", + [ + "model-00001-of-00002.safetensors", + "model-000002-of-00002.safetensors", + "consolidated.safetensors", + ], + ["consolidated.safetensors"], + True, ), ( "foo", @@ -29,12 +41,14 @@ class MyError(Exception): "foo-consolidated.safetensors", ], [f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)], + False, ), # If there is only a consolidated safetensor, that one should still be used. ( "foo", ["consolidated.safetensors"], ["consolidated.safetensors"], + False, ), # If the directory contains "consolidated" in its name, but its contents are sharded tensors. ( @@ -45,6 +59,7 @@ class MyError(Exception): "consolidated.safetensors", ], ["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"], + False, ), ], ) @@ -53,6 +68,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists( dir_name: str, safetensor_filenames: list[str], expected_safetensor_filenames: list[str], + use_consolidated: bool, ): checkpoint_dir = tmp_path / dir_name checkpoint_dir.mkdir() @@ -70,7 +86,9 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists( mock.patch.object(loader, "prefetch_files") as prefetch_files, pytest.raises(MyError), ): - loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping()) + loader.load_weights( + checkpoint_dir=str(checkpoint_dir), mapping=Mapping(), use_consolidated=use_consolidated + ) prefetch_files.assert_called_once() prefetched_files = prefetch_files.call_args[0][0] diff --git a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py new file mode 100644 index 000000000000..e9627ca8e5ce --- /dev/null +++ b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralWeightMapper + + +@pytest.fixture +def expected_renames(): + return { + # Top-level embeddings and output projections + "tok_embeddings.weight": "model.embed_tokens.weight", + "output.weight": "lm_head.weight", + "norm.weight": "model.norm.weight", + # Per-layer attention projection weights (pixtral_mapping + mistral_llm_mapping) + "layers.0.attention.wq.weight": "model.layers.0.self_attn.q_proj.weight", + "layers.0.attention.wk.weight": "model.layers.0.self_attn.k_proj.weight", + "layers.0.attention.wv.weight": "model.layers.0.self_attn.v_proj.weight", + "layers.0.attention.wo.weight": "model.layers.0.self_attn.o_proj.weight", + # Per-layer MLP weights + "layers.0.feed_forward.w1.weight": "model.layers.0.mlp.gate_proj.weight", + "layers.0.feed_forward.w2.weight": "model.layers.0.mlp.down_proj.weight", + "layers.0.feed_forward.w3.weight": "model.layers.0.mlp.up_proj.weight", + # Layernorms + "layers.0.attention_norm.weight": "model.layers.0.input_layernorm.weight", + "layers.0.ffn_norm.weight": "model.layers.0.post_attention_layernorm.weight", + # Quantization scales: compound key must win over individual token + "layers.0.attention.kv_fake_quantizer.qscale_act": "model.layers.0.self_attn.kv_scale", + "layers.0.attention.qscale_act": "model.layers.0.self_attn.input_scale", + # Unknown keys must pass through unchanged + "some.unknown.tensor": "some.unknown.tensor", + } + + +def test_rename_by_params_map(expected_renames): + mapper = MistralWeightMapper() + dummy = torch.tensor(0.0) + input_weights = {k: dummy for k in expected_renames} + + result = mapper.rename_by_params_map(mapper.mistral_llm_mapping, input_weights) + + mismatches = {k: v for k, v in expected_renames.items() if v not in result} + assert not mismatches, ( + "Keys not renamed as expected (input -> expected):\n" + + "\n".join(f" {k!r} -> {v!r}" for k, v in mismatches.items()) + + f"\nActual keys: {sorted(result.keys())}" + ) + assert type(result) is dict