|
18 | 18 | from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor |
19 | 19 | from ..looper.named_module import NamedModule |
20 | 20 | from ..models import BaseQModel |
21 | | -from ..models._const import SUPPORTS_MODULE_TYPES |
| 21 | +from ..models._const import DEVICE, SUPPORTS_MODULE_TYPES |
22 | 22 | from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, |
23 | 23 | PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) |
24 | 24 | from ..nn_modules.qlinear.gemm_awq import AwqGEMMLinear |
25 | 25 | from ..nn_modules.qlinear.gemv_awq import AwqGEMVLinear |
26 | 26 | from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastLinear, LLMAwqLinear |
| 27 | +from ..nn_modules.qlinear.torch_awq import AwqTorchLinear |
27 | 28 | from ..quantization.awq.quantize.scale import apply_clip, apply_scale |
28 | 29 | from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name |
29 | 30 | from ..quantization.awq.utils.utils import get_best_device |
@@ -272,12 +273,31 @@ def set_calibration_dataset(self, calibration_dataset): |
272 | 273 |
|
273 | 274 | raise NotImplementedError("AWQProcessor's calibration_dataset cannot be modified") |
274 | 275 |
|
| 276 | + def _quant_device_is_npu(self) -> bool: |
| 277 | + """Return whether this processor is quantizing for an Ascend NPU runtime.""" |
| 278 | + |
| 279 | + device = getattr(self.qcfg, "device", None) |
| 280 | + if isinstance(device, DEVICE): |
| 281 | + return device == DEVICE.NPU |
| 282 | + if isinstance(device, torch.device): |
| 283 | + return device.type == "npu" |
| 284 | + if isinstance(device, str): |
| 285 | + return device.split(":", 1)[0].lower() == "npu" |
| 286 | + return False |
| 287 | + |
275 | 288 | def _select_qlinear_kernel_for_format(self, format_value: FORMAT): |
276 | 289 | """Maps the resolved AWQ format to its concrete quantized linear kernel.""" |
277 | 290 |
|
278 | 291 | fmt = FORMAT(format_value) if not isinstance(format_value, FORMAT) else format_value |
279 | 292 | if fmt == FORMAT.GEMM: |
| 293 | + if self._quant_device_is_npu(): |
| 294 | + return AwqTorchLinear |
280 | 295 | return AwqGEMMLinear |
| 296 | + if self._quant_device_is_npu(): |
| 297 | + raise ValueError( |
| 298 | + "NPU AWQ quantization requires FORMAT.GEMM so the AwqTorchLinear runtime can run on NPU; " |
| 299 | + f"actual format is `{fmt}`." |
| 300 | + ) |
281 | 301 | if fmt == FORMAT.GEMV: |
282 | 302 | return AwqGEMVLinear |
283 | 303 | if fmt == FORMAT.GEMV_FAST: |
@@ -1820,11 +1840,7 @@ def preprocess(self, module: NamedModule, fallback=None, **kwargs): |
1820 | 1840 | def is_skipped(self, module: NamedModule) -> bool: |
1821 | 1841 | """Reports whether preprocessing excluded this module from AWQ work.""" |
1822 | 1842 |
|
1823 | | - t = self.tasks.get(module.name, False) |
1824 | | - if t == False: |
1825 | | - return True |
1826 | | - else: |
1827 | | - return False |
| 1843 | + return not self.tasks.get(module.name, False) |
1828 | 1844 |
|
1829 | 1845 | def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: |
1830 | 1846 | """Returns the forward hook that caches module input activations for AWQ.""" |
|
0 commit comments