Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
"watersic_kv": "WATERSIC_KV_CFG",
}

# Formats that use use_constant_amax (no calibration needed).
Expand Down Expand Up @@ -384,7 +385,7 @@ def forward_step(model, batch):
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
Expand All @@ -403,6 +404,16 @@ def forward_step(model, batch):
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
):
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)

# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
if args.kv_cache_qformat == "watersic_kv":
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
if args.watersic_target_rate is not None:
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
if args.watersic_kl_aware:
watersic_cfg["algorithm"]["kl_aware"] = True
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)

return language_model


Expand All @@ -423,7 +434,7 @@ def load_model(args: argparse.Namespace):
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
if args.kv_cache_qformat not in {"none", "watersic_kv"}:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
Expand Down Expand Up @@ -652,6 +663,15 @@ def mono_quantize(
else:
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)

# WaterSIC KV-cache needs a separate quantization pass with its own algorithm
if args.kv_cache_qformat == "watersic_kv":
watersic_cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES["watersic_kv"]))
if args.watersic_target_rate is not None:
watersic_cfg["algorithm"]["target_rate"] = args.watersic_target_rate
if args.watersic_kl_aware:
watersic_cfg["algorithm"]["kl_aware"] = True
language_model = mtq.quantize(language_model, watersic_cfg, forward_loop=calibrate_loop)

# For VL models, update full_model to use the quantized language model
if is_nemotron_vl_model:
language_model_lineage = get_language_model_from_vl(full_model)
Expand Down Expand Up @@ -1083,7 +1103,8 @@ def quantize_main(
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
# WaterSIC KV-cache uses a separate quantization pass, so skip merging here.
if enable_quant_kv_cache and args.kv_cache_qformat != "watersic_kv":
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
Expand Down Expand Up @@ -1242,6 +1263,17 @@ def parse_args() -> argparse.Namespace:
"Other formats (fp8, nvfp4, etc.) use data-driven calibration."
),
)
parser.add_argument(
"--watersic_target_rate",
type=float,
default=None,
help="Target bits per element for WaterSIC KV-cache quantization (default: 2.0)",
)
parser.add_argument(
"--watersic_kl_aware",
action="store_true",
help="Enable KL-aware importance weighting for WaterSIC KV-cache quantization",
)
parser.add_argument(
"--export_fmt",
required=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master

from . import config as mtq_config
from . import model_calib
from .config import QuantizeConfig, QuantizerAttributeConfig
from .conversion import set_quantizer_by_cfg
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import is_quantized_linear
from .. import config as mtq_config
from .. import model_calib
from ..config import QuantizeConfig, QuantizerAttributeConfig
from ..conversion import set_quantizer_by_cfg
from ..nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from ..utils import is_quantized_linear


def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
Expand Down Expand Up @@ -615,8 +615,8 @@ def before_search(self):
# Import here to avoid circular import
from modelopt.torch.quantization.model_quant import calibrate

from .conversion import restore_quantizer_state, update_quantize_metadata
from .utils import get_quantizer_state_dict, set_quantizer_state_dict
from ..conversion import restore_quantizer_state, update_quantize_metadata
from ..utils import get_quantizer_state_dict, set_quantizer_state_dict

super().before_search()
restored_method = getattr(self, "method", None)
Expand Down
23 changes: 23 additions & 0 deletions modelopt/torch/quantization/algorithms/watersic_kv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""WaterSIC KV-cache quantization algorithm."""

from __future__ import annotations

from .config import WaterSICKVCalibConfig
from .helper import WaterSICKVHelper, WaterSICKVState

__all__ = ["WaterSICKVCalibConfig", "WaterSICKVHelper", "WaterSICKVState"]
115 changes: 115 additions & 0 deletions modelopt/torch/quantization/algorithms/watersic_kv/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration for the WaterSIC KV-cache quantization algorithm."""

from __future__ import annotations

from typing import Literal

from modelopt.torch.opt.config import ModeloptField
from modelopt.torch.quantization.config import QuantizeAlgorithmConfig


class WaterSICKVCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for WaterSIC KV-cache quantization.

WaterSIC (Water-filling Successive Interference Cancellation) is a
rate-adaptive quantization method for KV-cache compression. It
applies the ZSIC algorithm with optional KL-aware importance
weighting and LMMSE shrinkage correction to minimize attention-output
distortion at a target bits-per-element budget.

Reference: "WaterSIC: Water-filling Successive Interference
Cancellation for KV-Cache Quantization" (2024).
"""

method: Literal["watersic_kv"] = ModeloptField(
"watersic_kv",
title="Calibration algorithm identifier.",
description="Fixed identifier for the WaterSIC KV-cache calibration method.",
)

target_rate: float = ModeloptField(
default=2.0,
gt=0.0,
title="Target bits per element.",
description=(
"Average number of bits per quantized KV-cache element. The binary "
"search over the ZSIC damping parameter c is driven to hit this rate."
),
)

kl_aware: bool = ModeloptField(
default=False,
title="Enable KL-aware importance weighting.",
description=(
"When True, per-token importance weights derived from the attention "
"distribution are folded into the Hessian so that tokens with higher "
"attention mass receive tighter quantization."
),
)

importance_clip: float = ModeloptField(
default=50.0,
gt=0.0,
title="Importance weight clipping ratio.",
description=(
"Maximum ratio by which a single token's importance weight may exceed "
"the mean weight. Clips extreme outlier tokens to prevent them from "
"dominating the Hessian estimate."
),
)

use_lmmse: bool = ModeloptField(
default=True,
title="Apply LMMSE shrinkage correction.",
description=(
"When True, the LMMSE (Linear Minimum Mean-Squared Error) shrinkage "
"correction is applied after ZSIC quantization to partially undo "
"quantization bias and reduce reconstruction NMSE."
),
)

n_rescaler_iters: int = ModeloptField(
default=0,
ge=0,
title="Diagonal rescaler optimization iterations.",
description=(
"Number of coordinate-descent iterations for the diagonal rescaler "
"that adjusts per-column scale factors after LMMSE. Set to 0 to "
"disable the rescaler (faster but slightly higher distortion)."
),
)

sample_frac: float | None = ModeloptField(
default=None,
title="Row subsampling fraction for binary search.",
description=(
"If set, only this fraction of rows (KV heads) are used during the "
"binary search for c. Full rows are then quantized with the found c. "
"Speeds up calibration on large KV caches at a small accuracy cost."
),
)

use_sequential: bool = ModeloptField(
default=False,
title="Enable sequential layer-by-layer calibration.",
description=(
"Must be False for WaterSIC. Unlike weight quantization, KV-cache "
"quantization does not have progressive error accumulation between "
"layers, so sequential calibration is not needed."
),
)
Loading
Loading