|
15 | 15 |
|
16 | 16 | """Unit tests for layerwise_calibrate and LayerActivationCollector.""" |
17 | 17 |
|
| 18 | +import copy |
18 | 19 | from collections import deque |
19 | 20 |
|
20 | 21 | import pytest |
21 | 22 | import torch |
22 | 23 | import torch.nn as nn |
23 | 24 |
|
| 25 | +import modelopt.torch.quantization as mtq |
24 | 26 | from modelopt.torch.quantization.model_calib import layerwise_calibrate |
| 27 | +from modelopt.torch.quantization.nn import TensorQuantizer |
25 | 28 | from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer |
26 | 29 |
|
27 | 30 |
|
@@ -593,3 +596,131 @@ def forward_loop(m): |
593 | 596 | for i, orig in enumerate(originals): |
594 | 597 | assert model.layers[i] is orig, f"Layer {i} not restored to original after cleanup" |
595 | 598 | assert not hasattr(orig, "_layerwise_calib"), f"Layer {i} still has _layerwise_calib" |
| 599 | + |
| 600 | + |
| 601 | +# --------------------------------------------------------------------------- |
| 602 | +# End-to-end mtq.quantize(..., algorithm={"layerwise": True}) per PTQ algorithm |
| 603 | +# --------------------------------------------------------------------------- |
| 604 | + |
| 605 | + |
| 606 | +def _int8_layerwise_config(algorithm: dict) -> dict: |
| 607 | + """Start from the shipped INT8 config and enable layerwise in the algorithm block. |
| 608 | +
|
| 609 | + Using a real shipped config guarantees the same include/exclude rules |
| 610 | + production PTQ relies on, so algorithm dispatch matches real usage. |
| 611 | + """ |
| 612 | + cfg = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) |
| 613 | + cfg["algorithm"] = algorithm |
| 614 | + return cfg |
| 615 | + |
| 616 | + |
| 617 | +def _awq_layerwise_config() -> dict: |
| 618 | + """INT4 weight-only AWQ config sized for the _DecoderBlock test model.""" |
| 619 | + cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) |
| 620 | + # Resize AWQ block to fit dim=16 hidden. |
| 621 | + for entry in cfg["quant_cfg"]: |
| 622 | + if entry.get("quantizer_name") == "*weight_quantizer": |
| 623 | + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"} |
| 624 | + cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 0.5, "layerwise": True} |
| 625 | + return cfg |
| 626 | + |
| 627 | + |
| 628 | +def _svdquant_layerwise_config() -> dict: |
| 629 | + """SVDQuant config sized for the _DecoderBlock test model.""" |
| 630 | + cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) |
| 631 | + for entry in cfg["quant_cfg"]: |
| 632 | + if entry.get("quantizer_name") == "*weight_quantizer": |
| 633 | + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"} |
| 634 | + cfg["algorithm"] = {"method": "svdquant", "lowrank": 4, "layerwise": True} |
| 635 | + return cfg |
| 636 | + |
| 637 | + |
| 638 | +def test_mtq_quantize_layerwise_e2e_max(monkeypatch): |
| 639 | + """End-to-end: mtq.quantize with layerwise=True produces populated amax values. |
| 640 | +
|
| 641 | + ``max`` is the representative algorithm for the layerwise happy path because |
| 642 | + every other algorithm seeds amax via max_calibrate first — if max works, the |
| 643 | + shared skip/run/capture machinery is sound. Other algorithms are covered by |
| 644 | + the dispatch-only test below to avoid hardware requirements (e.g. gptq needs |
| 645 | + CUDA) or unnecessary duplication. |
| 646 | + """ |
| 647 | + _register_test_discoverer(monkeypatch) |
| 648 | + config = _int8_layerwise_config({"method": "max", "layerwise": True}) |
| 649 | + |
| 650 | + torch.manual_seed(0) |
| 651 | + model = _SimpleTransformerModel(n_layers=3, dim=16) |
| 652 | + calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)] |
| 653 | + |
| 654 | + def forward_loop(m): |
| 655 | + for batch in calib_data: |
| 656 | + m(batch) |
| 657 | + |
| 658 | + model = mtq.quantize(model, config, forward_loop=forward_loop) |
| 659 | + |
| 660 | + for i, layer in enumerate(model.layers): |
| 661 | + assert not isinstance(layer, _SkipLayer), f"layer {i} left as _SkipLayer" |
| 662 | + assert not hasattr(layer, "_layerwise_calib"), f"layer {i} leaked _layerwise_calib" |
| 663 | + |
| 664 | + amax_count = sum( |
| 665 | + 1 |
| 666 | + for layer in model.layers |
| 667 | + for module in layer.modules() |
| 668 | + if ( |
| 669 | + isinstance(module, TensorQuantizer) |
| 670 | + and module.is_enabled |
| 671 | + and getattr(module, "_amax", None) is not None |
| 672 | + ) |
| 673 | + ) |
| 674 | + assert amax_count > 0, "no TensorQuantizer in decoder layers had _amax populated" |
| 675 | + |
| 676 | + with torch.no_grad(): |
| 677 | + model(calib_data[0]) |
| 678 | + |
| 679 | + |
| 680 | +@pytest.mark.parametrize( |
| 681 | + "algorithm", |
| 682 | + ["gptq", "awq_lite", "smoothquant", "mse"], |
| 683 | +) |
| 684 | +def test_mtq_quantize_layerwise_dispatches_for_algorithm(monkeypatch, algorithm): |
| 685 | + """Every layerwise-supporting algorithm must route through layerwise_calibrate. |
| 686 | +
|
| 687 | + Stubs layerwise_calibrate to a spy so the dispatch contract is checked without |
| 688 | + running the algorithm's full calibration — lets ``gptq`` (CUDA-only at runtime) |
| 689 | + and other expensive algorithms participate in CPU unit tests. |
| 690 | + """ |
| 691 | + spy: dict = {} |
| 692 | + |
| 693 | + def stub(model, forward_loop, calib_func, **kwargs): |
| 694 | + spy["calib_func"] = calib_func |
| 695 | + spy["kwargs"] = kwargs |
| 696 | + |
| 697 | + monkeypatch.setattr("modelopt.torch.quantization.mode.layerwise_calibrate", stub) |
| 698 | + |
| 699 | + if algorithm == "awq_lite": |
| 700 | + config = _awq_layerwise_config() |
| 701 | + else: |
| 702 | + config = _int8_layerwise_config({"method": algorithm, "layerwise": True}) |
| 703 | + |
| 704 | + torch.manual_seed(0) |
| 705 | + model = _SimpleTransformerModel(n_layers=2, dim=16) |
| 706 | + mtq.quantize( |
| 707 | + model, |
| 708 | + config, |
| 709 | + forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), |
| 710 | + ) |
| 711 | + |
| 712 | + assert "calib_func" in spy, f"{algorithm} did not dispatch through layerwise_calibrate" |
| 713 | + assert callable(spy["calib_func"]) |
| 714 | + |
| 715 | + |
| 716 | +def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): |
| 717 | + """Modes with ``_supports_layerwise = False`` must raise a clear ValueError.""" |
| 718 | + config = _svdquant_layerwise_config() |
| 719 | + torch.manual_seed(0) |
| 720 | + model = _SimpleTransformerModel(n_layers=2, dim=16) |
| 721 | + with pytest.raises(ValueError, match="does not support layerwise=True"): |
| 722 | + mtq.quantize( |
| 723 | + model, |
| 724 | + config, |
| 725 | + forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), |
| 726 | + ) |
0 commit comments