|
52 | 52 | promote_nvfp4_static_quantizers, |
53 | 53 | quantizer_attr_names, |
54 | 54 | reduce_amax, |
55 | | - weight_attr_names, |
56 | 55 | ) |
57 | 56 | from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper |
58 | 57 |
|
|
66 | 65 | "svdquant", |
67 | 66 | ] |
68 | 67 |
|
| 68 | + |
| 69 | +def _is_calibrated_nvfp4_static(q) -> bool: |
| 70 | + """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" |
| 71 | + return ( |
| 72 | + isinstance(q, TensorQuantizer) |
| 73 | + and not q._disabled |
| 74 | + and q.is_nvfp4_static |
| 75 | + and getattr(q, "_amax", None) is not None |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: |
| 80 | + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" |
| 81 | + # Inline: layer_utils → quant_utils → model_calib cycle. |
| 82 | + from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS |
| 83 | + |
| 84 | + # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant |
| 85 | + # in export). Single source for the gate/up half avoids parallel lists. |
| 86 | + patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) |
| 87 | + groups: list[list[nn.Module]] = [] |
| 88 | + wq_attr = quantizer_attr_names("weight").weight_quantizer |
| 89 | + for parent in model.modules(): |
| 90 | + for sibling_names in patterns: |
| 91 | + members = [ |
| 92 | + child |
| 93 | + for child in (getattr(parent, n, None) for n in sibling_names) |
| 94 | + if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) |
| 95 | + ] |
| 96 | + if len(members) >= 2: |
| 97 | + groups.append(members) |
| 98 | + return groups |
| 99 | + |
| 100 | + |
| 101 | +@torch.no_grad() |
| 102 | +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: |
| 103 | + """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. |
| 104 | +
|
| 105 | + Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE |
| 106 | + doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. |
| 107 | + Activation quantizers on those modules remain uncalibrated; emits a warning. |
| 108 | + """ |
| 109 | + name_to_module = dict(model.named_modules()) |
| 110 | + n = 0 |
| 111 | + for module in name_to_module.values(): |
| 112 | + if not isinstance(module, QuantModule): |
| 113 | + continue |
| 114 | + with enable_weight_access_and_writeback(module, model, name_to_module): |
| 115 | + for weight, q in module.iter_weights_for_calibration(): |
| 116 | + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: |
| 117 | + continue |
| 118 | + if q._calibrator is None: |
| 119 | + continue |
| 120 | + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): |
| 121 | + continue |
| 122 | + q.disable_quant() |
| 123 | + q.enable_calib() |
| 124 | + q(weight) |
| 125 | + if q._calibrator.compute_amax() is not None: |
| 126 | + q.load_calib_amax() |
| 127 | + q.enable_quant() |
| 128 | + q.disable_calib() |
| 129 | + if hasattr(q._calibrator, "reset"): |
| 130 | + q._calibrator.reset() |
| 131 | + n += 1 |
| 132 | + if n > 0: |
| 133 | + warnings.warn( |
| 134 | + f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " |
| 135 | + f"their activation quantizers (if any) remain uncalibrated. " |
| 136 | + f"Increase calib size/seq len to activate all experts.", |
| 137 | + stacklevel=2, |
| 138 | + ) |
| 139 | + return n |
| 140 | + |
| 141 | + |
| 142 | +@torch.no_grad() |
| 143 | +def _sync_grouped_weight_global_amax(model: nn.Module) -> int: |
| 144 | + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. |
| 145 | +
|
| 146 | + Run after ``max_calibrate``. Sibling discovery is name-based via |
| 147 | + ``_collect_grouped_linears``; non-matching architectures (wqkv, fused |
| 148 | + qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently |
| 149 | + fall back to per-module global_amax. Fused-experts containers already |
| 150 | + share a single quantizer across gate/up halves and need no sync. |
| 151 | + """ |
| 152 | + # quant_utils imports back from this module; top-level would cycle. |
| 153 | + from modelopt.torch.export.quant_utils import preprocess_linear_fusion |
| 154 | + |
| 155 | + wq_attr = quantizer_attr_names("weight").weight_quantizer |
| 156 | + n_groups = 0 |
| 157 | + for group in _collect_grouped_linears(model): |
| 158 | + for child in group: |
| 159 | + wq = getattr(child, wq_attr) |
| 160 | + if not isinstance(wq, NVFP4StaticQuantizer): |
| 161 | + NVFP4StaticQuantizer.from_tensor_quantizer( |
| 162 | + wq, global_amax=reduce_amax(wq._amax, axis=None) |
| 163 | + ) |
| 164 | + preprocess_linear_fusion(group) |
| 165 | + n_groups += 1 |
| 166 | + return n_groups |
| 167 | + |
| 168 | + |
69 | 169 | CalibratorFactory: TypeAlias = Callable[ |
70 | 170 | [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator |
71 | 171 | ] |
@@ -346,32 +446,23 @@ def mse_calibrate( |
346 | 446 | See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for |
347 | 447 | details on the remaining arguments. |
348 | 448 | """ |
349 | | - # Step 1: First get initial amax using max calibration |
| 449 | + # Step 1: max calibrate, bootstrap dead-expert weight quantizers, |
| 450 | + # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. |
350 | 451 | max_calibrate(model, forward_loop, distributed_sync) |
| 452 | + _bootstrap_uncalibrated_weight_quantizers(model) |
| 453 | + _sync_grouped_weight_global_amax(model) |
351 | 454 |
|
352 | | - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers |
353 | | - # and identify weight quantizers |
354 | | - weight_quantizers = [] |
355 | | - seen_modules = set() |
356 | | - |
| 455 | + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. |
357 | 456 | for name, module in list(model.named_modules()): |
358 | 457 | if isinstance(module, TensorQuantizer) and not module._disabled: |
359 | 458 | if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): |
360 | | - # Get the initial amax from max calibration |
361 | 459 | initial_amax = module._amax.clone().detach() |
| 460 | + is_nvfp4_static = module.is_nvfp4_static |
362 | 461 |
|
363 | | - is_nvfp4_static = ( |
364 | | - module.is_static_block_quant |
365 | | - and module._num_bits == (2, 1) |
366 | | - and module._block_sizes is not None |
367 | | - and module._block_sizes.get("scale_bits") == (4, 3) |
368 | | - ) |
369 | | - |
370 | | - if is_nvfp4_static: |
371 | | - # Compute and set global_amax |
| 462 | + # Promote standalone NVFP4-static quantizers; grouped siblings |
| 463 | + # already promoted by _sync_grouped_weight_global_amax above. |
| 464 | + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): |
372 | 465 | global_amax = reduce_amax(initial_amax, axis=None) |
373 | | - |
374 | | - # Convert to NVFP4StaticQuantizer in-place |
375 | 466 | NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) |
376 | 467 |
|
377 | 468 | if fp8_scale_sweep: |
@@ -412,52 +503,48 @@ def mse_calibrate( |
412 | 503 | quant_func=partial(_mse_quant_func, quantizer=module), |
413 | 504 | ) |
414 | 505 |
|
415 | | - # Identify weight quantizers by checking if they have corresponding weight parameters |
| 506 | + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. |
416 | 507 | name_to_module = dict(model.named_modules()) |
| 508 | + seen_modules: set[int] = set() |
| 509 | + pbar = tqdm(desc="MSE weight calibration") |
| 510 | + n_calibrated = 0 |
417 | 511 | for parent_module in name_to_module.values(): |
418 | | - if parent_module in seen_modules: |
| 512 | + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): |
419 | 513 | continue |
420 | | - for weight_name in weight_attr_names(parent_module): |
421 | | - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer |
422 | | - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) |
423 | | - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: |
424 | | - if getattr(weight_quantizer, "_calibrator", None) is not None: |
425 | | - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) |
426 | | - seen_modules.add(parent_module) |
427 | | - |
428 | | - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation |
429 | | - # This prevents massive memory accumulation seen in large models |
430 | | - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( |
431 | | - tqdm(weight_quantizers, desc="MSE weight calibration") |
432 | | - ): |
433 | | - # Enable calibration mode for the weight quantizer |
434 | | - weight_quantizer.disable_quant() |
435 | | - weight_quantizer.enable_calib() |
| 514 | + seen_modules.add(id(parent_module)) |
436 | 515 | with enable_weight_access_and_writeback(parent_module, model, name_to_module): |
437 | | - weight = getattr(parent_module, weight_name) |
438 | | - weight_quantizer(weight) |
| 516 | + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): |
| 517 | + if not ( |
| 518 | + isinstance(weight_quantizer, TensorQuantizer) |
| 519 | + and weight_quantizer.is_enabled |
| 520 | + and getattr(weight_quantizer, "_calibrator", None) is not None |
| 521 | + ): |
| 522 | + continue |
| 523 | + weight_quantizer.disable_quant() |
| 524 | + weight_quantizer.enable_calib() |
| 525 | + weight_quantizer(weight) |
439 | 526 |
|
440 | | - # IMMEDIATELY compute amax and reset calibrator to free memory |
441 | | - cal = getattr(weight_quantizer, "_calibrator", None) |
442 | | - if cal is not None and cal.compute_amax() is not None: |
443 | | - weight_quantizer.load_calib_amax() |
| 527 | + cal = weight_quantizer._calibrator |
| 528 | + if cal.compute_amax() is not None: |
| 529 | + weight_quantizer.load_calib_amax() |
444 | 530 |
|
445 | | - weight_quantizer.enable_quant() |
446 | | - weight_quantizer.disable_calib() |
| 531 | + weight_quantizer.enable_quant() |
| 532 | + weight_quantizer.disable_calib() |
447 | 533 |
|
448 | | - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete |
449 | | - # This is critical for multi-GPU setups where tensors may be on different devices |
450 | | - if torch.cuda.is_available(): |
451 | | - for dev_id in range(torch.cuda.device_count()): |
452 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 534 | + if torch.cuda.is_available(): |
| 535 | + for dev_id in range(torch.cuda.device_count()): |
| 536 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
453 | 537 |
|
454 | | - if cal is not None and hasattr(cal, "reset"): |
455 | | - cal.reset() |
| 538 | + if hasattr(cal, "reset"): |
| 539 | + cal.reset() |
456 | 540 |
|
457 | | - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): |
458 | | - for dev_id in range(torch.cuda.device_count()): |
459 | | - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
460 | | - torch.cuda.empty_cache() |
| 541 | + pbar.update(1) |
| 542 | + n_calibrated += 1 |
| 543 | + if n_calibrated % 10 == 0 and torch.cuda.is_available(): |
| 544 | + for dev_id in range(torch.cuda.device_count()): |
| 545 | + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) |
| 546 | + torch.cuda.empty_cache() |
| 547 | + pbar.close() |
461 | 548 |
|
462 | 549 | if torch.cuda.is_available(): |
463 | 550 | for dev_id in range(torch.cuda.device_count()): |
@@ -612,6 +699,8 @@ def forward(self, input, *args, **kwargs): |
612 | 699 | print_rank_0("local_hessian: Running max calibration for all quantizers...") |
613 | 700 | max_calibrate(model, forward_loop, distributed_sync) |
614 | 701 |
|
| 702 | + _sync_grouped_weight_global_amax(model) |
| 703 | + |
615 | 704 | # Setup helpers for all quantized linear modules |
616 | 705 | name_to_module = dict(model.named_modules()) |
617 | 706 | weight_quantizers_info = [] |
@@ -666,14 +755,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): |
666 | 755 |
|
667 | 756 | return xq |
668 | 757 |
|
669 | | - is_nvfp4_static = ( |
670 | | - weight_quantizer.is_static_block_quant |
671 | | - and weight_quantizer._num_bits == (2, 1) |
672 | | - and weight_quantizer._block_sizes is not None |
673 | | - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) |
674 | | - ) |
| 758 | + is_nvfp4_static = weight_quantizer.is_nvfp4_static |
675 | 759 |
|
676 | | - if is_nvfp4_static: |
| 760 | + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): |
677 | 761 | global_amax = reduce_amax(initial_amax, axis=None) |
678 | 762 | NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) |
679 | 763 |
|
|
0 commit comments