Skip to content

Commit c7e7fc5

Browse files
authored
[None] [refactor] Unify compressed-tensors quant config parsing (#14468)
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
1 parent 28718e5 commit c7e7fc5

6 files changed

Lines changed: 406 additions & 156 deletions

File tree

tensorrt_llm/_torch/model_config.py

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tensorrt_llm.logger import logger
2424
from tensorrt_llm.mapping import Mapping
2525
from tensorrt_llm.models.modeling_utils import QuantConfig
26+
from tensorrt_llm.models.quant_config_utils import \
27+
update_quant_config_from_compressed_tensors
2628
from tensorrt_llm.quantization.mode import QuantAlgo
2729
from tensorrt_llm.quantization.modelopt_config import (
2830
is_modelopt_quant_config, read_modelopt_quant_config,
@@ -477,78 +479,8 @@ def load_hf_quant_config(hf_quant_config, moe_backend, checkpoint_dir=None):
477479

478480
# NOTE: This is for llm-compressor's quantized checkpoints.
479481
elif hf_quant_config.get("quant_method") == "compressed-tensors":
480-
config_groups = hf_quant_config.get("config_groups")
481-
if config_groups is None:
482-
raise ValueError(
483-
f"config_groups is not set in {hf_quant_config}.")
484-
485-
weights_quant_config = config_groups["group_0"]["weights"]
486-
inputs_quant_config = config_groups["group_0"]["input_activations"]
487-
weights_quant_strategy = weights_quant_config["strategy"]
488-
inputs_quant_strategy = inputs_quant_config["strategy"]
489-
490-
if weights_quant_config["num_bits"] == 8:
491-
if weights_quant_strategy == "channel":
492-
if inputs_quant_strategy != "token":
493-
raise ValueError(
494-
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
495-
)
496-
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
497-
elif weights_quant_strategy == "block":
498-
if inputs_quant_strategy != "group":
499-
raise ValueError(
500-
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
501-
)
502-
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
503-
group_size = inputs_quant_config["group_size"]
504-
505-
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
506-
if group_size != 128:
507-
raise ValueError(
508-
f"Unsupported group_size: {group_size}. Supported: 128."
509-
)
510-
quant_config.group_size = group_size
511-
512-
else:
513-
raise ValueError(
514-
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
515-
"Supported strategies: 'channel', 'block'.")
516-
elif (weights_quant_config["num_bits"] == 4
517-
and weights_quant_config.get("type") == "float"
518-
and weights_quant_strategy == "tensor_group"):
519-
# llm-compressor NVFP4: weights FP4 with FP8 per-group scales
520-
# (group_size=16), scaled by an FP32 global scale.
521-
if inputs_quant_strategy != "tensor_group":
522-
raise ValueError(
523-
f"Unsupported inputs_quant_strategy for NVFP4: {inputs_quant_strategy}."
524-
)
525-
group_size = weights_quant_config["group_size"]
526-
if group_size != 16:
527-
raise ValueError(
528-
f"Unsupported group_size: {group_size}. Supported: 16 for NVFP4."
529-
)
530-
quant_config.quant_algo = QuantAlgo.NVFP4
531-
quant_config.group_size = group_size
532-
else:
533-
raise ValueError(
534-
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
535-
"Supported: 8 (FP8) or 4 (NVFP4).")
536-
537-
# kv_cache_scheme (llm-compressor): FP8 per-tensor KV cache.
538-
kv_cache_scheme = hf_quant_config.get("kv_cache_scheme")
539-
if kv_cache_scheme is not None:
540-
if (kv_cache_scheme.get("num_bits") == 8
541-
and kv_cache_scheme.get("type") == "float"):
542-
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
543-
else:
544-
raise ValueError(
545-
f"Unsupported kv_cache_scheme: {kv_cache_scheme}.")
546-
547-
if hf_exclude_modules is not None:
548-
quant_config.exclude_modules = list(
549-
set(hf_exclude_modules + hf_quant_config.get("ignore", [])))
550-
else:
551-
quant_config.exclude_modules = hf_quant_config.get("ignore", [])
482+
update_quant_config_from_compressed_tensors(quant_config,
483+
hf_quant_config)
552484
elif hf_quant_config.get("quant_method") == "nvfp4":
553485
quant_config.quant_algo = QuantAlgo.NVFP4
554486
group_size = hf_quant_config.get("group_size", 16)

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 4 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from ..mapping import Mapping
2727
from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM
2828
from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig
29+
from ..models.quant_config_utils import \
30+
update_quant_config_from_compressed_tensors
2931
from ..module import Module
3032
from ..quantization.modelopt_config import (is_modelopt_quant_config,
3133
read_modelopt_quant_config,
@@ -470,90 +472,8 @@ def _update_from_hf_quant_config(self) -> bool:
470472
]
471473
# NOTE: This is for llm-compressor's quantized checkpoints.
472474
elif hf_quant_config.get("quant_method") == "compressed-tensors":
473-
config_groups = hf_quant_config.get("config_groups")
474-
if config_groups is None:
475-
raise ValueError(
476-
f"config_groups is not set in {hf_quant_config}.")
477-
478-
weights_quant_config = config_groups["group_0"]["weights"]
479-
inputs_quant_config = config_groups["group_0"][
480-
"input_activations"]
481-
weights_quant_strategy = weights_quant_config["strategy"]
482-
inputs_quant_strategy = inputs_quant_config["strategy"]
483-
484-
if weights_quant_config["num_bits"] == 8:
485-
if weights_quant_strategy == "channel":
486-
if inputs_quant_strategy != "token":
487-
raise ValueError(
488-
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
489-
)
490-
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
491-
elif weights_quant_strategy == "block":
492-
if inputs_quant_strategy != "group":
493-
raise ValueError(
494-
f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}."
495-
)
496-
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
497-
group_size = inputs_quant_config["group_size"]
498-
499-
# NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
500-
if group_size != 128:
501-
raise ValueError(
502-
f"Unsupported group_size: {group_size}. Supported: 128."
503-
)
504-
quant_config.group_size = group_size
505-
506-
else:
507-
raise ValueError(
508-
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
509-
"Supported strategies: 'channel', 'block'.")
510-
elif (weights_quant_config["num_bits"] == 4
511-
and weights_quant_config.get("type") == "float"
512-
and weights_quant_strategy == "tensor_group"):
513-
# llm-compressor NVFP4: weights FP4 with FP8 per-group
514-
# scales (group_size=16), scaled by an FP32 global scale.
515-
if inputs_quant_strategy != "tensor_group":
516-
raise ValueError(
517-
f"Unsupported inputs_quant_strategy for NVFP4: {inputs_quant_strategy}."
518-
)
519-
group_size = weights_quant_config["group_size"]
520-
if group_size != 16:
521-
raise ValueError(
522-
f"Unsupported group_size: {group_size}. Supported: 16 for NVFP4."
523-
)
524-
quant_config.quant_algo = QuantAlgo.NVFP4
525-
quant_config.group_size = group_size
526-
else:
527-
raise ValueError(
528-
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
529-
"Supported: 8 (FP8) or 4 (NVFP4).")
530-
531-
# kv_cache_scheme (llm-compressor): FP8 per-tensor KV cache.
532-
kv_cache_scheme = hf_quant_config.get("kv_cache_scheme")
533-
if kv_cache_scheme is not None:
534-
if (kv_cache_scheme.get("num_bits") == 8
535-
and kv_cache_scheme.get("type") == "float"):
536-
if quant_config.kv_cache_quant_algo in (None,
537-
QuantAlgo.FP8):
538-
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
539-
else:
540-
raise ValueError(
541-
f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, "
542-
f"conflicting with FP8 KV cache from HF quant config."
543-
)
544-
else:
545-
raise ValueError(
546-
f"Unsupported kv_cache_scheme: {kv_cache_scheme}.")
547-
548-
hf_exclude_modules = hf_quant_config.get(
549-
"modules_to_not_convert", None)
550-
if hf_exclude_modules is not None:
551-
quant_config.exclude_modules = list(
552-
set(hf_exclude_modules +
553-
hf_quant_config.get("ignore", [])))
554-
else:
555-
quant_config.exclude_modules = hf_quant_config.get(
556-
"ignore", [])
475+
update_quant_config_from_compressed_tensors(
476+
quant_config, hf_quant_config)
557477
elif hf_quant_config.get("quant_method") == "nvfp4":
558478
quant_config.quant_algo = QuantAlgo.NVFP4
559479
group_size = hf_quant_config.get("group_size", 16)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Mapping
17+
18+
from tensorrt_llm.models.modeling_utils import QuantConfig
19+
from tensorrt_llm.quantization.mode import QuantAlgo
20+
21+
22+
def update_quant_config_from_compressed_tensors(
23+
quant_config: QuantConfig, hf_quant_config: Mapping[str, Any]
24+
) -> None:
25+
"""Mutate QuantConfig from an llm-compressor compressed-tensors config."""
26+
config_groups = hf_quant_config.get("config_groups")
27+
if config_groups is None:
28+
raise ValueError(f"config_groups is not set in {hf_quant_config}.")
29+
30+
weights_quant_config = config_groups["group_0"]["weights"]
31+
inputs_quant_config = config_groups["group_0"]["input_activations"]
32+
weights_quant_strategy = weights_quant_config["strategy"]
33+
inputs_quant_strategy = inputs_quant_config["strategy"]
34+
35+
if weights_quant_config["num_bits"] == 8:
36+
if weights_quant_strategy == "channel":
37+
if inputs_quant_strategy != "token":
38+
raise ValueError(f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}.")
39+
quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
40+
elif weights_quant_strategy == "block":
41+
if inputs_quant_strategy != "group":
42+
raise ValueError(f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}.")
43+
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
44+
group_size = inputs_quant_config["group_size"]
45+
46+
# TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES.
47+
if group_size != 128:
48+
raise ValueError(f"Unsupported group_size: {group_size}. Supported: 128.")
49+
quant_config.group_size = group_size
50+
51+
else:
52+
raise ValueError(
53+
f"Unsupported weights_quant_strategy: {weights_quant_strategy}. "
54+
"Supported strategies: 'channel', 'block'."
55+
)
56+
elif (
57+
weights_quant_config["num_bits"] == 4
58+
and weights_quant_config.get("type") == "float"
59+
and weights_quant_strategy == "tensor_group"
60+
):
61+
# llm-compressor NVFP4: weights FP4 with FP8 per-group scales
62+
# (group_size=16), scaled by an FP32 global scale.
63+
if inputs_quant_strategy != "tensor_group":
64+
raise ValueError(
65+
f"Unsupported inputs_quant_strategy for NVFP4: {inputs_quant_strategy}."
66+
)
67+
group_size = weights_quant_config["group_size"]
68+
if group_size != 16:
69+
raise ValueError(f"Unsupported group_size: {group_size}. Supported: 16 for NVFP4.")
70+
quant_config.quant_algo = QuantAlgo.NVFP4
71+
quant_config.group_size = group_size
72+
else:
73+
raise ValueError(
74+
f"Unsupported quant_bits: {weights_quant_config['num_bits']}. "
75+
"Supported: 8 (FP8) or 4 (NVFP4)."
76+
)
77+
78+
# kv_cache_scheme (llm-compressor): FP8 per-tensor KV cache.
79+
kv_cache_scheme = hf_quant_config.get("kv_cache_scheme")
80+
if kv_cache_scheme is not None:
81+
if kv_cache_scheme.get("num_bits") == 8 and kv_cache_scheme.get("type") == "float":
82+
if quant_config.kv_cache_quant_algo in (None, QuantAlgo.FP8):
83+
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
84+
else:
85+
raise ValueError(
86+
f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, "
87+
"conflicting with FP8 KV cache from HF quant config."
88+
)
89+
else:
90+
raise ValueError(f"Unsupported kv_cache_scheme: {kv_cache_scheme}.")
91+
92+
hf_exclude_modules = hf_quant_config.get("modules_to_not_convert", None)
93+
if hf_exclude_modules is not None:
94+
quant_config.exclude_modules = list(
95+
set(hf_exclude_modules + hf_quant_config.get("ignore", []))
96+
)
97+
else:
98+
quant_config.exclude_modules = hf_quant_config.get("ignore", [])

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ l0_a10:
1919
- unittest/utils/test_util.py
2020
- unittest/utils/test_logger.py
2121
- unittest/_torch/test_model_config.py
22+
- unittest/models/test_quant_config_utils.py
2223
- unittest/_torch/modeling/test_modeling_mistral.py
2324
- unittest/_torch/modeling/test_modeling_pixtral.py
2425
- unittest/_torch/modeling/test_modeling_cohere2.py

tests/unittest/llmapi/test_kv_cache_dtype_override.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@ def _write_hf_quant_config(model_dir, kv_cache_quant_algo: str = "FP8"):
2222
)
2323

2424

25+
def _compressed_tensors_nvfp4_config(**overrides):
26+
config = {
27+
"quant_method": "compressed-tensors",
28+
"config_groups": {
29+
"group_0": {
30+
"weights": {
31+
"num_bits": 4,
32+
"type": "float",
33+
"strategy": "tensor_group",
34+
"group_size": 16,
35+
},
36+
"input_activations": {
37+
"strategy": "tensor_group",
38+
},
39+
},
40+
},
41+
}
42+
config.update(overrides)
43+
return config
44+
45+
2546
def test_get_llm_args_plumbs_kv_cache_dtype():
2647
llm_args, _ = get_llm_args(model="dummy", kv_cache_dtype="nvfp4")
2748
assert llm_args["kv_cache_config"].dtype == "nvfp4"
@@ -65,3 +86,42 @@ def test_update_from_hf_quant_config_explicit_dtype_overrides(tmp_path):
6586

6687
assert model_loader._update_from_hf_quant_config() is True
6788
assert llm_args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4
89+
90+
91+
def test_update_from_hf_quant_config_parses_compressed_tensors_model_kwargs(tmp_path):
92+
llm_args = TorchLlmArgs(
93+
model=str(tmp_path),
94+
model_kwargs={
95+
"quantization_config": _compressed_tensors_nvfp4_config(
96+
kv_cache_scheme={
97+
"num_bits": 8,
98+
"type": "float",
99+
}
100+
),
101+
},
102+
)
103+
model_loader = ModelLoader(llm_args)
104+
105+
assert model_loader._update_from_hf_quant_config() is True
106+
assert llm_args.quant_config.quant_algo == QuantAlgo.NVFP4
107+
assert llm_args.quant_config.group_size == 16
108+
assert llm_args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
109+
110+
111+
def test_update_from_hf_quant_config_rejects_compressed_tensors_kv_conflict(tmp_path):
112+
llm_args = TorchLlmArgs(
113+
model=str(tmp_path),
114+
model_kwargs={
115+
"quantization_config": _compressed_tensors_nvfp4_config(
116+
kv_cache_scheme={
117+
"num_bits": 8,
118+
"type": "float",
119+
}
120+
),
121+
},
122+
)
123+
llm_args.quant_config = QuantConfig(kv_cache_quant_algo=QuantAlgo.NVFP4)
124+
model_loader = ModelLoader(llm_args)
125+
126+
with pytest.raises(ValueError, match="conflicting with FP8 KV cache"):
127+
model_loader._update_from_hf_quant_config()

0 commit comments

Comments
 (0)