From a311bef42051797b5652434fbe7718393b080dfa Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 13 Mar 2026 01:10:57 -0700 Subject: [PATCH 01/11] config load for HF format Signed-off-by: Olya Kozlova --- tensorrt_llm/_torch/pyexecutor/config_utils.py | 11 +++++++++++ tensorrt_llm/llmapi/llm_utils.py | 12 +++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 9c3b4c37560f..54fadbdbc995 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -280,6 +280,17 @@ def load_pretrained_config(model_name_or_path: str, model_config = getattr( MistralConfigLoader().load(model_name_or_path).pretrained_config, "text_config") + + elif model_type == "mistral3" and "layer_types" in config_dict: + # TODO: update this for transformers v5.0 + config_class = "MinistralConfig" + model_config = config_class.from_pretrained(model_name_or_path, + **kwargs) + + elif model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + model_config = config_class.from_pretrained(model_name_or_path, + **kwargs) else: model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 5a6cb517caca..50fffbdc5e80 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -430,11 +430,13 @@ def _update_from_hf_quant_config(self) -> bool: if hf_quant_config is not None: # DeepSeek V3 FP8 ckpt - if hf_quant_config.get( - "quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size"): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - quant_config.exclude_modules = ["*eh_proj"] + if hf_quant_config.get("quant_method") == "fp8": + if hf_quant_config.get("weight_block_size"): + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + quant_config.exclude_modules = ["*eh_proj"] + else: + # Ministral 3 static quant + quant_config.quant_algo = QuantAlgo.FP8 elif hf_quant_config.get("quant_method") == "mxfp4": from .._torch.model_config import ModelConfig quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( From a523f0eab45e7b91cffa4a438aff4f151e574eb3 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 3 Apr 2026 07:14:04 -0700 Subject: [PATCH 02/11] weight loading fix Signed-off-by: Olya Kozlova --- .../_torch/models/checkpoints/hf/weight_loader.py | 11 +++++++---- .../models/checkpoints/mistral/checkpoint_loader.py | 5 ++++- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 6 ++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index f47e77a81661..3d584fcb0cff 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -26,14 +26,17 @@ class HfWeightLoader(BaseWeightLoader): Loads weights from SafeTensors/bin/pth files. """ - def load_weights(self, checkpoint_dir: str, - mapping: Mapping) -> dict[str, Any]: + def load_weights(self, + checkpoint_dir: str, + mapping: Mapping, + use_consolidated: bool = False) -> dict[str, Any]: weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") # Some model checkpoint directories contain not only the sharded safetensors, but one - # consolidated tensor. In the presence of both, we favor the former, as there really is no need + # consolidated tensor. In the presence of both, we favor the former unless specified explicitly, as there really is no need # to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case. filtered_weight_files = [ - x for x in weight_files if "consolidated" not in os.path.split(x)[1] + x for x in weight_files + if ("consolidated" in os.path.split(x)[1]) == use_consolidated ] if len(filtered_weight_files) > 0: weight_files = filtered_weight_files diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py index 116fce261ac9..e43e1142abb0 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py @@ -64,7 +64,10 @@ def inverse_nvfp4_global_scales(self, weights): weights[key] = 1.0 / weights[key] def load_weights(self, checkpoint_dir: str, **kwargs): - weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs) + # Mistral native weight mapping is different from HF and stored in the .consolidated tensor + weights = super().weight_loader.load_weights( + checkpoint_dir, use_consolidated=True, **kwargs + ) weights = self.preprocess_weights(weights) self.broadcast_per_tensor_scales(weights) # The definition of global_scale is different in Mistral, need to inverse the scale diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 1b97f72f2a38..a3b484bdc9c2 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -879,7 +879,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython, num_key_value_heads) # get head dim - mla = hasattr(config, "kv_lora_rank") + mla = hasattr(config, + "kv_lora_rank") and config.kv_lora_rank is not None if mla: head_dim = config.kv_lora_rank + config.qk_rope_head_dim kv_factor = 1 @@ -2553,7 +2554,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython, num_key_value_heads) # get head dim - mla = hasattr(config, "kv_lora_rank") + mla = hasattr(config, + "kv_lora_rank") and config.kv_lora_rank is not None if mla: head_dim = config.kv_lora_rank + config.qk_rope_head_dim kv_factor = 1 From 236c966f2ae3ebfdedc608f7c6cd9e6f4e5a219a Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 3 Apr 2026 07:33:28 -0700 Subject: [PATCH 03/11] move permute from callbacks to weight loading Signed-off-by: Olya Kozlova --- .../checkpoints/mistral/weight_mapper.py | 49 +++++++++---------- .../_torch/models/modeling_mistral.py | 19 ++++--- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py index 035b539639fc..edee31fce2a6 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py @@ -10,8 +10,6 @@ class MistralWeightMapper(HfWeightMapper): def __init__(self): super().__init__() - self._callbacks.append(self._permute_qk) - self.pixtral_mapping = { "wq": "q_proj", "wk": "k_proj", @@ -31,8 +29,8 @@ def __init__(self): "qscale_weight": "weight_scale_inv", "kv_fake_quantizer.qscale_act": "kv_scale", "q_fake_quantizer.qscale_act": "attn.q_scale", - "k_fake_quantizer.qscale_act": "k_scale", - "v_fake_quantizer.qscale_act": "v_scale", + "k_fake_quantizer.qscale_act": "attn.k_scale", + "v_fake_quantizer.qscale_act": "attn.v_scale", "attention_norm": "input_layernorm", "feed_forward": "mlp", "ffn_norm": "post_attention_layernorm", @@ -78,38 +76,37 @@ def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dic return ConsumableWeightsDict(renamed_weights) return renamed_weights - def _permute_qk(self, module: nn.Module, new_name: str, weights: dict): + def permute_qk(self, weights: dict, config: dict): # Adapted from: # https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657 - processed_weights = {} - config = self.config.pretrained_config - - def permute(w, n_heads: int, attn_out: int): - attn_in = config.head_dim * n_heads - + def permute(w, n_heads: int, head_dim: int, hidden_size: int): + attn_in = head_dim * n_heads return ( - w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + w.view(n_heads, attn_in // n_heads // 2, 2, hidden_size) .transpose(1, 2) - .reshape(attn_in, attn_out) + .reshape(attn_in, hidden_size) ) # rotary embeds should be sliced # If using quantized model in mistral format, # quantization scales (qscale_weight) also need to be sliced - - if new_name in ["k_proj", "q_proj"]: - n_heads = ( - config.num_key_value_heads if new_name == "k_proj" else config.num_attention_heads - ) - - processed_weights["weight"] = permute(weights["weight"], n_heads, config.hidden_size) - - if "qscale_weight" in weights and weights["qscale_weight"].numel() > 1: - processed_weights["qscale_weight"] = permute(weights["qscale_weight"], n_heads, 1) - - return processed_weights - + for name in weights.keys(): + # TODO: add scales if dequant is necessary + if ".wq.weight" in name: + weights[name] = permute( + weights[name], + config.num_attention_heads, + config.head_dim, + config.hidden_size + ) + elif ".wk.weight" in name: + weights[name] = permute( + weights[name], + config.num_key_value_heads, + config.head_dim, + config.hidden_size + ) return weights diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 573318b41130..9f4f9b7b5bce 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -2,6 +2,7 @@ import dataclasses from typing import Any, Dict, List, Tuple +import math import torch import torchvision from mistral_common.tokens.tokenizers.multimodal import ImageEncoder @@ -370,16 +371,13 @@ def __init__( use_fast=self.use_fast, trust_remote_code=trust_remote_code) self._model_path = model_path - if model_type == "mistral_large_3": + if model_type in ("mistral_large_3", "mistral3"): # For mistral large 3, we add chat template in the model forward, and the # MistralCommonImageProcessor is used to process the input when both text and images are provided. # When the input only contains text, we use the text processor to process the input. self._processor = MistralCommonImageProcessor( tokenizer=self._tokenizer, dtype=self.dtype) - self.text_processor = AutoProcessor.from_pretrained( - model_path, - use_fast=self.use_fast, - trust_remote_code=trust_remote_code) + self.text_processor = self._processor else: # For other mistral models, we use the AutoProcessor to process the input. self._processor = AutoProcessor.from_pretrained( @@ -628,13 +626,22 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): llm_weights = filter_weights(weights=weights, prefix="language_model") logger.debug(f"Loading weights for {type(self.llm)}") - self.llm.load_weights(llm_weights) + if weight_mapper: + weight_mapper.permute_qk(weights=llm_weights, config=self.llm.config) + self.llm.load_weights(llm_weights, + weight_mapper=weight_mapper, + params_map=weight_mapper.mistral_llm_mapping) + else: + self.llm.load_weights(llm_weights) logger.debug(f"Successfully loaded weights for {type(self.llm)}") vit_weights = filter_weights(weights=weights, prefix="vision_tower") logger.debug(f"Loading weights for {type(self._vision_tower)}") if vit_params_map is not None: + # Pixtral uses num_attention_heads = num_key_value_heads + self._vision_tower.config.num_key_value_heads = self._vision_tower.config.num_attention_heads + weight_mapper.permute_qk(weights=vit_weights, config=self._vision_tower.config) vit_weights = weight_mapper.rename_by_params_map( weights=vit_weights, params_map=vit_params_map) From 6d49bb78f75fecebf5fff666b67bb595a3f51167 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 3 Apr 2026 07:33:28 -0700 Subject: [PATCH 04/11] move permute from callbacks to weight loading Signed-off-by: Olya Kozlova --- tensorrt_llm/_torch/models/modeling_mistral.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 9f4f9b7b5bce..76e047edcb53 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -2,7 +2,6 @@ import dataclasses from typing import Any, Dict, List, Tuple -import math import torch import torchvision from mistral_common.tokens.tokenizers.multimodal import ImageEncoder @@ -371,7 +370,7 @@ def __init__( use_fast=self.use_fast, trust_remote_code=trust_remote_code) self._model_path = model_path - if model_type in ("mistral_large_3", "mistral3"): + if model_type == "mistral_large_3": # For mistral large 3, we add chat template in the model forward, and the # MistralCommonImageProcessor is used to process the input when both text and images are provided. # When the input only contains text, we use the text processor to process the input. @@ -507,7 +506,7 @@ def __init__( def load_tokenizer(model_path: str, config: PretrainedConfig, tokenizer: AutoTokenizer | None = None): - if getattr(config, "input_processor_type", None) == "mistral_large_3": + if getattr(config, "input_processor_type", None) in ("mistral_large_3"): try: return MistralTokenizer.from_pretrained(model_path) From 3c7f5e53a3338080e4d82f607382c03b3531a294 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 3 Apr 2026 10:03:09 -0700 Subject: [PATCH 05/11] cleanup Signed-off-by: Olya Kozlova --- .../checkpoints/mistral/weight_mapper.py | 14 +++----------- tensorrt_llm/_torch/models/modeling_mistral.py | 18 ++++++++++-------- tensorrt_llm/_torch/pyexecutor/config_utils.py | 2 +- tensorrt_llm/llmapi/llm_utils.py | 2 +- 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py index edee31fce2a6..dd6e0332b849 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py @@ -1,5 +1,3 @@ -from torch import nn - from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper from tensorrt_llm._torch.models.modeling_utils import register_mapper @@ -92,20 +90,14 @@ def permute(w, n_heads: int, head_dim: int, hidden_size: int): # If using quantized model in mistral format, # quantization scales (qscale_weight) also need to be sliced for name in weights.keys(): - # TODO: add scales if dequant is necessary + # TODO: add scales if dequant is necessary if ".wq.weight" in name: weights[name] = permute( - weights[name], - config.num_attention_heads, - config.head_dim, - config.hidden_size + weights[name], config.num_attention_heads, config.head_dim, config.hidden_size ) elif ".wk.weight" in name: weights[name] = permute( - weights[name], - config.num_key_value_heads, - config.head_dim, - config.hidden_size + weights[name], config.num_key_value_heads, config.head_dim, config.hidden_size ) return weights diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 76e047edcb53..d4df2dd317ca 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -625,13 +625,14 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): llm_weights = filter_weights(weights=weights, prefix="language_model") logger.debug(f"Loading weights for {type(self.llm)}") - if weight_mapper: - weight_mapper.permute_qk(weights=llm_weights, config=self.llm.config) - self.llm.load_weights(llm_weights, - weight_mapper=weight_mapper, - params_map=weight_mapper.mistral_llm_mapping) - else: - self.llm.load_weights(llm_weights) + if weight_mapper: + weight_mapper.permute_qk(weights=llm_weights, + config=self.llm.config) + self.llm.load_weights(llm_weights, + weight_mapper=weight_mapper, + params_map=weight_mapper.mistral_llm_mapping) + else: + self.llm.load_weights(llm_weights) logger.debug(f"Successfully loaded weights for {type(self.llm)}") vit_weights = filter_weights(weights=weights, prefix="vision_tower") @@ -640,7 +641,8 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): if vit_params_map is not None: # Pixtral uses num_attention_heads = num_key_value_heads self._vision_tower.config.num_key_value_heads = self._vision_tower.config.num_attention_heads - weight_mapper.permute_qk(weights=vit_weights, config=self._vision_tower.config) + weight_mapper.permute_qk(weights=vit_weights, + config=self._vision_tower.config) vit_weights = weight_mapper.rename_by_params_map( weights=vit_weights, params_map=vit_params_map) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 54fadbdbc995..20991585bd82 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -283,7 +283,7 @@ def load_pretrained_config(model_name_or_path: str, elif model_type == "mistral3" and "layer_types" in config_dict: # TODO: update this for transformers v5.0 - config_class = "MinistralConfig" + config_class = "MinistralConfig" model_config = config_class.from_pretrained(model_name_or_path, **kwargs) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 50fffbdc5e80..4dbc5fc8c9c0 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -430,7 +430,7 @@ def _update_from_hf_quant_config(self) -> bool: if hf_quant_config is not None: # DeepSeek V3 FP8 ckpt - if hf_quant_config.get("quant_method") == "fp8": + if hf_quant_config.get("quant_method") == "fp8": if hf_quant_config.get("weight_block_size"): quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES quant_config.exclude_modules = ["*eh_proj"] From c18be323da7da00d77c1cfb4cdafd5daff151163 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Thu, 9 Apr 2026 11:01:39 -0700 Subject: [PATCH 06/11] tests Signed-off-by: Olya Kozlova --- .../checkpoints/hf/test_weight_loader.py | 22 ++++++++- .../checkpoints/mistral/test_weight_mapper.py | 47 +++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py diff --git a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py index 9989e2821acb..fbe8a3dbf481 100644 --- a/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py +++ b/tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py @@ -11,7 +11,7 @@ class MyError(Exception): @pytest.mark.parametrize( - "dir_name, safetensor_filenames, expected_safetensor_filenames", + "dir_name, safetensor_filenames, expected_safetensor_filenames, use_consolidated", [ ( "foo", @@ -21,6 +21,18 @@ class MyError(Exception): "consolidated.safetensors", ], ["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"], + False, + ), + # If use_consolidated specified explicitly. + ( + "foo", + [ + "model-00001-of-00002.safetensors", + "model-000002-of-00002.safetensors", + "consolidated.safetensors", + ], + ["consolidated.safetensors"], + True, ), ( "foo", @@ -29,12 +41,14 @@ class MyError(Exception): "foo-consolidated.safetensors", ], [f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)], + False, ), # If there is only a consolidated safetensor, that one should still be used. ( "foo", ["consolidated.safetensors"], ["consolidated.safetensors"], + False, ), # If the directory contains "consolidated" in its name, but its contents are sharded tensors. ( @@ -45,6 +59,7 @@ class MyError(Exception): "consolidated.safetensors", ], ["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"], + False, ), ], ) @@ -53,6 +68,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists( dir_name: str, safetensor_filenames: list[str], expected_safetensor_filenames: list[str], + use_consolidated: bool, ): checkpoint_dir = tmp_path / dir_name checkpoint_dir.mkdir() @@ -70,7 +86,9 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists( mock.patch.object(loader, "prefetch_files") as prefetch_files, pytest.raises(MyError), ): - loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping()) + loader.load_weights( + checkpoint_dir=str(checkpoint_dir), mapping=Mapping(), use_consolidated=use_consolidated + ) prefetch_files.assert_called_once() prefetched_files = prefetch_files.call_args[0][0] diff --git a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py new file mode 100644 index 000000000000..37cd4cc4b981 --- /dev/null +++ b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralWeightMapper + + +@pytest.fixture +def expected_renames(): + return { + # Top-level embeddings and output projections + "tok_embeddings.weight": "model.embed_tokens.weight", + "output.weight": "lm_head.weight", + "norm.weight": "model.norm.weight", + # Per-layer attention projection weights (pixtral_mapping + mistral_llm_mapping) + "layers.0.attention.wq.weight": "model.layers.0.self_attn.q_proj.weight", + "layers.0.attention.wk.weight": "model.layers.0.self_attn.k_proj.weight", + "layers.0.attention.wv.weight": "model.layers.0.self_attn.v_proj.weight", + "layers.0.attention.wo.weight": "model.layers.0.self_attn.o_proj.weight", + # Per-layer MLP weights + "layers.0.feed_forward.w1.weight": "model.layers.0.mlp.gate_proj.weight", + "layers.0.feed_forward.w2.weight": "model.layers.0.mlp.down_proj.weight", + "layers.0.feed_forward.w3.weight": "model.layers.0.mlp.up_proj.weight", + # Layernorms + "layers.0.attention_norm.weight": "model.layers.0.input_layernorm.weight", + "layers.0.ffn_norm.weight": "model.layers.0.post_attention_layernorm.weight", + # Quantization scales: compound key must win over individual token + "layers.0.attention.kv_fake_quantizer.qscale_act": "model.layers.0.self_attn.kv_scale", + "layers.0.attention.qscale_act": "model.layers.0.self_attn.input_scale", + # Unknown keys must pass through unchanged + "some.unknown.tensor": "some.unknown.tensor", + } + + +def test_rename_by_params_map(expected_renames): + mapper = MistralWeightMapper() + dummy = torch.tensor(0.0) + input_weights = {k: dummy for k in expected_renames} + + result = mapper.rename_by_params_map(mapper.mistral_llm_mapping, input_weights) + + mismatches = {k: v for k, v in expected_renames.items() if v not in result} + assert not mismatches, ( + f"Keys not renamed as expected (input -> expected):\n" + + "\n".join(f" {k!r} -> {v!r}" for k, v in mismatches.items()) + + f"\nActual keys: {sorted(result.keys())}" + ) + assert type(result) is dict From 5ddd81136de8066aee690ff209a6e10449cd2e00 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Thu, 9 Apr 2026 12:49:51 -0700 Subject: [PATCH 07/11] precommit fix Signed-off-by: Olya Kozlova --- .../_torch/models/checkpoints/mistral/test_weight_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py index 37cd4cc4b981..c634724ee69d 100644 --- a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py +++ b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py @@ -40,7 +40,7 @@ def test_rename_by_params_map(expected_renames): mismatches = {k: v for k, v in expected_renames.items() if v not in result} assert not mismatches, ( - f"Keys not renamed as expected (input -> expected):\n" + "Keys not renamed as expected (input -> expected):\n" + "\n".join(f" {k!r} -> {v!r}" for k, v in mismatches.items()) + f"\nActual keys: {sorted(result.keys())}" ) From 97803c8109608bc6df5afc6ffb8c0f148170e2c3 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 10 Apr 2026 15:28:58 -0700 Subject: [PATCH 08/11] openai server fixes Signed-off-by: Olya Kozlova --- .../checkpoints/mistral/config_loader.py | 1 + .../_torch/models/modeling_mistral.py | 2 +- .../_torch/pyexecutor/config_utils.py | 24 +++++-------------- tensorrt_llm/llmapi/llm_utils.py | 2 +- tensorrt_llm/serve/openai_server.py | 11 +++++++-- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py index dbc7abf73020..1b566a812ae0 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -348,5 +348,6 @@ def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: model_config.pretrained_config.gate_cls = Mistral3Gate model_config.pretrained_config.input_processor_type = "mistral_large_3" + model_config.pretrained_config.model_type = "mistral_large_3" model_config._frozen = True return model_config diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index d4df2dd317ca..142ee94e4a3b 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -506,7 +506,7 @@ def __init__( def load_tokenizer(model_path: str, config: PretrainedConfig, tokenizer: AutoTokenizer | None = None): - if getattr(config, "input_processor_type", None) in ("mistral_large_3"): + if getattr(config, "input_processor_type", None) == "mistral_large_3": try: return MistralTokenizer.from_pretrained(model_path) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 20991585bd82..a29dc41c174f 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -260,7 +260,12 @@ def load_pretrained_config(model_name_or_path: str, model_type = config_dict.get("model_type") architectures = config_dict.get("architectures") or [] - if model_type in _CONFIG_REGISTRY: + if checkpoint_format in ("mistral", "mistral_large_3"): + from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ + MistralConfigLoader + model_config = MistralConfigLoader().load( + model_name_or_path).pretrained_config + elif model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] model_config = config_class.from_pretrained(model_name_or_path, **kwargs) @@ -274,23 +279,6 @@ def load_pretrained_config(model_name_or_path: str, )): model_config = transformers.Qwen3NextConfig.from_dict( _Qwen35ConfigCompat.normalize(config_dict)) - elif checkpoint_format in ("mistral", "mistral_large_3"): - from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ - MistralConfigLoader - model_config = getattr( - MistralConfigLoader().load(model_name_or_path).pretrained_config, - "text_config") - - elif model_type == "mistral3" and "layer_types" in config_dict: - # TODO: update this for transformers v5.0 - config_class = "MinistralConfig" - model_config = config_class.from_pretrained(model_name_or_path, - **kwargs) - - elif model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - model_config = config_class.from_pretrained(model_name_or_path, - **kwargs) else: model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 4dbc5fc8c9c0..2ea28568f166 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -431,7 +431,7 @@ def _update_from_hf_quant_config(self) -> bool: if hf_quant_config is not None: # DeepSeek V3 FP8 ckpt if hf_quant_config.get("quant_method") == "fp8": - if hf_quant_config.get("weight_block_size"): + if hf_quant_config.get("weight_block_size") is not None: quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES quant_config.exclude_modules = ["*eh_proj"] else: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 65d58df6d8cd..8654c60ff5bd 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -332,7 +332,11 @@ def _init_visual_gen(self): def _init_llm(self, chat_template: Optional[str] = None): self.tokenizer = self.generator.tokenizer - hf_tokenizer_path = self.generator._hf_model_dir or self.tokenizer.tokenizer.name_or_path + hf_tokenizer_path = self.generator._hf_model_dir + if not hf_tokenizer_path: + hf_tokenizer_path = getattr( + self.tokenizer.tokenizer, "name_or_path", None) or getattr( + self.tokenizer, "name_or_path", None) trust_remote_code = self.generator.args.trust_remote_code try: self.processor = AutoProcessor.from_pretrained( @@ -1042,8 +1046,11 @@ async def chat_stream_generator( ] # Pass the tokenizer vocabulary size so ``logit_bias`` can be # expanded into an embedding bias tensor in the sampler. + vocab_size = getattr(self.tokenizer.tokenizer, + "vocab_size", None) or getattr( + self.tokenizer, "vocab_size", None) sampling_params = request.to_sampling_params( - vocab_size=self.tokenizer.tokenizer.vocab_size, + vocab_size=vocab_size, gather_generation_logits=self.generator.args. gather_generation_logits, reasoning_parser=self.generator.args.reasoning_parser, From 045aae17550464e0f9df04aedde894ce5dc75161 Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Fri, 10 Apr 2026 15:35:10 -0700 Subject: [PATCH 09/11] license Signed-off-by: Olya Kozlova --- .../checkpoints/mistral/test_weight_mapper.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py index c634724ee69d..e9627ca8e5ce 100644 --- a/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py +++ b/tests/unittest/_torch/models/checkpoints/mistral/test_weight_mapper.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest import torch From d89a1042f4409168baa9c8e50480e7c3653538be Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Tue, 14 Apr 2026 08:43:50 -0700 Subject: [PATCH 10/11] ignore hf as weight mapper Signed-off-by: Olya Kozlova --- tensorrt_llm/_torch/models/modeling_mistral.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 142ee94e4a3b..047c851976b6 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -619,13 +619,12 @@ def _post_config(self): def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): vit_params_map = None - if weight_mapper: - if isinstance(weight_mapper, MistralWeightMapper): - vit_params_map = weight_mapper.pixtral_mapping + if weight_mapper and isinstance(weight_mapper, MistralWeightMapper): + vit_params_map = weight_mapper.pixtral_mapping llm_weights = filter_weights(weights=weights, prefix="language_model") logger.debug(f"Loading weights for {type(self.llm)}") - if weight_mapper: + if weight_mapper and isinstance(weight_mapper, MistralWeightMapper): weight_mapper.permute_qk(weights=llm_weights, config=self.llm.config) self.llm.load_weights(llm_weights, From 3c26c2e21c3db57d8ddef429e1f670766fef2efe Mon Sep 17 00:00:00 2001 From: Olya Kozlova Date: Thu, 16 Apr 2026 12:08:06 -0700 Subject: [PATCH 11/11] ML3 should use ML3 weight mapper Signed-off-by: Olya Kozlova --- tensorrt_llm/_torch/models/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 047c851976b6..53eb1095749e 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -624,7 +624,7 @@ def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): llm_weights = filter_weights(weights=weights, prefix="language_model") logger.debug(f"Loading weights for {type(self.llm)}") - if weight_mapper and isinstance(weight_mapper, MistralWeightMapper): + if weight_mapper and type(weight_mapper) is MistralWeightMapper: weight_mapper.permute_qk(weights=llm_weights, config=self.llm.config) self.llm.load_weights(llm_weights,