From c7bf5b8f15720d11d7723a50a3f6b7f50426470a Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 22 May 2026 16:44:51 -0700 Subject: [PATCH 1/3] selective_mixed_precision: QKV-aware overrides, AUTO memory mode, MULTI_GPU dispatch - Normalize per-layer quant config overrides so Q/K/V projections in the same attention block share precision, required by ModelBuilder for GQA fusion. - Add AUTO setting for kld_memory_mode that picks among FULL, MULTI_GPU, LOW_MEMORY, OFFLOAD based on available GPU memory and model size. - Add MULTI_GPU mode that uses Accelerate's dispatch_model with _no_split_modules honored, plus a coalescing pass that pins every model.layers.N.* entry to a single device and falls back to LOW_MEMORY if a decoder layer still spans devices. - Tests: 24 unit tests covering QKV grouping, AUTO selection thresholds, and the MULTI_GPU device-map coalescing path. --- olive/passes/pytorch/quant_utils.py | 103 ++- .../pytorch/selective_mixed_precision.py | 399 +++++++++-- .../pytorch/test_selective_mixed_precision.py | 642 +++++++++++++++++- 3 files changed, 1085 insertions(+), 59 deletions(-) diff --git a/olive/passes/pytorch/quant_utils.py b/olive/passes/pytorch/quant_utils.py index bdfd08e0da..f468d2a221 100644 --- a/olive/passes/pytorch/quant_utils.py +++ b/olive/passes/pytorch/quant_utils.py @@ -14,6 +14,7 @@ from olive.common.quant.hf_utils import ( OliveHfQuantizationConfig, OliveHfQuantizationMethod, + OliveHfQuantizationOverrideConfig, replace_matching_submodules, tie_quant_word_embeddings, ) @@ -78,6 +79,74 @@ def get_quantizer_config(allow_embeds: bool = False) -> dict[str, PassConfigPara } +def get_qkv_quantization_groups(wrapper: ModelWrapper, module_names: set[str] | None = None) -> list[tuple[str, ...]]: + """Get attention input projection groups that must share quantization settings. + + Names are resolved from ``wrapper.model.named_modules()`` to stay correct for any layer + container (``ModuleList``, ``ModuleDict``, custom containers) and for unpacked QKV + submodules. When ``module_names`` is provided, attention inputs not in the set are + dropped from the group. Groups with fewer than two members are skipped. + """ + module_to_name = {id(module): name for name, module in wrapper.model.named_modules()} + qkv_groups = [] + for layer_wrapper in wrapper.get_layer_wrappers(): + attn_inputs, _ = layer_wrapper.get_attention_inputs() + group = tuple( + name + for name in (module_to_name.get(id(module)) for module in attn_inputs) + if name is not None and (module_names is None or name in module_names) + ) + if len(group) > 1: + qkv_groups.append(group) + return qkv_groups + + +def _quant_config_rank(qargs: dict[str, int | bool]) -> tuple[int, int, int]: + """Rank quantization configs by precision; higher rank means more precise. + + Ordering: higher ``bits`` wins; among equal bits, smaller positive ``group_size`` wins; + per-channel (``-1``) wins over per-tensor (``0``) but loses to positive group sizes. + ``symmetric`` is intentionally not part of the ordering since it is a representation + choice rather than a strict precision axis. + """ + bits = qargs["bits"].value if hasattr(qargs["bits"], "value") else qargs["bits"] + group_size = qargs["group_size"] + if group_size > 0: + group_size_rank = (2, -group_size) + elif group_size == -1: + group_size_rank = (1, 0) + else: + group_size_rank = (0, 0) + return bits, *group_size_rank + + +def normalize_qkv_quant_config( + wrapper: ModelWrapper, + qcfg: OliveHfQuantizationConfig, + module_names: set[str] | None = None, +) -> OliveHfQuantizationConfig: + """Promote split QKV projection overrides to one shared quantization config. + + When ``module_names`` is provided, QKV members not in the set (e.g., excluded from + quantization) are ignored when forming the group and choosing the promoted config. + """ + for group in get_qkv_quantization_groups(wrapper, module_names): + group_qargs = {module_name: qcfg.get_qlinear_init_args(module_name) for module_name in group} + if len({tuple(qargs.items()) for qargs in group_qargs.values()}) == 1: + continue + + promoted_qargs = max(group_qargs.values(), key=_quant_config_rank) + logger.debug("Promoting QKV group %s to shared quantization config %s", group, promoted_qargs) + for module_name in group: + override = {key: value for key, value in promoted_qargs.items() if getattr(qcfg, key) != value} + if override: + qcfg.overrides[module_name] = OliveHfQuantizationOverrideConfig(**override) + else: + qcfg.overrides.pop(module_name, None) + + return qcfg + + def prepare_model( model: HfModelHandler, config: type[BasePassConfig], @@ -107,16 +176,6 @@ def prepare_model( wrapper = ModelWrapper.from_model(load_hf_base_model(model)) wrapper.model.eval() - qcfg = get_quant_config(model, config) - - originally_tied_embeddings = wrapper.config.tie_word_embeddings - if qcfg.lm_head or qcfg.embeds: - wrapper.maybe_untie_word_embeddings() - - lm_head_name = wrapper.get_lm_head()[1] - embeds_name = wrapper.get_embeds()[1][0] - new_qargs: dict[str, dict[str, int | bool]] = {} - excluded_attn_inputs: set[torch.nn.Module] = set() if exclude_attn_inputs: for layer_wrapper in wrapper.get_layer_wrappers(): @@ -126,6 +185,30 @@ def prepare_model( else: excluded_attn_inputs.update(attn_inputs[:2]) + quantizable_attn_input_names: set[str] | None = None + if excluded_attn_inputs: + module_to_name = {id(module): name for name, module in wrapper.model.named_modules()} + excluded_ids = {id(module) for module in excluded_attn_inputs} + quantizable_attn_input_names = set() + for layer_wrapper in wrapper.get_layer_wrappers(): + attn_inputs, _ = layer_wrapper.get_attention_inputs() + for module in attn_inputs: + if id(module) in excluded_ids: + continue + name = module_to_name.get(id(module)) + if name is not None: + quantizable_attn_input_names.add(name) + + qcfg = normalize_qkv_quant_config(wrapper, get_quant_config(model, config), quantizable_attn_input_names) + + originally_tied_embeddings = wrapper.config.tie_word_embeddings + if qcfg.lm_head or qcfg.embeds: + wrapper.maybe_untie_word_embeddings() + + lm_head_name = wrapper.get_lm_head()[1] + embeds_name = wrapper.get_embeds()[1][0] + new_qargs: dict[str, dict[str, int | bool]] = {} + def should_quantize(module: torch.nn.Module, name: str) -> bool: if module in excluded_attn_inputs: return False diff --git a/olive/passes/pytorch/selective_mixed_precision.py b/olive/passes/pytorch/selective_mixed_precision.py index fd5fe403d5..ce3f21b48d 100644 --- a/olive/passes/pytorch/selective_mixed_precision.py +++ b/olive/passes/pytorch/selective_mixed_precision.py @@ -19,6 +19,7 @@ from olive.model import HfModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam +from olive.passes.pytorch.quant_utils import get_qkv_quantization_groups from olive.passes.pytorch.train_utils import get_calibration_dataset, kl_div_loss, load_hf_base_model from olive.search.search_parameter import Categorical @@ -61,6 +62,25 @@ class Algorithm(StrEnumBase): SNR = "snr" SNR_RELATIVE = "snr_relative" + class KldMemoryMode(StrEnumBase): + """Memory mode for KL Divergence gradient based selection. + + - ``auto``: pick one of the modes below based on the model size and free device memory. + - ``full``: keep a per-layer fp32 gradient accumulator (legacy behaviour, highest peak memory). + - ``multi_gpu``: same algorithm as ``full`` but shard teacher and student across all visible + CUDA devices via ``accelerate``. Used when the model does not fit on a single GPU but fits + across multiple GPUs. Falls back to ``low_memory`` if ``accelerate`` is not installed. + - ``low_memory``: stream the alignment to a scalar accumulator per layer; teacher and student + stay on device. + - ``offload``: also keep teacher and student off device when not in use; lowest peak memory. + """ + + AUTO = "auto" + FULL = "full" + MULTI_GPU = "multi_gpu" + LOW_MEMORY = "low_memory" + OFFLOAD = "offload" + @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: return { @@ -123,6 +143,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon " kld_gradient algorithms. Must be provided when using these algorithms." ), ), + "kld_memory_mode": PassConfigParam( + type_=SelectiveMixedPrecision.KldMemoryMode, + default_value=SelectiveMixedPrecision.KldMemoryMode.AUTO, + description=( + "Memory mode for kld_gradient. ``auto`` (default) picks among ``full``, ``multi_gpu``," + " ``low_memory`` and ``offload`` based on the model size and free device memory." + " ``full`` keeps a per-layer fp32 gradient accumulator (legacy behaviour)." + " ``multi_gpu`` runs the ``full`` algorithm with teacher and student sharded across all" + " visible CUDA devices via ``accelerate``." + " ``low_memory`` streams the alignment to a scalar per layer." + " ``offload`` also keeps teacher and student off device when not in use." + ), + ), } @classmethod @@ -165,6 +198,7 @@ def _run_for_config( config.high_group_size if config.high_group_size is not None else config.group_size, config.high_sym if config.high_sym is not None else config.sym, config.ratio, + config.kld_memory_mode, ) lm_head_name = model_wrapper.get_lm_head()[1] @@ -217,6 +251,43 @@ def get_k_quant_config( return {"bits": bits}, overrides + @staticmethod + def get_overrides_from_scores( + model_wrapper: ModelWrapper, + module_numels: dict[str, int], + module_scores: dict[str, float], + high_override_config: dict, + ratio: float, + ) -> tuple[dict[str, dict], int]: + """Get high precision overrides from sensitivity scores.""" + qkv_groups = get_qkv_quantization_groups(model_wrapper, set(module_scores)) + grouped_modules = {module_name for group in qkv_groups for module_name in group} + + scored_items = [ + ( + group, + sum(module_numels[module_name] for module_name in group), + min(module_scores[name] for name in group), + ) + for group in qkv_groups + ] + scored_items.extend( + ((module_name,), module_numels[module_name], score) + for module_name, score in module_scores.items() + if module_name not in grouped_modules + ) + + threshold = sum(module_numels.values()) * (1 - ratio) + overrides = {} + high_precision_numels = 0 + for module_names, numels, _ in sorted(scored_items, key=lambda item: item[2]): + high_precision_numels += numels + overrides.update({module_name: high_override_config.copy() for module_name in module_names}) + if high_precision_numels >= threshold: + break + + return overrides, high_precision_numels + @staticmethod def get_scored_config( handler: HfModelHandler, @@ -229,6 +300,7 @@ def get_scored_config( high_group_size: int, high_symmetric: bool, ratio: float, + kld_memory_mode: KldMemoryMode = KldMemoryMode.AUTO, ): """Get mixed precision config based on sensitivity scores.""" quantizer = WeightQuantizer(bits=bits, group_size=group_size, symmetric=symmetric) @@ -239,36 +311,39 @@ def get_scored_config( ) device = "cuda" if torch.cuda.is_available() else "cpu" - algo_func = ( - SelectiveMixedPrecision.get_kld_scores - if algorithm == SelectiveMixedPrecision.Algorithm.KLD_GRADIENT - else SelectiveMixedPrecision.get_snr_iqe_scores - ) - module_numels, module_scores = algo_func( - handler, - model_wrapper.model, - algorithm, - quantizer, - high_quantizer, - device, - ) + if algorithm == SelectiveMixedPrecision.Algorithm.KLD_GRADIENT: + module_numels, module_scores = SelectiveMixedPrecision.get_kld_scores( + handler, + model_wrapper.model, + algorithm, + quantizer, + high_quantizer, + device, + kld_memory_mode, + ) + else: + module_numels, module_scores = SelectiveMixedPrecision.get_snr_iqe_scores( + handler, + model_wrapper.model, + algorithm, + quantizer, + high_quantizer, + device, + ) - threshold = sum(module_numels.values()) * (1 - ratio) - # ascending order, lower score means more sensitive and should be in higher precision - sorted_modules = sorted(module_scores, key=lambda item: module_scores[item], reverse=False) - overrides = {} high_override_config = {"bits": high_bits, "group_size": high_group_size, "symmetric": high_symmetric} - total = 0 - for module_name in sorted_modules: - total += module_numels[module_name] - overrides[module_name] = high_override_config - if total >= threshold: - break + overrides, high_precision_numels = SelectiveMixedPrecision.get_overrides_from_scores( + model_wrapper, + module_numels, + module_scores, + high_override_config, + ratio, + ) logger.info( "Selected %d modules for high precision out of %d modules. Ratio of low precision: %.4f", len(overrides), len(module_numels), - 1 - total / sum(module_numels.values()), + 1 - high_precision_numels / sum(module_numels.values()), ) return {"bits": bits, "group_size": group_size, "symmetric": symmetric}, overrides @@ -313,6 +388,104 @@ def process_module(module, module_name): replace_matching_submodules(model, should_include, process_module, description="Computing SNR/IQE scores") return module_numels, module_scores + @staticmethod + def _estimate_kld_memory_bytes(model: torch.nn.Module) -> tuple[int, int, int]: + """Estimate parameter bytes and peak KLD memory for FULL and LOW_MEMORY modes.""" + # Activations are bounded by gradient checkpointing; budget this fraction of model bytes as headroom. + activation_budget_ratio = 0.2 + # Multiplicative safety factor applied to absorb allocator fragmentation and short-lived temporaries. + memory_safety_factor = 1.2 + # Bytes per element for the fp32 gradient accumulator held by the FULL mode. + fp32_bytes_per_element = 4 + + param_bytes = sum(parameter.numel() * parameter.element_size() for parameter in model.parameters()) + linear_grad_bytes = sum( + module.weight.numel() * fp32_bytes_per_element + for module in model.modules() + if isinstance(module, torch.nn.Linear) + ) + activation_budget = int(activation_budget_ratio * param_bytes) + + full_estimate = int((2 * param_bytes + linear_grad_bytes + activation_budget) * memory_safety_factor) + low_estimate = int((2 * param_bytes + activation_budget) * memory_safety_factor) + return param_bytes, full_estimate, low_estimate + + @staticmethod + def _get_kld_memory_budget(free_bytes: int) -> int: + """Return the usable free-memory budget for KLD mode selection.""" + # Leave ~15% headroom for allocator fragmentation and underestimated activation peaks. + free_memory_budget_ratio = 0.85 + return int(free_bytes * free_memory_budget_ratio) + + @staticmethod + def _get_kld_multi_gpu_max_memory(model: torch.nn.Module, free_per_gpu: list[int]) -> dict[int, int]: + """Return per-GPU model-copy limits that leave room for FULL-mode KLD memory.""" + param_bytes, full_estimate, _ = SelectiveMixedPrecision._estimate_kld_memory_bytes(model) + # Cap each GPU at the parameter share of the full estimate so the remainder of the budget + # stays free for the second model copy, the fp32 grad accumulator, and activations. + per_model_memory_fraction = param_bytes / full_estimate if full_estimate else 1.0 + return { + device_idx: int(SelectiveMixedPrecision._get_kld_memory_budget(free_bytes) * per_model_memory_fraction) + for device_idx, free_bytes in enumerate(free_per_gpu) + } + + @staticmethod + def resolve_kld_memory_mode( + model: torch.nn.Module, + device: str, + kld_memory_mode: KldMemoryMode, + ) -> KldMemoryMode: + """Resolve ``KldMemoryMode.AUTO`` to a concrete mode for ``model`` on ``device``. + + On CPU we always prefer the ``full`` legacy path since host memory is usually ample. + On CUDA we estimate the peak device memory for each mode and pick the most accurate + mode whose estimate fits in free device memory with safety headroom. + """ + if kld_memory_mode != SelectiveMixedPrecision.KldMemoryMode.AUTO: + return kld_memory_mode + + if not device.startswith("cuda") or not torch.cuda.is_available(): + logger.info("KLD memory mode auto-selected: full (non-CUDA device %s).", device) + return SelectiveMixedPrecision.KldMemoryMode.FULL + + gpu_count = torch.cuda.device_count() + if gpu_count == 0: + logger.warning("CUDA reports available but no devices visible; defaulting to offload.") + return SelectiveMixedPrecision.KldMemoryMode.OFFLOAD + + try: + free_per_gpu = [torch.cuda.mem_get_info(i)[0] for i in range(gpu_count)] + except Exception as exc: # pragma: no cover - depends on driver/runtime + logger.warning("Failed to query free CUDA memory (%s); defaulting to offload.", exc) + return SelectiveMixedPrecision.KldMemoryMode.OFFLOAD + + _, full_estimate, low_estimate = SelectiveMixedPrecision._estimate_kld_memory_bytes(model) + single_gpu_budget = SelectiveMixedPrecision._get_kld_memory_budget(free_per_gpu[0]) + multi_gpu_budget = sum( + SelectiveMixedPrecision._get_kld_memory_budget(free_bytes) for free_bytes in free_per_gpu + ) + + if full_estimate <= single_gpu_budget: + chosen = SelectiveMixedPrecision.KldMemoryMode.FULL + elif gpu_count > 1 and full_estimate <= multi_gpu_budget: + chosen = SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU + elif low_estimate <= single_gpu_budget: + chosen = SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY + else: + chosen = SelectiveMixedPrecision.KldMemoryMode.OFFLOAD + + logger.info( + "KLD memory mode auto-selected: %s (gpus=%d, full=%.2f GB, low=%.2f GB," + " single_budget=%.2f GB, multi_budget=%.2f GB).", + chosen, + gpu_count, + full_estimate / 1e9, + low_estimate / 1e9, + single_gpu_budget / 1e9, + multi_gpu_budget / 1e9, + ) + return chosen + @staticmethod def get_kld_scores( handler: HfModelHandler, @@ -321,6 +494,7 @@ def get_kld_scores( quantizer: WeightQuantizer, high_quantizer: WeightQuantizer, device: str, + kld_memory_mode: KldMemoryMode = KldMemoryMode.AUTO, ) -> tuple[dict[str, int], dict[str, float]]: """Compute KL Divergence gradient based sensitivity scores. @@ -333,8 +507,90 @@ def get_kld_scores( # TODO(jambayk): make data_config configurable data = get_calibration_dataset(handler, max_seq_len=512, max_samples=256) - model.to(device).eval() - q_model = deepcopy(model).to(device).eval() + resolved_mode = SelectiveMixedPrecision.resolve_kld_memory_mode(model, device, kld_memory_mode) + if resolved_mode == SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU: + if not (device.startswith("cuda") and torch.cuda.is_available() and torch.cuda.device_count() > 1): + logger.warning( + "kld_memory_mode=multi_gpu requires at least two visible CUDA devices; falling back to low_memory." + ) + resolved_mode = SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY + else: + import importlib.util + + if importlib.util.find_spec("accelerate") is None: + logger.warning( + "kld_memory_mode=multi_gpu requires the 'accelerate' package; falling back to low_memory." + ) + resolved_mode = SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY + # Offloading between host and device is only meaningful on a non-CPU device; on CPU the + # transfers degenerate to no-ops, so we keep the low-memory path to avoid redundant work. + offload = resolved_mode == SelectiveMixedPrecision.KldMemoryMode.OFFLOAD and device != "cpu" + multi_gpu = resolved_mode == SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU + # MULTI_GPU runs the same per-layer fp32 grad accumulator algorithm as FULL, just sharded. + full_memory = resolved_mode == SelectiveMixedPrecision.KldMemoryMode.FULL or multi_gpu + if multi_gpu: + from accelerate import dispatch_model, infer_auto_device_map + + # Keep both copies on CPU before dispatching so deepcopy is safe and the device map + # can be inferred once on the un-dispatched model. + model.to("cpu").eval() + q_model = deepcopy(model).eval() + no_split = getattr(model, "_no_split_modules", None) or [] + + free_per_gpu = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())] + max_memory = SelectiveMixedPrecision._get_kld_multi_gpu_max_memory(model, free_per_gpu) + + device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split) + # Coalesce any sub-decoder-layer placements onto a single device so accelerate hooks do + # not need to cross device boundaries inside a transformer block (which breaks pointwise + # ops like the MLP gate*up product). + coalesced_map: dict[str, object] = {} + layer_devices: dict[str, object] = {} + for module_name, mapped_device in device_map.items(): + parts = module_name.split(".") + if len(parts) >= 3 and parts[0] == "model" and parts[1] == "layers": + layer_key = ".".join(parts[:3]) + layer_devices.setdefault(layer_key, mapped_device) + coalesced_map[module_name] = layer_devices[layer_key] + else: + coalesced_map[module_name] = mapped_device + device_map = coalesced_map + if any(str(mapped_device) in {"cpu", "disk"} for mapped_device in device_map.values()): + logger.warning( + "Unable to place kld_memory_mode=multi_gpu fully on CUDA devices; falling back to low_memory." + ) + multi_gpu = False + full_memory = False + else: + # Verify no decoder layer's submodules are spread across devices. + layer_groups: dict[str, set] = {} + for module_name, mapped_device in device_map.items(): + parts = module_name.split(".") + if len(parts) >= 3 and parts[0] == "model" and parts[1] == "layers": + layer_groups.setdefault(parts[2], set()).add(str(mapped_device)) + split_layers = [layer for layer, devices in layer_groups.items() if len(devices) > 1] + if split_layers: + logger.warning( + "kld_memory_mode=multi_gpu device_map split decoder layer(s) %s across " + "devices; falling back to low_memory.", + split_layers[:5], + ) + multi_gpu = False + full_memory = False + else: + device_counts: dict[str, int] = {} + for mapped_device in device_map.values(): + device_counts[str(mapped_device)] = device_counts.get(str(mapped_device), 0) + 1 + logger.info( + "kld_memory_mode=multi_gpu device_map: %d entries across %s.", + len(device_map), + device_counts, + ) + model = dispatch_model(model, device_map=device_map).eval() + q_model = dispatch_model(q_model, device_map=device_map).eval() + if not multi_gpu: + model.to("cpu" if offload else device).eval() + q_model = deepcopy(model).to(device).eval() # freeze all parameters for param in q_model.parameters(): @@ -345,6 +601,7 @@ def get_kld_scores( # replace the weights in qmodel with low-bit quantized weights module_numels = {} q_layers: dict[str, torch.nn.Module] = {} + sensitivity_sums: dict[str, float] = {} grad_accum: dict[str, torch.Tensor] = {} def should_include(module, _): @@ -357,44 +614,90 @@ def process_module(module, module_name): module.weight.data = low_w module.weight.requires_grad = True q_layers[module_name] = module - grad_accum[module_name] = torch.zeros_like(module.weight.data, dtype=torch.float32) + sensitivity_sums[module_name] = 0.0 + if full_memory: + grad_accum[module_name] = torch.zeros_like(module.weight.data, dtype=torch.float32) return module replace_matching_submodules( q_model, should_include, process_module, description="Preparing for sensitivity estimation" ) + def empty_device_cache(): + if not (device.startswith("cuda") and torch.cuda.is_available()): + return + if multi_gpu: + for i in range(torch.cuda.device_count()): + with torch.cuda.device(i): + torch.cuda.empty_cache() + else: + torch.cuda.empty_cache() + + if offload: + q_model.to("cpu") + empty_device_cache() + + @torch.no_grad() + def get_teacher_logits(inputs: dict[str, torch.Tensor]) -> torch.Tensor: + if offload: + model.to(device) + teacher_logits = model(**inputs).logits + if offload: + model.to("cpu") + empty_device_cache() + return teacher_logits + + @torch.no_grad() + def accumulate_full_grads(): + for module_name, layer in q_layers.items(): + if layer.weight.grad is None: + raise ValueError(f"Missing gradient for {module_name} while estimating KLD sensitivity.") + grad_accum[module_name] += layer.weight.grad.data.detach().float() + + @torch.no_grad() + def accumulate_streaming_sensitivities(): + for module_name, layer in q_layers.items(): + if layer.weight.grad is None: + raise ValueError(f"Missing gradient for {module_name} while estimating KLD sensitivity.") + + source_weight = get_attr(model, module_name).weight.data.to(layer.weight.device) + high_w = high_quantizer.fake_quantize(source_weight) + alignment = layer.weight.grad.data.detach().float() * (layer.weight.data - high_w).float() + sensitivity_sums[module_name] += alignment.sum().item() + del source_weight, high_w + for batch in tqdm(data, desc="Estimating sensitivities"): inputs = {k: v.to(device) for k, v in batch.items()} - with torch.no_grad(): - teacher_logits = model(**inputs).logits + teacher_logits = get_teacher_logits(inputs) + if offload: + q_model.to(device) student_logits = q_model(**inputs).logits loss = kl_div_loss(student_logits, teacher_logits).mean() loss.backward() - # accumulate gradients - for name, layer in q_layers.items(): - grad_accum[name] += layer.weight.grad.data.detach().float() - - # zero grads - q_model.zero_grad() - - @torch.no_grad() - def compute_sensitivity(module_name: str) -> torch.Tensor: - grad = grad_accum[module_name] / len(data) # average gradient - - # high-precision quantization baseline - high_w = high_quantizer.fake_quantize(get_attr(model, module_name).weight.data) - - # get sensitivity - param_size_m = module_numels[module_name] / 1e6 - alignment = (grad * (q_layers[module_name].weight.data - high_w)).sum().item() - return alignment / param_size_m + if full_memory: + accumulate_full_grads() + else: + accumulate_streaming_sensitivities() + q_model.zero_grad(set_to_none=True) + if offload: + q_model.to("cpu") + empty_device_cache() + del teacher_logits, student_logits, loss + + if full_memory: + with torch.no_grad(): + for module_name, layer in q_layers.items(): + avg_grad = grad_accum[module_name] / len(data) + high_w = high_quantizer.fake_quantize(get_attr(model, module_name).weight.data) + sensitivity_sums[module_name] = (avg_grad * (layer.weight.data - high_w)).sum().item() * len(data) # negative sensitivity because lower is more sensitive - return module_numels, {name: -compute_sensitivity(name) for name in q_layers} + return module_numels, { + name: -(sensitivity_sums[name] / len(data)) / (module_numels[name] / 1e6) for name in q_layers + } @staticmethod def compute_snr(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-12) -> float: diff --git a/test/passes/pytorch/test_selective_mixed_precision.py b/test/passes/pytorch/test_selective_mixed_precision.py index db2570a3cf..a9bb78db49 100644 --- a/test/passes/pytorch/test_selective_mixed_precision.py +++ b/test/passes/pytorch/test_selective_mixed_precision.py @@ -2,17 +2,150 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import importlib.util +import sys +from copy import deepcopy +from types import ModuleType, SimpleNamespace + import pytest import torch from transformers import LlamaConfig, LlamaForCausalLM +from olive.common.hf.wrapper import ModelWrapper +from olive.common.quant.utils import WeightQuantizer from olive.constants import PrecisionBits from olive.model import HfModelHandler from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.pytorch import selective_mixed_precision as smp_module +from olive.passes.pytorch.quant_utils import _quant_config_rank, get_qkv_quantization_groups, prepare_model from olive.passes.pytorch.selective_mixed_precision import SelectiveMixedPrecision +from olive.passes.pytorch.train_utils import kl_div_loss from test.utils import get_tiny_phi3 +class KldGradientTestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Linear(4, 4, bias=False) + self.out = torch.nn.Linear(4, 4, bias=False) + self.gradient_checkpointing_enabled = False + + with torch.no_grad(): + self.proj.weight.copy_( + torch.tensor( + [ + [0.10, -0.20, 0.30, -0.40], + [0.50, 0.60, -0.70, -0.80], + [-0.15, 0.25, -0.35, 0.45], + [0.55, -0.65, 0.75, -0.85], + ] + ) + ) + self.out.weight.copy_( + torch.tensor( + [ + [0.35, -0.45, 0.55, -0.65], + [-0.75, 0.85, -0.95, 1.05], + [0.12, 0.22, -0.32, -0.42], + [-0.52, 0.62, 0.72, -0.82], + ] + ) + ) + + def gradient_checkpointing_enable(self): + self.gradient_checkpointing_enabled = True + + def forward(self, tokens): + hidden_states = torch.tanh(self.proj(tokens)) + return SimpleNamespace(logits=self.out(hidden_states)) + + +def get_kld_gradient_test_data(): + return [ + { + "tokens": torch.tensor( + [ + [[0.20, -0.70, 1.10, 0.50], [0.30, 0.40, -0.90, 0.80]], + [[-0.60, 0.10, 0.70, -0.20], [0.90, -0.30, 0.20, -0.50]], + ] + ) + }, + { + "tokens": torch.tensor( + [ + [[-0.40, 0.60, -0.10, 0.30], [0.50, -0.80, 0.40, 0.20]], + [[0.70, 0.20, -0.60, 0.10], [-0.30, 0.90, -0.50, -0.70]], + ] + ) + }, + ] + + +def get_kld_gradient_quantizers(): + return WeightQuantizer(bits=4, group_size=4, symmetric=False), WeightQuantizer(bits=8, group_size=4, symmetric=True) + + +def get_legacy_kld_scores(model, data, quantizer, high_quantizer, device): + model.to(device).eval() + quantized_model = deepcopy(model).to(device).eval() + + for parameter in quantized_model.parameters(): + parameter.requires_grad = False + quantized_model.gradient_checkpointing_enable() + + module_numels = {} + quantized_layers = {} + grad_accum = {} + + with torch.no_grad(): + for module_name, module in quantized_model.named_modules(): + if not isinstance(module, torch.nn.Linear): + continue + + module_numels[module_name] = module.weight.numel() + module.weight.data = quantizer.fake_quantize(module.weight.data) + module.weight.requires_grad = True + quantized_layers[module_name] = module + grad_accum[module_name] = torch.zeros_like(module.weight.data, dtype=torch.float32) + + for batch in data: + inputs = {key: value.to(device) for key, value in batch.items()} + + with torch.no_grad(): + teacher_logits = model(**inputs).logits + + student_logits = quantized_model(**inputs).logits + loss = kl_div_loss(student_logits, teacher_logits).mean() + loss.backward() + + for module_name, layer in quantized_layers.items(): + grad_accum[module_name] += layer.weight.grad.data.detach().float() + quantized_model.zero_grad() + + scores = {} + with torch.no_grad(): + for module_name, layer in quantized_layers.items(): + grad = grad_accum[module_name] / len(data) + high_weight = high_quantizer.fake_quantize(model.get_submodule(module_name).weight.data) + param_size_m = module_numels[module_name] / 1e6 + scores[module_name] = -((grad * (layer.weight.data - high_weight)).sum().item() / param_size_m) + + return module_numels, scores + + +def patch_kld_calibration_data(monkeypatch, data): + monkeypatch.setattr( + "olive.passes.pytorch.selective_mixed_precision.get_calibration_dataset", + lambda *_args, **_kwargs: data, + ) + + +def assert_scores_close(actual_scores, expected_scores): + assert actual_scores.keys() == expected_scores.keys() + for module_name, expected_score in expected_scores.items(): + assert actual_scores[module_name] == pytest.approx(expected_score, rel=1e-6, abs=1e-6) + + @pytest.fixture(name="input_model", scope="module") def input_model_fixture(tmp_path_factory): save_path = tmp_path_factory.mktemp("selective-mixed-precision-test") @@ -39,7 +172,12 @@ def input_model_fixture(tmp_path_factory): ], ) def test_selective_mixed_precision_k_quant(algorithm, expected_layer_indices, include_qkv, input_model, tmp_path): - """Test SelectiveMixedPrecision pass with different algorithms.""" + """End-to-end: rule-based k_quant_* algorithms write the expected mixed_precision_info. + + Verifies that each k_quant variant promotes the correct subset of layers (first 1/8, + every 3rd, last 1/8) and that ``k_quant_mixed`` additionally promotes Q/K/V together, + while ``k_quant_last`` (lm_head only) leaves all transformer layers untouched. + """ config = {"algorithm": algorithm} p = create_pass_from_dict(SelectiveMixedPrecision, config, disable_search=True) @@ -71,8 +209,510 @@ def test_selective_mixed_precision_k_quant(algorithm, expected_layer_indices, in assert output_model.model_attributes["mixed_precision_info"] == expected_mp_info +def test_selective_mixed_precision_scored_keeps_qkv_same_precision(input_model): + """Score-based selection must promote Q/K/V as a single group (separate-projection model). + + Even though only k_proj has the low/sensitive score, q_proj/v_proj must also be promoted + so the three attention input projections share the same precision (required by fused QKV + kernels downstream). + """ + model_wrapper = ModelWrapper.from_model(input_model.load_model()) + q_proj = "model.layers.0.self_attn.q_proj" + k_proj = "model.layers.0.self_attn.k_proj" + v_proj = "model.layers.0.self_attn.v_proj" + down_proj = "model.layers.0.mlp.down_proj" + module_numels = {q_proj: 1, k_proj: 1, v_proj: 1, down_proj: 100} + module_scores = {q_proj: 10.0, k_proj: 1.0, v_proj: 10.0, down_proj: 5.0} + + overrides, high_precision_numels = SelectiveMixedPrecision.get_overrides_from_scores( + model_wrapper, + module_numels, + module_scores, + ratio=0.98, + high_override_config={"bits": PrecisionBits.BITS8}, + ) + + assert overrides == { + q_proj: {"bits": PrecisionBits.BITS8}, + k_proj: {"bits": PrecisionBits.BITS8}, + v_proj: {"bits": PrecisionBits.BITS8}, + } + assert high_precision_numels == 3 + + +def test_selective_mixed_precision_scored_ignores_single_packed_qkv(): + """A packed single qkv_proj (phi3 before unpacking) is not treated as a QKV group. + + Grouping only applies when Q/K/V are distinct modules; a fused single projection has + nothing to co-promote and must be skipped. + """ + model_wrapper = ModelWrapper.from_model(get_tiny_phi3().load_model()) + + assert not get_qkv_quantization_groups(model_wrapper, {"model.layers.0.self_attn.qkv_proj"}) + + +def test_selective_mixed_precision_scored_groups_unpacked_qkv(): + """After ``maybe_unpack_qkv()``, the unpacked Q/K/V submodules form a single group. + + Confirms that phi3-style models still get correct QKV co-promotion once the packed + qkv_proj has been split into ``qkv_proj.{q,k,v}_proj`` children. + """ + model_wrapper = ModelWrapper.from_model(get_tiny_phi3().load_model()) + model_wrapper.maybe_unpack_qkv() + q_proj = "model.layers.0.self_attn.qkv_proj.q_proj" + k_proj = "model.layers.0.self_attn.qkv_proj.k_proj" + v_proj = "model.layers.0.self_attn.qkv_proj.v_proj" + down_proj = "model.layers.0.mlp.down_proj" + module_numels = {q_proj: 1, k_proj: 1, v_proj: 1, down_proj: 100} + module_scores = {q_proj: 10.0, k_proj: 1.0, v_proj: 10.0, down_proj: 5.0} + + overrides, high_precision_numels = SelectiveMixedPrecision.get_overrides_from_scores( + model_wrapper, + module_numels, + module_scores, + ratio=0.98, + high_override_config={"bits": PrecisionBits.BITS8}, + ) + + assert overrides == { + q_proj: {"bits": PrecisionBits.BITS8}, + k_proj: {"bits": PrecisionBits.BITS8}, + v_proj: {"bits": PrecisionBits.BITS8}, + } + assert high_precision_numels == 3 + + +def test_quant_config_promotes_user_override_conflicts_for_qkv(input_model): + """``normalize_qkv_quant_config`` promotes the most-precise config across the QKV group. + + When k_proj/v_proj are int8/sym/g16 but q_proj is left at int4, all three must end up + int8/sym/g16 (highest bits wins) so the fused kernel sees a single shared config. + """ + model = HfModelHandler( + input_model.model_path, + model_attributes={ + "mixed_precision_info": { + "default": {"bits": PrecisionBits.BITS4, "group_size": 16, "symmetric": False}, + "overrides": { + "model.layers.0.self_attn.k_proj": { + "bits": PrecisionBits.BITS8, + "group_size": 16, + "symmetric": True, + }, + "model.layers.0.self_attn.v_proj": { + "bits": PrecisionBits.BITS8, + "group_size": 16, + "symmetric": True, + }, + }, + } + }, + ) + config = SimpleNamespace( + bits=PrecisionBits.BITS4, + sym=False, + group_size=16, + lm_head=False, + overrides={"model.layers.0.self_attn.q_proj": {"bits": PrecisionBits.BITS4}}, + ) + + wrapper, qcfg, _ = prepare_model(model, config) + + qkv_qargs = [qcfg.get_qlinear_init_args(f"model.layers.0.self_attn.{name}_proj") for name in ["q", "k", "v"]] + assert qkv_qargs == [ + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + ] + assert [ + getattr(wrapper.model.model.layers[0].self_attn, f"{name}_proj").quant_info.quantizer.bits + for name in ["q", "k", "v"] + ] == [PrecisionBits.BITS8, PrecisionBits.BITS8, PrecisionBits.BITS8] + + +def test_quant_config_rank_prefers_bits_then_smaller_positive_group_size(): + """Unit test for ``_quant_config_rank`` ordering used to promote QKV groups. + + Higher ``bits`` wins; among equal bits, smaller positive ``group_size`` wins; per-channel + wins over per-tensor; ``symmetric`` does not affect the rank. + """ + symmetric_qargs = {"bits": PrecisionBits.BITS4, "group_size": 16, "symmetric": True} + asymmetric_qargs = {"bits": PrecisionBits.BITS4, "group_size": 16, "symmetric": False} + group_size_qargs = [ + {"bits": PrecisionBits.BITS4, "group_size": 128, "symmetric": True}, + {"bits": PrecisionBits.BITS4, "group_size": 32, "symmetric": True}, + {"bits": PrecisionBits.BITS4, "group_size": -1, "symmetric": True}, + {"bits": PrecisionBits.BITS4, "group_size": 0, "symmetric": True}, + ] + higher_bit_qargs = {"bits": PrecisionBits.BITS8, "group_size": 128, "symmetric": True} + + assert _quant_config_rank(symmetric_qargs) == _quant_config_rank(asymmetric_qargs) + assert max(group_size_qargs, key=_quant_config_rank) == group_size_qargs[1] + assert max(group_size_qargs[2:], key=_quant_config_rank) == group_size_qargs[2] + assert max([*group_size_qargs, higher_bit_qargs], key=_quant_config_rank) == higher_bit_qargs + + +def test_quant_config_ignores_excluded_qkv_overrides_when_normalizing(input_model): + """With ``exclude_attn_inputs=True``, excluded Q/K do not pull V into a higher precision. + + Q and K are excluded from quantization, so their high-bit overrides must be dropped and + V must keep the default int4 config (no group promotion across excluded members). + """ + model = HfModelHandler( + input_model.model_path, + model_attributes={ + "mixed_precision_info": { + "default": {"bits": PrecisionBits.BITS4, "group_size": 16, "symmetric": False}, + "overrides": { + "model.layers.0.self_attn.q_proj": { + "bits": PrecisionBits.BITS8, + "group_size": 16, + "symmetric": True, + }, + "model.layers.0.self_attn.k_proj": { + "bits": PrecisionBits.BITS8, + "group_size": 16, + "symmetric": True, + }, + }, + } + }, + ) + config = SimpleNamespace( + bits=PrecisionBits.BITS4, + sym=False, + group_size=16, + lm_head=False, + overrides=None, + ) + + wrapper, qcfg, _ = prepare_model(model, config, exclude_attn_inputs=True) + attention = wrapper.model.model.layers[0].self_attn + + assert not hasattr(attention.q_proj, "quant_info") + assert not hasattr(attention.k_proj, "quant_info") + assert qcfg.get_qlinear_init_args("model.layers.0.self_attn.v_proj") == { + "bits": PrecisionBits.BITS4, + "symmetric": False, + "group_size": 16, + } + assert attention.v_proj.quant_info.quantizer.bits == PrecisionBits.BITS4 + assert "model.layers.0.self_attn.q_proj" not in qcfg.overrides + assert "model.layers.0.self_attn.k_proj" not in qcfg.overrides + + +def test_selective_mixed_precision_kld_low_memory_matches_legacy_grad_accum(monkeypatch): + """LOW_MEMORY KLD scoring is numerically equivalent to the legacy gradient-accumulation path. + + Guards the chunked/low-memory implementation against drift from the reference scores. + """ + data = get_kld_gradient_test_data() + patch_kld_calibration_data(monkeypatch, data) + quantizer, high_quantizer = get_kld_gradient_quantizers() + model = KldGradientTestModel() + + expected_numels, expected_scores = get_legacy_kld_scores( + deepcopy(model), data, quantizer, high_quantizer, device="cpu" + ) + actual_numels, actual_scores = SelectiveMixedPrecision.get_kld_scores( + None, + deepcopy(model), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cpu", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY, + ) + + assert actual_numels == expected_numels + assert_scores_close(actual_scores, expected_scores) + + +def test_selective_mixed_precision_kld_full_matches_legacy_grad_accum(monkeypatch): + """FULL KLD scoring is numerically equivalent to the legacy gradient-accumulation path. + + Pins the fast/full path to the same scores as the reference implementation. + """ + data = get_kld_gradient_test_data() + patch_kld_calibration_data(monkeypatch, data) + quantizer, high_quantizer = get_kld_gradient_quantizers() + model = KldGradientTestModel() + + expected_numels, expected_scores = get_legacy_kld_scores( + deepcopy(model), data, quantizer, high_quantizer, device="cpu" + ) + actual_numels, actual_scores = SelectiveMixedPrecision.get_kld_scores( + None, + deepcopy(model), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cpu", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.FULL, + ) + + assert actual_numels == expected_numels + assert_scores_close(actual_scores, expected_scores) + + +def test_selective_mixed_precision_kld_auto_resolves_to_full_on_cpu(): + """AUTO mode on CPU falls back to FULL (no GPU memory budget to worry about).""" + model = KldGradientTestModel() + + resolved = SelectiveMixedPrecision.resolve_kld_memory_mode( + model, device="cpu", kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.AUTO + ) + + assert resolved == SelectiveMixedPrecision.KldMemoryMode.FULL + + +def test_selective_mixed_precision_kld_auto_passthrough_when_not_auto(): + """Explicit modes (FULL/LOW_MEMORY/OFFLOAD) are returned unchanged by the resolver. + + Only AUTO triggers heuristic selection; user-pinned modes must pass through. + """ + model = KldGradientTestModel() + + for mode in ( + SelectiveMixedPrecision.KldMemoryMode.FULL, + SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU, + SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY, + SelectiveMixedPrecision.KldMemoryMode.OFFLOAD, + ): + assert SelectiveMixedPrecision.resolve_kld_memory_mode(model, device="cuda", kld_memory_mode=mode) == mode + + +def test_selective_mixed_precision_kld_auto_falls_back_to_offload_when_cuda_memory_query_fails(monkeypatch): + """AUTO mode chooses OFFLOAD when CUDA free memory cannot be queried safely.""" + model = KldGradientTestModel() + + def raise_cuda_error(_device): + raise RuntimeError("CUDA memory query failed") + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + monkeypatch.setattr(torch.cuda, "mem_get_info", raise_cuda_error) + + resolved = SelectiveMixedPrecision.resolve_kld_memory_mode( + model, device="cuda", kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.AUTO + ) + + assert resolved == SelectiveMixedPrecision.KldMemoryMode.OFFLOAD + + +def test_selective_mixed_precision_kld_auto_picks_multi_gpu_when_full_fits_across_gpus(monkeypatch): + """AUTO picks MULTI_GPU when FULL does not fit on one GPU but fits across all visible GPUs. + + Simulates two GPUs each with a small per-device budget so that FULL fails the single-GPU + check but the combined budget across both GPUs is sufficient. + """ + model = KldGradientTestModel() + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 2) + # The tiny test model needs ~491 bytes for FULL and ~338 bytes for LOW_MEMORY. 400 bytes per GPU + # (single_budget=340) fails the FULL single-GPU check but the combined two-GPU budget (680) + # accommodates FULL, so MULTI_GPU should win. + monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (400, 400)) + + resolved = SelectiveMixedPrecision.resolve_kld_memory_mode( + model, device="cuda", kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.AUTO + ) + + assert resolved == SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU + + +def test_selective_mixed_precision_kld_auto_does_not_pick_multi_gpu_on_single_gpu(monkeypatch): + """AUTO never picks MULTI_GPU when only one CUDA device is visible, even if FULL doesn't fit. + + With a single GPU the ladder must skip MULTI_GPU and pick LOW_MEMORY or OFFLOAD instead. + """ + model = KldGradientTestModel() + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + # 400 bytes single-GPU budget (340 after safety factor): FULL won't fit, MULTI_GPU is gated + # out by gpu_count==1, so LOW_MEMORY/OFFLOAD must win. + monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (400, 400)) + + resolved = SelectiveMixedPrecision.resolve_kld_memory_mode( + model, device="cuda", kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.AUTO + ) + + assert resolved != SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU + + +def test_selective_mixed_precision_kld_explicit_multi_gpu_falls_back_without_multiple_cuda_devices(monkeypatch): + """Explicit MULTI_GPU falls back before trying CUDA when fewer than two CUDA devices are visible.""" + data = get_kld_gradient_test_data() + patch_kld_calibration_data(monkeypatch, data) + quantizer, high_quantizer = get_kld_gradient_quantizers() + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + warnings = [] + monkeypatch.setattr( + smp_module.logger, + "warning", + lambda message, *args: warnings.append(message % args if args else message), + ) + + module_numels, module_scores = SelectiveMixedPrecision.get_kld_scores( + None, + KldGradientTestModel(), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cpu", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU, + ) + + assert module_numels + assert module_scores + assert any("requires at least two visible CUDA devices" in warning for warning in warnings) + + +def test_selective_mixed_precision_kld_multi_gpu_uses_constrained_device_map(monkeypatch): + """MULTI_GPU passes constrained per-GPU max_memory to Accelerate device-map inference. + + This prevents Accelerate from placing one full model copy on a GPU without leaving room for + the second copy and the fp32 gradient accumulator used by the FULL algorithm. + """ + captured = {} + fake_accelerate = ModuleType("accelerate") + + def fake_infer_auto_device_map(_model, max_memory, no_split_module_classes): + captured["max_memory"] = max_memory + captured["no_split_module_classes"] = no_split_module_classes + return {"": 0} + + def fake_dispatch_model(model, device_map): + captured.setdefault("device_maps", []).append(device_map) + return model + + fake_accelerate.infer_auto_device_map = fake_infer_auto_device_map + fake_accelerate.dispatch_model = fake_dispatch_model + monkeypatch.setitem(sys.modules, "accelerate", fake_accelerate) + original_find_spec = importlib.util.find_spec + monkeypatch.setattr( + importlib.util, + "find_spec", + lambda name: object() if name == "accelerate" else original_find_spec(name), + ) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (1000, 1000)) + patch_kld_calibration_data(monkeypatch, []) + monkeypatch.setattr(smp_module, "replace_matching_submodules", lambda *_args, **_kwargs: None) + quantizer, high_quantizer = get_kld_gradient_quantizers() + + SelectiveMixedPrecision.get_kld_scores( + None, + KldGradientTestModel(), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cuda", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU, + ) + + assert captured["max_memory"] == {0: 222, 1: 222} + assert all(memory < 850 for memory in captured["max_memory"].values()) + assert captured["device_maps"] == [{"": 0}, {"": 0}] + + +def test_selective_mixed_precision_kld_offload_matches_low_memory(monkeypatch): + """OFFLOAD mode produces the same scores as LOW_MEMORY (only differs in where tensors live). + + Confirms CPU-offloading does not perturb the numerics relative to the in-GPU low-memory path. + """ + data = get_kld_gradient_test_data() + patch_kld_calibration_data(monkeypatch, data) + quantizer, high_quantizer = get_kld_gradient_quantizers() + model = KldGradientTestModel() + + low_memory_numels, low_memory_scores = SelectiveMixedPrecision.get_kld_scores( + None, + deepcopy(model), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cpu", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY, + ) + offload_numels, offload_scores = SelectiveMixedPrecision.get_kld_scores( + None, + deepcopy(model), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cpu", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.OFFLOAD, + ) + + assert offload_numels == low_memory_numels + assert_scores_close(offload_scores, low_memory_scores) + + +def test_selective_mixed_precision_kld_memory_modes_keep_same_mixed_precision_config(input_model, monkeypatch): + """FULL/LOW_MEMORY/OFFLOAD all yield identical ``mixed_precision_info`` given the same scores. + + Memory mode is a runtime-only knob; once scoring is mocked to a fixed result, the produced + config must not depend on which path was taken. Also asserts each mode is actually passed + through to ``get_kld_scores``. + """ + module_numels = { + "model.layers.0.self_attn.q_proj": 1, + "model.layers.0.self_attn.k_proj": 1, + "model.layers.0.self_attn.v_proj": 1, + "model.layers.0.mlp.down_proj": 100, + } + module_scores = { + "model.layers.0.self_attn.q_proj": 10.0, + "model.layers.0.self_attn.k_proj": 1.0, + "model.layers.0.self_attn.v_proj": 10.0, + "model.layers.0.mlp.down_proj": 5.0, + } + seen_memory_modes = [] + + def get_kld_scores(*args): + seen_memory_modes.append(args[-1]) + return module_numels, module_scores + + monkeypatch.setattr(SelectiveMixedPrecision, "get_kld_scores", staticmethod(get_kld_scores)) + model_wrapper = ModelWrapper.from_model(input_model.load_model()) + + modes = [ + SelectiveMixedPrecision.KldMemoryMode.FULL, + SelectiveMixedPrecision.KldMemoryMode.LOW_MEMORY, + SelectiveMixedPrecision.KldMemoryMode.OFFLOAD, + ] + configs = [ + SelectiveMixedPrecision.get_scored_config( + input_model, + model_wrapper, + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + PrecisionBits.BITS4, + 16, + False, + PrecisionBits.BITS8, + 16, + True, + 0.98, + mode, + ) + for mode in modes + ] + + assert configs[0] == configs[1] == configs[2] + assert seen_memory_modes == modes + + @pytest.mark.parametrize("algorithm", ["snr", "snr_relative", "iqe", "iqe_relative", "kld_gradient"]) def test_selective_mixed_precision_scored(algorithm, tmp_path): + """End-to-end: every score-based algorithm produces a valid ``mixed_precision_info``. + + Runs the pass on tiny-phi3 for each scoring algorithm and checks the default/lm_head + sections are populated correctly. ``kld_gradient`` is skipped on CPU (too slow). + """ if algorithm == "kld_gradient" and not torch.cuda.is_available(): pytest.skip("Skipping kld_gradient test as it runs slow on CPU.") From cc52a7e884be9794a69a0504600925201b3b0560 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 22 May 2026 16:50:23 -0700 Subject: [PATCH 2/3] docs: surface KLD memory modes and QKV grouping in pass docstring --- olive/passes/pytorch/selective_mixed_precision.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/olive/passes/pytorch/selective_mixed_precision.py b/olive/passes/pytorch/selective_mixed_precision.py index ce3f21b48d..437ea2b6bc 100644 --- a/olive/passes/pytorch/selective_mixed_precision.py +++ b/olive/passes/pytorch/selective_mixed_precision.py @@ -48,6 +48,14 @@ class SelectiveMixedPrecision(Pass): - iqe: Inverse of Integer Quantization Error based selection. - iqe_relative: Relative IQE (between low and high precision) based selection. - kld_gradient: KL Divergence gradient based selection. + + For ``kld_gradient`` the peak memory required for KL Divergence scoring can be tuned via + ``kld_memory_mode``, which supports ``auto`` (default; picks based on the model size and free + device memory), ``full``, ``multi_gpu`` (shards the scoring forward across all visible CUDA + devices with ``accelerate``), ``low_memory``, and ``offload``. + + The override map produced by this pass groups Q/K/V projections in the same attention block so + they always share precision, which is required for ModelBuilder's GQA fusion. """ class Algorithm(StrEnumBase): From 8e98a928afeff5bfbe09796f81c21ad7d5e21c9e Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 22 May 2026 18:04:46 -0700 Subject: [PATCH 3/3] Address SMP review feedback --- olive/passes/pytorch/quant_utils.py | 3 + .../pytorch/selective_mixed_precision.py | 13 +- .../pytorch/test_selective_mixed_precision.py | 125 ++++++++++++++++++ 3 files changed, 136 insertions(+), 5 deletions(-) diff --git a/olive/passes/pytorch/quant_utils.py b/olive/passes/pytorch/quant_utils.py index f468d2a221..7063cba2da 100644 --- a/olive/passes/pytorch/quant_utils.py +++ b/olive/passes/pytorch/quant_utils.py @@ -243,6 +243,9 @@ def add_quant_info(module: torch.nn.Module, name: str) -> torch.nn.Module: merged_qcfg_dict["lm_head"] |= qcfg.lm_head merged_qcfg_dict["embeds"] |= qcfg.embeds qcfg = OliveHfQuantizationConfig(**merged_qcfg_dict) + # Re-normalize: the pre-existing overrides we just merged in may violate QKV + # consistency on their own, which can break ModelBuilder's GQA fusion. + qcfg = normalize_qkv_quant_config(wrapper, qcfg, quantizable_attn_input_names) word_embeddings_eligible_for_tieing = ( originally_tied_embeddings diff --git a/olive/passes/pytorch/selective_mixed_precision.py b/olive/passes/pytorch/selective_mixed_precision.py index 437ea2b6bc..d732280183 100644 --- a/olive/passes/pytorch/selective_mixed_precision.py +++ b/olive/passes/pytorch/selective_mixed_precision.py @@ -586,13 +586,16 @@ def get_kld_scores( multi_gpu = False full_memory = False else: - device_counts: dict[str, int] = {} - for mapped_device in device_map.values(): - device_counts[str(mapped_device)] = device_counts.get(str(mapped_device), 0) + 1 + layer_device_counts: dict[str, int] = {} + for devices in layer_groups.values(): + # layer_groups[layer] is a singleton set after coalescing succeeded above. + (dev,) = devices + layer_device_counts[dev] = layer_device_counts.get(dev, 0) + 1 logger.info( - "kld_memory_mode=multi_gpu device_map: %d entries across %s.", + "kld_memory_mode=multi_gpu device_map: %d decoder layers across %s (total %d module entries).", + len(layer_groups), + layer_device_counts, len(device_map), - device_counts, ) model = dispatch_model(model, device_map=device_map).eval() q_model = dispatch_model(q_model, device_map=device_map).eval() diff --git a/test/passes/pytorch/test_selective_mixed_precision.py b/test/passes/pytorch/test_selective_mixed_precision.py index a9bb78db49..270df118ba 100644 --- a/test/passes/pytorch/test_selective_mixed_precision.py +++ b/test/passes/pytorch/test_selective_mixed_precision.py @@ -330,6 +330,56 @@ def test_quant_config_promotes_user_override_conflicts_for_qkv(input_model): ] == [PrecisionBits.BITS8, PrecisionBits.BITS8, PrecisionBits.BITS8] +def test_prepare_model_renormalizes_qkv_after_merging_existing_quant_config(input_model, monkeypatch): + """``prepare_model`` renormalizes QKV after merging a pre-existing ``quantization_config``. + + With ``allow_quantized=True`` (e.g., the RTN path), a model loaded with a + ``quantization_config`` whose ``overrides`` already violate QKV consistency must still + end up with Q/K/V sharing the most-precise config so ModelBuilder's GQA fusion works. + """ + # Pre-existing quant config has only q_proj at 8-bit; k_proj and v_proj are at the default 4-bit. + existing_quantization_config = { + "quant_method": "olive", + "bits": PrecisionBits.BITS4, + "symmetric": False, + "group_size": 16, + "lm_head": False, + "embeds": False, + "overrides": { + "model.layers.0.self_attn.q_proj": { + "bits": PrecisionBits.BITS8, + "symmetric": True, + "group_size": 16, + }, + }, + } + real_get_hf_model_config = HfModelHandler.get_hf_model_config + + def fake_get_hf_model_config(self, exclude_load_keys=None): + cfg = real_get_hf_model_config(self, exclude_load_keys=exclude_load_keys) + cfg.quantization_config = dict(existing_quantization_config) + return cfg + + monkeypatch.setattr(HfModelHandler, "get_hf_model_config", fake_get_hf_model_config) + + config = SimpleNamespace( + bits=PrecisionBits.BITS4, + sym=False, + group_size=16, + lm_head=False, + overrides={}, + ) + + _, qcfg, _ = prepare_model(input_model, config, allow_quantized=True) + + qkv_qargs = [qcfg.get_qlinear_init_args(f"model.layers.0.self_attn.{name}_proj") for name in ["q", "k", "v"]] + assert qkv_qargs == [ + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + {"bits": PrecisionBits.BITS8, "symmetric": True, "group_size": 16}, + ] + + def test_quant_config_rank_prefers_bits_then_smaller_positive_group_size(): """Unit test for ``_quant_config_rank`` ordering used to promote QKV groups. @@ -619,6 +669,81 @@ def fake_dispatch_model(model, device_map): assert captured["device_maps"] == [{"": 0}, {"": 0}] +def test_selective_mixed_precision_kld_multi_gpu_logs_per_layer_device_counts(monkeypatch): + """MULTI_GPU diagnostic log reports decoder-layer counts per device, not raw map entries. + + After coalescing sub-decoder-layer placements, the info log must reflect how many distinct + ``model.layers.N`` decoder layers ended up on each device, plus the total module-entry + count for context. + """ + fake_accelerate = ModuleType("accelerate") + + def fake_infer_auto_device_map(_model, max_memory, no_split_module_classes): + # Three decoder layers split across two GPUs, with a sub-module placement that the + # coalescing pass must pull back onto the layer's primary device. + return { + "model.layers.0": 0, + "model.layers.0.mlp.down_proj": 1, # to be coalesced back to device 0 + "model.layers.1": 0, + "model.layers.2": 1, + "model.embed_tokens": 0, + "lm_head": 1, + } + + def fake_dispatch_model(model, device_map): + return model + + fake_accelerate.infer_auto_device_map = fake_infer_auto_device_map + fake_accelerate.dispatch_model = fake_dispatch_model + monkeypatch.setitem(sys.modules, "accelerate", fake_accelerate) + original_find_spec = importlib.util.find_spec + monkeypatch.setattr( + importlib.util, + "find_spec", + lambda name: object() if name == "accelerate" else original_find_spec(name), + ) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(torch.cuda, "mem_get_info", lambda _i: (1000, 1000)) + patch_kld_calibration_data(monkeypatch, []) + monkeypatch.setattr(smp_module, "replace_matching_submodules", lambda *_args, **_kwargs: None) + quantizer, high_quantizer = get_kld_gradient_quantizers() + + import logging as _logging + + captured_logs: list[str] = [] + + class _ListHandler(_logging.Handler): + def emit(self, record): + captured_logs.append(record.getMessage()) + + handler = _ListHandler(level=_logging.INFO) + smp_module.logger.addHandler(handler) + previous_level = smp_module.logger.level + smp_module.logger.setLevel(_logging.INFO) + try: + SelectiveMixedPrecision.get_kld_scores( + None, + KldGradientTestModel(), + SelectiveMixedPrecision.Algorithm.KLD_GRADIENT, + quantizer, + high_quantizer, + device="cuda", + kld_memory_mode=SelectiveMixedPrecision.KldMemoryMode.MULTI_GPU, + ) + finally: + smp_module.logger.removeHandler(handler) + smp_module.logger.setLevel(previous_level) + + device_map_logs = [msg for msg in captured_logs if "kld_memory_mode=multi_gpu device_map" in msg] + assert device_map_logs, f"expected an info log describing the multi_gpu device map; got {captured_logs!r}" + log = device_map_logs[-1] + # 3 decoder layers total: layers 0 and 1 on device 0, layer 2 on device 1. + assert "3 decoder layers" in log + assert "'0': 2" in log or "0: 2" in log + assert "'1': 1" in log or "1: 1" in log + + def test_selective_mixed_precision_kld_offload_matches_low_memory(monkeypatch): """OFFLOAD mode produces the same scores as LOW_MEMORY (only differs in where tensors live).