|
26 | 26 | from ..mapping import Mapping |
27 | 27 | from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM |
28 | 28 | from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig |
| 29 | +from ..models.quant_config_utils import \ |
| 30 | + update_quant_config_from_compressed_tensors |
29 | 31 | from ..module import Module |
30 | 32 | from ..quantization.modelopt_config import (is_modelopt_quant_config, |
31 | 33 | read_modelopt_quant_config, |
@@ -470,90 +472,8 @@ def _update_from_hf_quant_config(self) -> bool: |
470 | 472 | ] |
471 | 473 | # NOTE: This is for llm-compressor's quantized checkpoints. |
472 | 474 | elif hf_quant_config.get("quant_method") == "compressed-tensors": |
473 | | - config_groups = hf_quant_config.get("config_groups") |
474 | | - if config_groups is None: |
475 | | - raise ValueError( |
476 | | - f"config_groups is not set in {hf_quant_config}.") |
477 | | - |
478 | | - weights_quant_config = config_groups["group_0"]["weights"] |
479 | | - inputs_quant_config = config_groups["group_0"][ |
480 | | - "input_activations"] |
481 | | - weights_quant_strategy = weights_quant_config["strategy"] |
482 | | - inputs_quant_strategy = inputs_quant_config["strategy"] |
483 | | - |
484 | | - if weights_quant_config["num_bits"] == 8: |
485 | | - if weights_quant_strategy == "channel": |
486 | | - if inputs_quant_strategy != "token": |
487 | | - raise ValueError( |
488 | | - f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." |
489 | | - ) |
490 | | - quant_config.quant_algo = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN |
491 | | - elif weights_quant_strategy == "block": |
492 | | - if inputs_quant_strategy != "group": |
493 | | - raise ValueError( |
494 | | - f"Unsupported inputs_quant_strategy: {inputs_quant_strategy}." |
495 | | - ) |
496 | | - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES |
497 | | - group_size = inputs_quant_config["group_size"] |
498 | | - |
499 | | - # NOTE: TRT-LLM only supports group_size=128 for FP8_BLOCK_SCALES. |
500 | | - if group_size != 128: |
501 | | - raise ValueError( |
502 | | - f"Unsupported group_size: {group_size}. Supported: 128." |
503 | | - ) |
504 | | - quant_config.group_size = group_size |
505 | | - |
506 | | - else: |
507 | | - raise ValueError( |
508 | | - f"Unsupported weights_quant_strategy: {weights_quant_strategy}. " |
509 | | - "Supported strategies: 'channel', 'block'.") |
510 | | - elif (weights_quant_config["num_bits"] == 4 |
511 | | - and weights_quant_config.get("type") == "float" |
512 | | - and weights_quant_strategy == "tensor_group"): |
513 | | - # llm-compressor NVFP4: weights FP4 with FP8 per-group |
514 | | - # scales (group_size=16), scaled by an FP32 global scale. |
515 | | - if inputs_quant_strategy != "tensor_group": |
516 | | - raise ValueError( |
517 | | - f"Unsupported inputs_quant_strategy for NVFP4: {inputs_quant_strategy}." |
518 | | - ) |
519 | | - group_size = weights_quant_config["group_size"] |
520 | | - if group_size != 16: |
521 | | - raise ValueError( |
522 | | - f"Unsupported group_size: {group_size}. Supported: 16 for NVFP4." |
523 | | - ) |
524 | | - quant_config.quant_algo = QuantAlgo.NVFP4 |
525 | | - quant_config.group_size = group_size |
526 | | - else: |
527 | | - raise ValueError( |
528 | | - f"Unsupported quant_bits: {weights_quant_config['num_bits']}. " |
529 | | - "Supported: 8 (FP8) or 4 (NVFP4).") |
530 | | - |
531 | | - # kv_cache_scheme (llm-compressor): FP8 per-tensor KV cache. |
532 | | - kv_cache_scheme = hf_quant_config.get("kv_cache_scheme") |
533 | | - if kv_cache_scheme is not None: |
534 | | - if (kv_cache_scheme.get("num_bits") == 8 |
535 | | - and kv_cache_scheme.get("type") == "float"): |
536 | | - if quant_config.kv_cache_quant_algo in (None, |
537 | | - QuantAlgo.FP8): |
538 | | - quant_config.kv_cache_quant_algo = QuantAlgo.FP8 |
539 | | - else: |
540 | | - raise ValueError( |
541 | | - f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, " |
542 | | - f"conflicting with FP8 KV cache from HF quant config." |
543 | | - ) |
544 | | - else: |
545 | | - raise ValueError( |
546 | | - f"Unsupported kv_cache_scheme: {kv_cache_scheme}.") |
547 | | - |
548 | | - hf_exclude_modules = hf_quant_config.get( |
549 | | - "modules_to_not_convert", None) |
550 | | - if hf_exclude_modules is not None: |
551 | | - quant_config.exclude_modules = list( |
552 | | - set(hf_exclude_modules + |
553 | | - hf_quant_config.get("ignore", []))) |
554 | | - else: |
555 | | - quant_config.exclude_modules = hf_quant_config.get( |
556 | | - "ignore", []) |
| 475 | + update_quant_config_from_compressed_tensors( |
| 476 | + quant_config, hf_quant_config) |
557 | 477 | elif hf_quant_config.get("quant_method") == "nvfp4": |
558 | 478 | quant_config.quant_algo = QuantAlgo.NVFP4 |
559 | 479 | group_size = hf_quant_config.get("group_size", 16) |
|
0 commit comments