Skip to content

Commit 9697f87

Browse files
authored
Add NPU quant method coverage (#2845)
* Add NPU quant method coverage * Simplify NPU device ordering * Resolve PR code quality comments
1 parent 1002e69 commit 9697f87

5 files changed

Lines changed: 240 additions & 17 deletions

File tree

gptqmodel/looper/awq_processor.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor
1919
from ..looper.named_module import NamedModule
2020
from ..models import BaseQModel
21-
from ..models._const import SUPPORTS_MODULE_TYPES
21+
from ..models._const import DEVICE, SUPPORTS_MODULE_TYPES
2222
from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
2323
PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
2424
from ..nn_modules.qlinear.gemm_awq import AwqGEMMLinear
2525
from ..nn_modules.qlinear.gemv_awq import AwqGEMVLinear
2626
from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastLinear, LLMAwqLinear
27+
from ..nn_modules.qlinear.torch_awq import AwqTorchLinear
2728
from ..quantization.awq.quantize.scale import apply_clip, apply_scale
2829
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name
2930
from ..quantization.awq.utils.utils import get_best_device
@@ -272,12 +273,31 @@ def set_calibration_dataset(self, calibration_dataset):
272273

273274
raise NotImplementedError("AWQProcessor's calibration_dataset cannot be modified")
274275

276+
def _quant_device_is_npu(self) -> bool:
277+
"""Return whether this processor is quantizing for an Ascend NPU runtime."""
278+
279+
device = getattr(self.qcfg, "device", None)
280+
if isinstance(device, DEVICE):
281+
return device == DEVICE.NPU
282+
if isinstance(device, torch.device):
283+
return device.type == "npu"
284+
if isinstance(device, str):
285+
return device.split(":", 1)[0].lower() == "npu"
286+
return False
287+
275288
def _select_qlinear_kernel_for_format(self, format_value: FORMAT):
276289
"""Maps the resolved AWQ format to its concrete quantized linear kernel."""
277290

278291
fmt = FORMAT(format_value) if not isinstance(format_value, FORMAT) else format_value
279292
if fmt == FORMAT.GEMM:
293+
if self._quant_device_is_npu():
294+
return AwqTorchLinear
280295
return AwqGEMMLinear
296+
if self._quant_device_is_npu():
297+
raise ValueError(
298+
"NPU AWQ quantization requires FORMAT.GEMM so the AwqTorchLinear runtime can run on NPU; "
299+
f"actual format is `{fmt}`."
300+
)
281301
if fmt == FORMAT.GEMV:
282302
return AwqGEMVLinear
283303
if fmt == FORMAT.GEMV_FAST:
@@ -1302,7 +1322,7 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
13021322

13031323
scales = scales.view(org_w_shape[0], -1)
13041324

1305-
# Symmetric quantization produces signed int values (e.g. int4 [-8, 7]),
1325+
# Symmetric quantization produces signed int values (e.g. int4 in [-8, 7]),
13061326
# which cannot be packed directly. To make it packable, we shift the signed
13071327
# representation to unsigned by adding 2^(bits-1), i.e. q_u = q_s + 2^(bits-1).
13081328
# This is equivalent to using an affine form with zero_point = 2^(bits-1),
@@ -1820,11 +1840,7 @@ def preprocess(self, module: NamedModule, fallback=None, **kwargs):
18201840
def is_skipped(self, module: NamedModule) -> bool:
18211841
"""Reports whether preprocessing excluded this module from AWQ work."""
18221842

1823-
t = self.tasks.get(module.name, False)
1824-
if t == False:
1825-
return True
1826-
else:
1827-
return False
1843+
return not self.tasks.get(module.name, False)
18281844

18291845
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
18301846
"""Returns the forward hook that caches module input activations for AWQ."""

gptqmodel/models/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,11 @@ def quantize(
781781
quant_device_type = str(quant_device).split(":")[0].lower()
782782

783783
if export_quant_method == METHOD.AWQ:
784+
if quant_device_type == "npu" and format_code != FORMAT.GEMM:
785+
raise ValueError(
786+
"NPU AWQ quantization requires FORMAT.GEMM so the AwqTorchLinear runtime can run on NPU; "
787+
f"actual format is `{format_code}`."
788+
)
784789
if format_code == FORMAT.GEMM:
785790
# Weight-only RTN->AWQ export should stay on the portable torch kernel.
786791
preferred_backend = (
@@ -2938,8 +2943,8 @@ def _linear_names(module):
29382943
def _find_parents(module, possible_names):
29392944
found = set()
29402945
for n, _ in module.named_children():
2941-
l = n.lower()
2942-
if any(k in l for k in possible_names):
2946+
lowered_name = n.lower()
2947+
if any(k in lowered_name for k in possible_names):
29432948
found.add(n)
29442949
return found
29452950

gptqmodel/utils/torch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,42 @@ def torch_devices() -> List[torch.device]:
417417
else:
418418
return [CPU]
419419

420+
421+
def npu_devices_by_pci_bus_order() -> List[torch.device]:
422+
"""Return visible NPU devices in torch runtime order.
423+
424+
Ascend exposes process-level NPU visibility through ASCEND_RT_VISIBLE_DEVICES.
425+
torch-npu remaps that visible set to logical indices, so callers should set
426+
the env var before process start and then use the resulting torch order.
427+
"""
428+
429+
if not HAS_NPU:
430+
return []
431+
432+
try:
433+
count = int(torch.npu.device_count())
434+
except Exception:
435+
return []
436+
if count <= 0:
437+
return []
438+
439+
devices: List[torch.device] = []
440+
for logical_index in range(count):
441+
try:
442+
devices.append(torch.device("npu", logical_index))
443+
except (RuntimeError, ValueError):
444+
return []
445+
return devices
446+
447+
448+
def last_npu_device_by_pci_bus_order() -> Optional[torch.device]:
449+
"""Return the last visible NPU in torch runtime order, or None when unavailable."""
450+
451+
devices = npu_devices_by_pci_bus_order()
452+
if not devices:
453+
return None
454+
return devices[-1]
455+
420456
ALL_DEVICES = torch_devices()
421457

422458
if HAS_CUDA:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ test = [
4545
"parameterized",
4646
]
4747
quality = [
48-
"ruff==0.13.0",
48+
"ruff==0.14.2",
4949
# "isort==6.0.1",
5050
]
5151
vllm = [

tests/test_npu_support.py

Lines changed: 173 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import copy
22
import os
3+
import sys
34
import warnings
45

56
import pytest
67
import torch
78
import torch.nn as nn
89

10+
from gptqmodel.looper.awq_processor import AWQProcessor
11+
from gptqmodel.looper.gptq_processor import GPTQProcessor
12+
from gptqmodel.looper.paroquant_processor import ParoQuantProcessor
13+
from gptqmodel.looper.qqq_processor import QQQProcessor
14+
from gptqmodel.looper.weight_only_processor import WeightOnlyProcessor
915
from gptqmodel.models._const import DEVICE, normalize_device
1016
from gptqmodel.nn_modules.exllamav3_torch import ExllamaV3TorchLinear
1117
from gptqmodel.nn_modules.qlinear.fp8 import TorchFP8Linear
@@ -16,21 +22,30 @@
1622
from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear
1723
from gptqmodel.quantization import FORMAT, METHOD
1824
from gptqmodel.quantization.awq.utils.packing_utils import unpack_awq
25+
from gptqmodel.quantization.config import AWQConfig, GGUFConfig, ParoConfig, QQQConfig, QuantizeConfig
1926
from gptqmodel.utils import importer
2027
from gptqmodel.utils.backend import BACKEND
2128
from gptqmodel.utils.importer import auto_select_device, get_kernel_for_backend, select_quant_linear
22-
from gptqmodel.utils.torch import HAS_NPU
29+
from gptqmodel.utils.torch import HAS_NPU, last_npu_device_by_pci_bus_order
2330

2431

25-
NPU_TEST_DEVICE = os.environ.get("GPTQMODEL_TEST_NPU_DEVICE", "npu:0")
32+
def _default_npu_test_device() -> str:
33+
selected = last_npu_device_by_pci_bus_order()
34+
return str(selected) if selected is not None else "npu:0"
35+
36+
37+
NPU_TEST_DEVICE = os.environ.get("GPTQMODEL_TEST_NPU_DEVICE", _default_npu_test_device())
2638
NPU_CPU_FALLBACK_MARKERS = (
2739
"not currently supported on the NPU backend",
2840
"fall back to run on the CPU",
2941
)
3042

3143

3244
def _test_npu_device() -> torch.device:
33-
return torch.device(NPU_TEST_DEVICE)
45+
device = torch.device(NPU_TEST_DEVICE)
46+
if HAS_NPU:
47+
torch.npu.set_device(device)
48+
return device
3449

3550

3651
def _assert_no_npu_cpu_fallback(caught: list[warnings.WarningMessage]) -> None:
@@ -253,6 +268,58 @@ def _make_exllamav3_torch_module(*, device: torch.device | str = "cpu") -> Exlla
253268
).eval()
254269

255270

271+
class _NpuProcessorModelStub:
272+
def __init__(self, qlinear_kernel=None):
273+
self.qlinear_kernel = qlinear_kernel
274+
self.rotary_embedding = None
275+
self.lm_head = "lm_head"
276+
self.model = nn.Sequential()
277+
278+
279+
def _processor_common_kwargs(qcfg):
280+
return {
281+
"tokenizer": None,
282+
"qcfg": qcfg,
283+
"calibration": None,
284+
"prepare_dataset_func": None,
285+
"calibration_concat_size": None,
286+
"calibration_sort": None,
287+
"batch_size": 1,
288+
}
289+
290+
291+
def _npu_select_quant_linear(qcfg, *, method: METHOD, fmt: FORMAT):
292+
return select_quant_linear(
293+
bits=qcfg.runtime_bits,
294+
group_size=qcfg.group_size,
295+
desc_act=qcfg.desc_act,
296+
sym=qcfg.sym,
297+
device=DEVICE.NPU,
298+
backend=BACKEND.AUTO,
299+
format=fmt,
300+
quant_method=method,
301+
pack_dtype=qcfg.pack_dtype,
302+
)
303+
304+
305+
def test_last_npu_device_by_pci_bus_order_uses_visible_logical_order(monkeypatch):
306+
try:
307+
torch.device("npu:0")
308+
except (RuntimeError, ValueError):
309+
pytest.skip("This PyTorch build does not register the npu device type")
310+
311+
class _FakeNpu:
312+
@staticmethod
313+
def device_count():
314+
return 3
315+
316+
torch_utils = sys.modules[last_npu_device_by_pci_bus_order.__module__]
317+
monkeypatch.setattr(torch_utils, "HAS_NPU", True)
318+
monkeypatch.setattr(torch_utils.torch, "npu", _FakeNpu())
319+
320+
assert str(last_npu_device_by_pci_bus_order()) == "npu:2"
321+
322+
256323
def test_npu_device_normalization():
257324
assert normalize_device("npu") is DEVICE.NPU
258325
assert normalize_device("npu:3") is DEVICE.NPU
@@ -358,13 +425,111 @@ def test_qqq_torch_backend_selects_torch_kernel():
358425
assert get_kernel_for_backend(BACKEND.QQQ_TORCH, METHOD.QQQ, FORMAT.QQQ) is QQQTorchLinear
359426

360427

428+
def test_npu_gptq_processor_has_torch_runtime_kernel():
429+
qcfg = QuantizeConfig(bits=4, group_size=128, device=DEVICE.NPU, offload_to_disk=False)
430+
processor = GPTQProcessor(**_processor_common_kwargs(qcfg))
431+
432+
assert processor.name() == "gptq"
433+
assert processor.execution_config.require_fwd is True
434+
assert _npu_select_quant_linear(qcfg, method=METHOD.GPTQ, fmt=FORMAT.GPTQ) is TorchLinear
435+
436+
437+
def test_npu_awq_processor_selects_torch_runtime_kernel():
438+
qcfg = AWQConfig(bits=4, group_size=128, device=DEVICE.NPU, offload_to_disk=False)
439+
model_stub = _NpuProcessorModelStub()
440+
processor = AWQProcessor(
441+
**_processor_common_kwargs(qcfg),
442+
gptq_model=model_stub,
443+
model=model_stub.model,
444+
)
445+
446+
assert processor.name() == "awq"
447+
assert processor.execution_config.enable_activation_capture is True
448+
assert processor.qlinear_kernel is AwqTorchLinear
449+
assert _npu_select_quant_linear(qcfg, method=METHOD.AWQ, fmt=FORMAT.GEMM) is AwqTorchLinear
450+
451+
452+
def test_npu_paroquant_processor_has_torch_runtime_kernel():
453+
qcfg = ParoConfig(
454+
bits=4,
455+
group_size=128,
456+
device=DEVICE.NPU,
457+
opt_rotation_epochs=1,
458+
opt_finetune_epochs=1,
459+
offload_to_disk=False,
460+
)
461+
model_stub = _NpuProcessorModelStub()
462+
processor = ParoQuantProcessor(
463+
**_processor_common_kwargs(qcfg),
464+
gptq_model=model_stub,
465+
model=model_stub.model,
466+
)
467+
468+
assert processor.name() == "paroquant"
469+
assert processor.execution_config.enable_activation_capture is True
470+
assert processor.qlinear_kernel is ParoLinear
471+
assert _npu_select_quant_linear(qcfg, method=METHOD.PARO, fmt=FORMAT.PAROQUANT) is ParoLinear
472+
473+
474+
def test_npu_qqq_processor_selects_torch_runtime_kernel():
475+
qcfg = QQQConfig(bits=4, group_size=128, device=DEVICE.NPU, offload_to_disk=False)
476+
processor = QQQProcessor(**_processor_common_kwargs(qcfg))
477+
qlinear_cls, backend = processor._quant_linear_kernel()
478+
479+
assert processor.name() == "qqq"
480+
assert qlinear_cls is QQQTorchLinear
481+
assert backend is BACKEND.QQQ_TORCH
482+
assert _npu_select_quant_linear(qcfg, method=METHOD.QQQ, fmt=FORMAT.QQQ) is QQQTorchLinear
483+
484+
485+
def test_npu_gguf_weight_only_processor_has_torch_runtime_kernel():
486+
qcfg = GGUFConfig(bits="q4_0", device=DEVICE.NPU, offload_to_disk=False)
487+
processor = WeightOnlyProcessor(tokenizer=None, qcfg=qcfg)
488+
489+
assert processor.name() == "weight_only_gguf"
490+
assert processor.execution_config.require_fwd is False
491+
assert _npu_select_quant_linear(qcfg, method=METHOD.GGUF, fmt=FORMAT.GGUF) is GGUFTorchLinear
492+
493+
494+
def test_npu_supported_quant_methods_have_torch_runnable_kernel():
495+
cases = [
496+
(METHOD.GPTQ, FORMAT.GPTQ, 4, 128, TorchLinear),
497+
(METHOD.AWQ, FORMAT.GEMM, 4, 128, AwqTorchLinear),
498+
(METHOD.PARO, FORMAT.PAROQUANT, 4, 128, ParoLinear),
499+
(METHOD.GGUF, FORMAT.GGUF, "q4_0", -1, GGUFTorchLinear),
500+
(METHOD.QQQ, FORMAT.QQQ, 4, 128, QQQTorchLinear),
501+
]
502+
for method, fmt, bits, group_size, expected_cls in cases:
503+
qlinear_cls = select_quant_linear(
504+
bits=bits,
505+
group_size=group_size,
506+
desc_act=False,
507+
sym=True,
508+
device=DEVICE.NPU,
509+
backend=BACKEND.AUTO,
510+
format=fmt,
511+
quant_method=method,
512+
pack_dtype=torch.int32,
513+
)
514+
assert qlinear_cls is expected_cls
515+
assert DEVICE.ALL in qlinear_cls.SUPPORTS_DEVICES or DEVICE.NPU in qlinear_cls.SUPPORTS_DEVICES
516+
517+
518+
def test_npu_exl3_has_torch_runtime_kernel():
519+
module = _make_exllamav3_torch_module()
520+
521+
assert isinstance(module, ExllamaV3TorchLinear)
522+
assert module.QUANT_TYPE == "exl3"
523+
524+
361525
def test_npu_does_not_advertise_fp8_torch_until_cann_supports_float8():
362526
assert DEVICE.ALL not in TorchFP8Linear.SUPPORTS_DEVICES
363527
assert DEVICE.NPU not in TorchFP8Linear.SUPPORTS_DEVICES
364528

365529

366530
@pytest.mark.skipif(not HAS_NPU, reason="NPU is not available")
367531
def test_npu_awq_unpack_preserves_pack_dimension():
532+
device = _test_npu_device()
368533
qweight_cpu = torch.tensor(
369534
[[0, 1, -1], [-2147483648, 2147483647, -123456789]],
370535
dtype=torch.int32,
@@ -373,8 +538,8 @@ def test_npu_awq_unpack_preserves_pack_dimension():
373538
[[-1, 0, 123456789], [2147483647, -2147483648, 7]],
374539
dtype=torch.int32,
375540
)
376-
qweight = qweight_cpu.to("npu:0")
377-
qzeros = qzeros_cpu.to("npu:0")
541+
qweight = qweight_cpu.to(device)
542+
qzeros = qzeros_cpu.to(device)
378543

379544
iweight, izeros = unpack_awq(qweight, qzeros, bits=4)
380545
shifts = torch.arange(0, 32, 4, dtype=torch.int32)
@@ -391,6 +556,7 @@ def test_npu_awq_unpack_preserves_pack_dimension():
391556

392557
@pytest.mark.skipif(not HAS_NPU, reason="NPU is not available")
393558
def test_npu_torch_gptq_unpack_preserves_pack_dimension():
559+
device = _test_npu_device()
394560
qweight_cpu = torch.tensor(
395561
[
396562
[0, 1, -1],
@@ -400,8 +566,8 @@ def test_npu_torch_gptq_unpack_preserves_pack_dimension():
400566
],
401567
dtype=torch.int32,
402568
)
403-
qweight = qweight_cpu.to("npu:0")
404-
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device="npu:0").view(1, 8, 1)
569+
qweight = qweight_cpu.to(device)
570+
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=device).view(1, 8, 1)
405571

406572
unpacked = _right_shift_unpack(
407573
qweight.unsqueeze(1).expand(-1, 8, -1),

0 commit comments

Comments
 (0)