Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ class HfWeightLoader(BaseWeightLoader):
Loads weights from SafeTensors/bin/pth files.
"""

def load_weights(self, checkpoint_dir: str,
mapping: Mapping) -> dict[str, Any]:
def load_weights(self,
checkpoint_dir: str,
mapping: Mapping,
use_consolidated: bool = False) -> dict[str, Any]:
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
# Some model checkpoint directories contain not only the sharded safetensors, but one
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
# consolidated tensor. In the presence of both, we favor the former unless specified explicitly, as there really is no need
# to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case.
filtered_weight_files = [
x for x in weight_files if "consolidated" not in os.path.split(x)[1]
x for x in weight_files
if ("consolidated" in os.path.split(x)[1]) == use_consolidated
]
if len(filtered_weight_files) > 0:
weight_files = filtered_weight_files
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def inverse_nvfp4_global_scales(self, weights):
weights[key] = 1.0 / weights[key]

def load_weights(self, checkpoint_dir: str, **kwargs):
weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs)
# Mistral native weight mapping is different from HF and stored in the .consolidated tensor
weights = super().weight_loader.load_weights(
checkpoint_dir, use_consolidated=True, **kwargs
)
weights = self.preprocess_weights(weights)
self.broadcast_per_tensor_scales(weights)
# The definition of global_scale is different in Mistral, need to inverse the scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,5 +348,6 @@ def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:

model_config.pretrained_config.gate_cls = Mistral3Gate
model_config.pretrained_config.input_processor_type = "mistral_large_3"
model_config.pretrained_config.model_type = "mistral_large_3"
model_config._frozen = True
return model_config
45 changes: 17 additions & 28 deletions tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from torch import nn

from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper

Expand All @@ -10,8 +8,6 @@ class MistralWeightMapper(HfWeightMapper):
def __init__(self):
super().__init__()

self._callbacks.append(self._permute_qk)

self.pixtral_mapping = {
"wq": "q_proj",
"wk": "k_proj",
Expand All @@ -31,8 +27,8 @@ def __init__(self):
"qscale_weight": "weight_scale_inv",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"k_fake_quantizer.qscale_act": "attn.k_scale",
"v_fake_quantizer.qscale_act": "attn.v_scale",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"ffn_norm": "post_attention_layernorm",
Expand Down Expand Up @@ -78,38 +74,31 @@ def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dic
return ConsumableWeightsDict(renamed_weights)
return renamed_weights

def _permute_qk(self, module: nn.Module, new_name: str, weights: dict):
def permute_qk(self, weights: dict, config: dict):
# Adapted from:
# https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657

processed_weights = {}
config = self.config.pretrained_config

def permute(w, n_heads: int, attn_out: int):
attn_in = config.head_dim * n_heads

def permute(w, n_heads: int, head_dim: int, hidden_size: int):
attn_in = head_dim * n_heads
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
w.view(n_heads, attn_in // n_heads // 2, 2, hidden_size)
.transpose(1, 2)
.reshape(attn_in, attn_out)
.reshape(attn_in, hidden_size)
)

# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced

if new_name in ["k_proj", "q_proj"]:
n_heads = (
config.num_key_value_heads if new_name == "k_proj" else config.num_attention_heads
)

processed_weights["weight"] = permute(weights["weight"], n_heads, config.hidden_size)

if "qscale_weight" in weights and weights["qscale_weight"].numel() > 1:
processed_weights["qscale_weight"] = permute(weights["qscale_weight"], n_heads, 1)

return processed_weights

for name in weights.keys():
# TODO: add scales if dequant is necessary
if ".wq.weight" in name:
weights[name] = permute(
weights[name], config.num_attention_heads, config.head_dim, config.hidden_size
)
elif ".wk.weight" in name:
weights[name] = permute(
weights[name], config.num_key_value_heads, config.head_dim, config.hidden_size
)
return weights


Expand Down
23 changes: 15 additions & 8 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,7 @@ def __init__(
# When the input only contains text, we use the text processor to process the input.
self._processor = MistralCommonImageProcessor(
tokenizer=self._tokenizer, dtype=self.dtype)
self.text_processor = AutoProcessor.from_pretrained(
model_path,
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)
self.text_processor = self._processor
else:
# For other mistral models, we use the AutoProcessor to process the input.
self._processor = AutoProcessor.from_pretrained(
Expand Down Expand Up @@ -622,19 +619,29 @@ def _post_config(self):

def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs):
vit_params_map = None
if weight_mapper:
if isinstance(weight_mapper, MistralWeightMapper):
vit_params_map = weight_mapper.pixtral_mapping
if weight_mapper and isinstance(weight_mapper, MistralWeightMapper):
vit_params_map = weight_mapper.pixtral_mapping

llm_weights = filter_weights(weights=weights, prefix="language_model")
logger.debug(f"Loading weights for {type(self.llm)}")
self.llm.load_weights(llm_weights)
if weight_mapper and type(weight_mapper) is MistralWeightMapper:
weight_mapper.permute_qk(weights=llm_weights,
config=self.llm.config)
self.llm.load_weights(llm_weights,
weight_mapper=weight_mapper,
params_map=weight_mapper.mistral_llm_mapping)
else:
self.llm.load_weights(llm_weights)
logger.debug(f"Successfully loaded weights for {type(self.llm)}")

vit_weights = filter_weights(weights=weights, prefix="vision_tower")
logger.debug(f"Loading weights for {type(self._vision_tower)}")

if vit_params_map is not None:
# Pixtral uses num_attention_heads = num_key_value_heads
self._vision_tower.config.num_key_value_heads = self._vision_tower.config.num_attention_heads
weight_mapper.permute_qk(weights=vit_weights,
config=self._vision_tower.config)
vit_weights = weight_mapper.rename_by_params_map(
weights=vit_weights, params_map=vit_params_map)

Expand Down
13 changes: 6 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ def load_pretrained_config(model_name_or_path: str,
model_type = config_dict.get("model_type")
architectures = config_dict.get("architectures") or []

if model_type in _CONFIG_REGISTRY:
if checkpoint_format in ("mistral", "mistral_large_3"):
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
MistralConfigLoader
model_config = MistralConfigLoader().load(
model_name_or_path).pretrained_config
elif model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
model_config = config_class.from_pretrained(model_name_or_path,
**kwargs)
Expand All @@ -274,12 +279,6 @@ def load_pretrained_config(model_name_or_path: str,
)):
model_config = transformers.Qwen3NextConfig.from_dict(
_Qwen35ConfigCompat.normalize(config_dict))
elif checkpoint_format in ("mistral", "mistral_large_3"):
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
MistralConfigLoader
model_config = getattr(
MistralConfigLoader().load(model_name_or_path).pretrained_config,
"text_config")
else:
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code)
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
num_key_value_heads)

# get head dim
mla = hasattr(config, "kv_lora_rank")
mla = hasattr(config,
"kv_lora_rank") and config.kv_lora_rank is not None
if mla:
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
kv_factor = 1
Expand Down Expand Up @@ -2553,7 +2554,8 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
num_key_value_heads)

# get head dim
mla = hasattr(config, "kv_lora_rank")
mla = hasattr(config,
"kv_lora_rank") and config.kv_lora_rank is not None
if mla:
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
kv_factor = 1
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,13 @@ def _update_from_hf_quant_config(self) -> bool:

if hf_quant_config is not None:
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get(
"quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
if hf_quant_config.get("quant_method") == "fp8":
if hf_quant_config.get("weight_block_size") is not None:
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
else:
# Ministral 3 static quant
quant_config.quant_algo = QuantAlgo.FP8
elif hf_quant_config.get("quant_method") == "mxfp4":
from .._torch.model_config import ModelConfig
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ def _init_visual_gen(self):

def _init_llm(self, chat_template: Optional[str] = None):
self.tokenizer = self.generator.tokenizer
hf_tokenizer_path = self.generator._hf_model_dir or self.tokenizer.tokenizer.name_or_path
hf_tokenizer_path = self.generator._hf_model_dir
if not hf_tokenizer_path:
hf_tokenizer_path = getattr(
self.tokenizer.tokenizer, "name_or_path", None) or getattr(
self.tokenizer, "name_or_path", None)
trust_remote_code = self.generator.args.trust_remote_code
try:
self.processor = AutoProcessor.from_pretrained(
Expand Down Expand Up @@ -1042,8 +1046,11 @@ async def chat_stream_generator(
]
# Pass the tokenizer vocabulary size so ``logit_bias`` can be
# expanded into an embedding bias tensor in the sampler.
vocab_size = getattr(self.tokenizer.tokenizer,
"vocab_size", None) or getattr(
self.tokenizer, "vocab_size", None)
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size,
vocab_size=vocab_size,
gather_generation_logits=self.generator.args.
gather_generation_logits,
reasoning_parser=self.generator.args.reasoning_parser,
Expand Down
22 changes: 20 additions & 2 deletions tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MyError(Exception):


@pytest.mark.parametrize(
"dir_name, safetensor_filenames, expected_safetensor_filenames",
"dir_name, safetensor_filenames, expected_safetensor_filenames, use_consolidated",
[
(
"foo",
Expand All @@ -21,6 +21,18 @@ class MyError(Exception):
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
False,
),
# If use_consolidated specified explicitly.
(
"foo",
[
"model-00001-of-00002.safetensors",
"model-000002-of-00002.safetensors",
"consolidated.safetensors",
],
["consolidated.safetensors"],
True,
),
(
"foo",
Expand All @@ -29,12 +41,14 @@ class MyError(Exception):
"foo-consolidated.safetensors",
],
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
False,
),
# If there is only a consolidated safetensor, that one should still be used.
(
"foo",
["consolidated.safetensors"],
["consolidated.safetensors"],
False,
),
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
(
Expand All @@ -45,6 +59,7 @@ class MyError(Exception):
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
False,
),
],
)
Expand All @@ -53,6 +68,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
dir_name: str,
safetensor_filenames: list[str],
expected_safetensor_filenames: list[str],
use_consolidated: bool,
):
checkpoint_dir = tmp_path / dir_name
checkpoint_dir.mkdir()
Expand All @@ -70,7 +86,9 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
mock.patch.object(loader, "prefetch_files") as prefetch_files,
pytest.raises(MyError),
):
loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping())
loader.load_weights(
checkpoint_dir=str(checkpoint_dir), mapping=Mapping(), use_consolidated=use_consolidated
)

prefetch_files.assert_called_once()
prefetched_files = prefetch_files.call_args[0][0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralWeightMapper
Comment thread
evezhier marked this conversation as resolved.


@pytest.fixture
def expected_renames():
return {
# Top-level embeddings and output projections
"tok_embeddings.weight": "model.embed_tokens.weight",
"output.weight": "lm_head.weight",
"norm.weight": "model.norm.weight",
# Per-layer attention projection weights (pixtral_mapping + mistral_llm_mapping)
"layers.0.attention.wq.weight": "model.layers.0.self_attn.q_proj.weight",
"layers.0.attention.wk.weight": "model.layers.0.self_attn.k_proj.weight",
"layers.0.attention.wv.weight": "model.layers.0.self_attn.v_proj.weight",
"layers.0.attention.wo.weight": "model.layers.0.self_attn.o_proj.weight",
# Per-layer MLP weights
"layers.0.feed_forward.w1.weight": "model.layers.0.mlp.gate_proj.weight",
"layers.0.feed_forward.w2.weight": "model.layers.0.mlp.down_proj.weight",
"layers.0.feed_forward.w3.weight": "model.layers.0.mlp.up_proj.weight",
# Layernorms
"layers.0.attention_norm.weight": "model.layers.0.input_layernorm.weight",
"layers.0.ffn_norm.weight": "model.layers.0.post_attention_layernorm.weight",
# Quantization scales: compound key must win over individual token
"layers.0.attention.kv_fake_quantizer.qscale_act": "model.layers.0.self_attn.kv_scale",
"layers.0.attention.qscale_act": "model.layers.0.self_attn.input_scale",
# Unknown keys must pass through unchanged
"some.unknown.tensor": "some.unknown.tensor",
}


def test_rename_by_params_map(expected_renames):
mapper = MistralWeightMapper()
dummy = torch.tensor(0.0)
input_weights = {k: dummy for k in expected_renames}

result = mapper.rename_by_params_map(mapper.mistral_llm_mapping, input_weights)

mismatches = {k: v for k, v in expected_renames.items() if v not in result}
assert not mismatches, (
"Keys not renamed as expected (input -> expected):\n"
+ "\n".join(f" {k!r} -> {v!r}" for k, v in mismatches.items())
+ f"\nActual keys: {sorted(result.keys())}"
)
assert type(result) is dict
Loading