Skip to content

Commit 44b02d9

Browse files
committed
Add NPU quant method coverage
1 parent 1002e69 commit 44b02d9

5 files changed

Lines changed: 352 additions & 16 deletions

File tree

gptqmodel/looper/awq_processor.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor
1919
from ..looper.named_module import NamedModule
2020
from ..models import BaseQModel
21-
from ..models._const import SUPPORTS_MODULE_TYPES
21+
from ..models._const import DEVICE, SUPPORTS_MODULE_TYPES
2222
from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
2323
PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
2424
from ..nn_modules.qlinear.gemm_awq import AwqGEMMLinear
2525
from ..nn_modules.qlinear.gemv_awq import AwqGEMVLinear
2626
from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastLinear, LLMAwqLinear
27+
from ..nn_modules.qlinear.torch_awq import AwqTorchLinear
2728
from ..quantization.awq.quantize.scale import apply_clip, apply_scale
2829
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name
2930
from ..quantization.awq.utils.utils import get_best_device
@@ -272,12 +273,31 @@ def set_calibration_dataset(self, calibration_dataset):
272273

273274
raise NotImplementedError("AWQProcessor's calibration_dataset cannot be modified")
274275

276+
def _quant_device_is_npu(self) -> bool:
277+
"""Return whether this processor is quantizing for an Ascend NPU runtime."""
278+
279+
device = getattr(self.qcfg, "device", None)
280+
if isinstance(device, DEVICE):
281+
return device == DEVICE.NPU
282+
if isinstance(device, torch.device):
283+
return device.type == "npu"
284+
if isinstance(device, str):
285+
return device.split(":", 1)[0].lower() == "npu"
286+
return False
287+
275288
def _select_qlinear_kernel_for_format(self, format_value: FORMAT):
276289
"""Maps the resolved AWQ format to its concrete quantized linear kernel."""
277290

278291
fmt = FORMAT(format_value) if not isinstance(format_value, FORMAT) else format_value
279292
if fmt == FORMAT.GEMM:
293+
if self._quant_device_is_npu():
294+
return AwqTorchLinear
280295
return AwqGEMMLinear
296+
if self._quant_device_is_npu():
297+
raise ValueError(
298+
"NPU AWQ quantization requires FORMAT.GEMM so the AwqTorchLinear runtime can run on NPU; "
299+
f"actual format is `{fmt}`."
300+
)
281301
if fmt == FORMAT.GEMV:
282302
return AwqGEMVLinear
283303
if fmt == FORMAT.GEMV_FAST:
@@ -1820,11 +1840,7 @@ def preprocess(self, module: NamedModule, fallback=None, **kwargs):
18201840
def is_skipped(self, module: NamedModule) -> bool:
18211841
"""Reports whether preprocessing excluded this module from AWQ work."""
18221842

1823-
t = self.tasks.get(module.name, False)
1824-
if t == False:
1825-
return True
1826-
else:
1827-
return False
1843+
return not self.tasks.get(module.name, False)
18281844

18291845
def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]:
18301846
"""Returns the forward hook that caches module input activations for AWQ."""

gptqmodel/models/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,11 @@ def quantize(
781781
quant_device_type = str(quant_device).split(":")[0].lower()
782782

783783
if export_quant_method == METHOD.AWQ:
784+
if quant_device_type == "npu" and format_code != FORMAT.GEMM:
785+
raise ValueError(
786+
"NPU AWQ quantization requires FORMAT.GEMM so the AwqTorchLinear runtime can run on NPU; "
787+
f"actual format is `{format_code}`."
788+
)
784789
if format_code == FORMAT.GEMM:
785790
# Weight-only RTN->AWQ export should stay on the portable torch kernel.
786791
preferred_backend = (
@@ -2938,8 +2943,8 @@ def _linear_names(module):
29382943
def _find_parents(module, possible_names):
29392944
found = set()
29402945
for n, _ in module.named_children():
2941-
l = n.lower()
2942-
if any(k in l for k in possible_names):
2946+
lowered_name = n.lower()
2947+
if any(k in lowered_name for k in possible_names):
29432948
found.add(n)
29442949
return found
29452950

gptqmodel/utils/torch.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import contextlib
77
import importlib
8+
import os
89
import time
910
from contextlib import contextmanager
1011
from enum import Enum
@@ -417,6 +418,119 @@ def torch_devices() -> List[torch.device]:
417418
else:
418419
return [CPU]
419420

421+
422+
def _npu_device_pci_bus_id(index: int) -> Optional[str]:
423+
"""Best-effort PCI bus identifier for a visible NPU logical index."""
424+
425+
if not HAS_NPU:
426+
return None
427+
428+
npu = getattr(torch, "npu", None)
429+
get_props = getattr(npu, "get_device_properties", None)
430+
if not callable(get_props):
431+
return None
432+
433+
try:
434+
props = get_props(index)
435+
except Exception:
436+
return None
437+
438+
attr_names = (
439+
"pci_bus_id",
440+
"pci_busid",
441+
"bus_id",
442+
"busid",
443+
"pcie_bus_id",
444+
"pcie_id",
445+
)
446+
for attr_name in attr_names:
447+
if isinstance(props, dict):
448+
value = props.get(attr_name)
449+
else:
450+
value = getattr(props, attr_name, None)
451+
if callable(value):
452+
try:
453+
value = value()
454+
except Exception:
455+
value = None
456+
if value not in (None, ""):
457+
return str(value).strip().lower()
458+
return None
459+
460+
461+
def _parse_ascend_visible_devices() -> List[int]:
462+
"""Parse ASCEND_RT_VISIBLE_DEVICES without depending on torch-npu internals."""
463+
464+
visible = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
465+
if visible is None:
466+
return []
467+
result: List[int] = []
468+
for item in visible.split(","):
469+
item = item.strip()
470+
if not item:
471+
return []
472+
try:
473+
result.append(int(item))
474+
except ValueError:
475+
return []
476+
return result
477+
478+
479+
def npu_devices_by_pci_bus_order() -> List[torch.device]:
480+
"""Return visible NPU devices ordered by PCI bus id when the runtime exposes it.
481+
482+
If torch-npu does not expose a bus id, the visible logical order is used.
483+
With ASCEND_RT_VISIBLE_DEVICES set in PCI order, this keeps the requested
484+
bus ordering while still returning logical torch device indices.
485+
"""
486+
487+
if not HAS_NPU:
488+
return []
489+
490+
try:
491+
count = int(torch.npu.device_count())
492+
except Exception:
493+
return []
494+
if count <= 0:
495+
return []
496+
497+
bus_entries = []
498+
for logical_index in range(count):
499+
bus_id = _npu_device_pci_bus_id(logical_index)
500+
if bus_id:
501+
bus_entries.append((bus_id, logical_index))
502+
503+
if len(bus_entries) == count:
504+
entries = sorted(bus_entries, key=lambda item: (item[0], item[1]))
505+
else:
506+
visible_physical = _parse_ascend_visible_devices()
507+
entries = [
508+
(
509+
visible_physical[logical_index]
510+
if logical_index < len(visible_physical)
511+
else logical_index,
512+
logical_index,
513+
)
514+
for logical_index in range(count)
515+
]
516+
517+
devices: List[torch.device] = []
518+
for _, logical_index in entries:
519+
try:
520+
devices.append(torch.device("npu", logical_index))
521+
except (RuntimeError, ValueError):
522+
return []
523+
return devices
524+
525+
526+
def last_npu_device_by_pci_bus_order() -> Optional[torch.device]:
527+
"""Return the last visible NPU in PCI bus order, or None when unavailable."""
528+
529+
devices = npu_devices_by_pci_bus_order()
530+
if not devices:
531+
return None
532+
return devices[-1]
533+
420534
ALL_DEVICES = torch_devices()
421535

422536
if HAS_CUDA:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ test = [
4545
"parameterized",
4646
]
4747
quality = [
48-
"ruff==0.13.0",
48+
"ruff==0.14.2",
4949
# "isort==6.0.1",
5050
]
5151
vllm = [

0 commit comments

Comments
 (0)