Skip to content

Commit c66b9a2

Browse files
committed
tests
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
1 parent dad332e commit c66b9a2

2 files changed

Lines changed: 67 additions & 2 deletions

File tree

tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MyError(Exception):
1111

1212

1313
@pytest.mark.parametrize(
14-
"dir_name, safetensor_filenames, expected_safetensor_filenames",
14+
"dir_name, safetensor_filenames, expected_safetensor_filenames, use_consolidated",
1515
[
1616
(
1717
"foo",
@@ -21,6 +21,18 @@ class MyError(Exception):
2121
"consolidated.safetensors",
2222
],
2323
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
24+
False,
25+
),
26+
# If use_consolidated specified explicitly.
27+
(
28+
"foo",
29+
[
30+
"model-00001-of-00002.safetensors",
31+
"model-000002-of-00002.safetensors",
32+
"consolidated.safetensors",
33+
],
34+
["consolidated.safetensors"],
35+
True,
2436
),
2537
(
2638
"foo",
@@ -29,12 +41,14 @@ class MyError(Exception):
2941
"foo-consolidated.safetensors",
3042
],
3143
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
44+
False,
3245
),
3346
# If there is only a consolidated safetensor, that one should still be used.
3447
(
3548
"foo",
3649
["consolidated.safetensors"],
3750
["consolidated.safetensors"],
51+
False,
3852
),
3953
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
4054
(
@@ -45,6 +59,7 @@ class MyError(Exception):
4559
"consolidated.safetensors",
4660
],
4761
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
62+
False,
4863
),
4964
],
5065
)
@@ -53,6 +68,7 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
5368
dir_name: str,
5469
safetensor_filenames: list[str],
5570
expected_safetensor_filenames: list[str],
71+
use_consolidated: bool,
5672
):
5773
checkpoint_dir = tmp_path / dir_name
5874
checkpoint_dir.mkdir()
@@ -70,7 +86,9 @@ def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
7086
mock.patch.object(loader, "prefetch_files") as prefetch_files,
7187
pytest.raises(MyError),
7288
):
73-
loader.load_weights(checkpoint_dir=str(checkpoint_dir), mapping=Mapping())
89+
loader.load_weights(
90+
checkpoint_dir=str(checkpoint_dir), mapping=Mapping(), use_consolidated=use_consolidated
91+
)
7492

7593
prefetch_files.assert_called_once()
7694
prefetched_files = prefetch_files.call_args[0][0]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
import torch
3+
4+
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralWeightMapper
5+
6+
7+
@pytest.fixture
8+
def expected_renames():
9+
return {
10+
# Top-level embeddings and output projections
11+
"tok_embeddings.weight": "model.embed_tokens.weight",
12+
"output.weight": "lm_head.weight",
13+
"norm.weight": "model.norm.weight",
14+
# Per-layer attention projection weights (pixtral_mapping + mistral_llm_mapping)
15+
"layers.0.attention.wq.weight": "model.layers.0.self_attn.q_proj.weight",
16+
"layers.0.attention.wk.weight": "model.layers.0.self_attn.k_proj.weight",
17+
"layers.0.attention.wv.weight": "model.layers.0.self_attn.v_proj.weight",
18+
"layers.0.attention.wo.weight": "model.layers.0.self_attn.o_proj.weight",
19+
# Per-layer MLP weights
20+
"layers.0.feed_forward.w1.weight": "model.layers.0.mlp.gate_proj.weight",
21+
"layers.0.feed_forward.w2.weight": "model.layers.0.mlp.down_proj.weight",
22+
"layers.0.feed_forward.w3.weight": "model.layers.0.mlp.up_proj.weight",
23+
# Layernorms
24+
"layers.0.attention_norm.weight": "model.layers.0.input_layernorm.weight",
25+
"layers.0.ffn_norm.weight": "model.layers.0.post_attention_layernorm.weight",
26+
# Quantization scales: compound key must win over individual token
27+
"layers.0.attention.kv_fake_quantizer.qscale_act": "model.layers.0.self_attn.kv_scale",
28+
"layers.0.attention.qscale_act": "model.layers.0.self_attn.input_scale",
29+
# Unknown keys must pass through unchanged
30+
"some.unknown.tensor": "some.unknown.tensor",
31+
}
32+
33+
34+
def test_rename_by_params_map(expected_renames):
35+
mapper = MistralWeightMapper()
36+
dummy = torch.tensor(0.0)
37+
input_weights = {k: dummy for k in expected_renames}
38+
39+
result = mapper.rename_by_params_map(mapper.mistral_llm_mapping, input_weights)
40+
41+
mismatches = {k: v for k, v in expected_renames.items() if v not in result}
42+
assert not mismatches, (
43+
f"Keys not renamed as expected (input -> expected):\n"
44+
+ "\n".join(f" {k!r} -> {v!r}" for k, v in mismatches.items())
45+
+ f"\nActual keys: {sorted(result.keys())}"
46+
)
47+
assert type(result) is dict

0 commit comments

Comments
 (0)