Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 96 additions & 10 deletions olive/passes/pytorch/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from olive.common.quant.hf_utils import (
OliveHfQuantizationConfig,
OliveHfQuantizationMethod,
OliveHfQuantizationOverrideConfig,
replace_matching_submodules,
tie_quant_word_embeddings,
)
Expand Down Expand Up @@ -78,6 +79,74 @@ def get_quantizer_config(allow_embeds: bool = False) -> dict[str, PassConfigPara
}


def get_qkv_quantization_groups(wrapper: ModelWrapper, module_names: set[str] | None = None) -> list[tuple[str, ...]]:
"""Get attention input projection groups that must share quantization settings.

Names are resolved from ``wrapper.model.named_modules()`` to stay correct for any layer
container (``ModuleList``, ``ModuleDict``, custom containers) and for unpacked QKV
submodules. When ``module_names`` is provided, attention inputs not in the set are
dropped from the group. Groups with fewer than two members are skipped.
"""
module_to_name = {id(module): name for name, module in wrapper.model.named_modules()}
qkv_groups = []
for layer_wrapper in wrapper.get_layer_wrappers():
attn_inputs, _ = layer_wrapper.get_attention_inputs()
group = tuple(
name
for name in (module_to_name.get(id(module)) for module in attn_inputs)
if name is not None and (module_names is None or name in module_names)
)
if len(group) > 1:
qkv_groups.append(group)
return qkv_groups


def _quant_config_rank(qargs: dict[str, int | bool]) -> tuple[int, int, int]:
"""Rank quantization configs by precision; higher rank means more precise.

Ordering: higher ``bits`` wins; among equal bits, smaller positive ``group_size`` wins;
per-channel (``-1``) wins over per-tensor (``0``) but loses to positive group sizes.
``symmetric`` is intentionally not part of the ordering since it is a representation
choice rather than a strict precision axis.
"""
bits = qargs["bits"].value if hasattr(qargs["bits"], "value") else qargs["bits"]
group_size = qargs["group_size"]
if group_size > 0:
group_size_rank = (2, -group_size)
elif group_size == -1:
group_size_rank = (1, 0)
else:
group_size_rank = (0, 0)
return bits, *group_size_rank


def normalize_qkv_quant_config(
wrapper: ModelWrapper,
qcfg: OliveHfQuantizationConfig,
module_names: set[str] | None = None,
) -> OliveHfQuantizationConfig:
"""Promote split QKV projection overrides to one shared quantization config.

When ``module_names`` is provided, QKV members not in the set (e.g., excluded from
quantization) are ignored when forming the group and choosing the promoted config.
"""
for group in get_qkv_quantization_groups(wrapper, module_names):
group_qargs = {module_name: qcfg.get_qlinear_init_args(module_name) for module_name in group}
if len({tuple(qargs.items()) for qargs in group_qargs.values()}) == 1:
continue

promoted_qargs = max(group_qargs.values(), key=_quant_config_rank)
logger.debug("Promoting QKV group %s to shared quantization config %s", group, promoted_qargs)
for module_name in group:
override = {key: value for key, value in promoted_qargs.items() if getattr(qcfg, key) != value}
if override:
qcfg.overrides[module_name] = OliveHfQuantizationOverrideConfig(**override)
else:
qcfg.overrides.pop(module_name, None)

return qcfg


def prepare_model(
model: HfModelHandler,
config: type[BasePassConfig],
Expand Down Expand Up @@ -107,16 +176,6 @@ def prepare_model(
wrapper = ModelWrapper.from_model(load_hf_base_model(model))
wrapper.model.eval()

qcfg = get_quant_config(model, config)

originally_tied_embeddings = wrapper.config.tie_word_embeddings
if qcfg.lm_head or qcfg.embeds:
wrapper.maybe_untie_word_embeddings()

lm_head_name = wrapper.get_lm_head()[1]
embeds_name = wrapper.get_embeds()[1][0]
new_qargs: dict[str, dict[str, int | bool]] = {}

excluded_attn_inputs: set[torch.nn.Module] = set()
if exclude_attn_inputs:
for layer_wrapper in wrapper.get_layer_wrappers():
Expand All @@ -126,6 +185,30 @@ def prepare_model(
else:
excluded_attn_inputs.update(attn_inputs[:2])

quantizable_attn_input_names: set[str] | None = None
if excluded_attn_inputs:
module_to_name = {id(module): name for name, module in wrapper.model.named_modules()}
excluded_ids = {id(module) for module in excluded_attn_inputs}
quantizable_attn_input_names = set()
for layer_wrapper in wrapper.get_layer_wrappers():
attn_inputs, _ = layer_wrapper.get_attention_inputs()
for module in attn_inputs:
if id(module) in excluded_ids:
continue
name = module_to_name.get(id(module))
if name is not None:
quantizable_attn_input_names.add(name)

qcfg = normalize_qkv_quant_config(wrapper, get_quant_config(model, config), quantizable_attn_input_names)
Comment thread
hanbitmyths marked this conversation as resolved.

originally_tied_embeddings = wrapper.config.tie_word_embeddings
if qcfg.lm_head or qcfg.embeds:
wrapper.maybe_untie_word_embeddings()

lm_head_name = wrapper.get_lm_head()[1]
embeds_name = wrapper.get_embeds()[1][0]
new_qargs: dict[str, dict[str, int | bool]] = {}

def should_quantize(module: torch.nn.Module, name: str) -> bool:
if module in excluded_attn_inputs:
return False
Expand Down Expand Up @@ -160,6 +243,9 @@ def add_quant_info(module: torch.nn.Module, name: str) -> torch.nn.Module:
merged_qcfg_dict["lm_head"] |= qcfg.lm_head
merged_qcfg_dict["embeds"] |= qcfg.embeds
qcfg = OliveHfQuantizationConfig(**merged_qcfg_dict)
# Re-normalize: the pre-existing overrides we just merged in may violate QKV
# consistency on their own, which can break ModelBuilder's GQA fusion.
qcfg = normalize_qkv_quant_config(wrapper, qcfg, quantizable_attn_input_names)

word_embeddings_eligible_for_tieing = (
originally_tied_embeddings
Expand Down
Loading
Loading