11import copy
22import os
3+ import sys
34import warnings
45
56import pytest
67import torch
78import torch .nn as nn
89
10+ from gptqmodel .looper .awq_processor import AWQProcessor
11+ from gptqmodel .looper .gptq_processor import GPTQProcessor
12+ from gptqmodel .looper .paroquant_processor import ParoQuantProcessor
13+ from gptqmodel .looper .qqq_processor import QQQProcessor
14+ from gptqmodel .looper .weight_only_processor import WeightOnlyProcessor
915from gptqmodel .models ._const import DEVICE , normalize_device
1016from gptqmodel .nn_modules .exllamav3_torch import ExllamaV3TorchLinear
1117from gptqmodel .nn_modules .qlinear .fp8 import TorchFP8Linear
1622from gptqmodel .nn_modules .qlinear .torch_awq import AwqTorchLinear
1723from gptqmodel .quantization import FORMAT , METHOD
1824from gptqmodel .quantization .awq .utils .packing_utils import unpack_awq
25+ from gptqmodel .quantization .config import AWQConfig , GGUFConfig , ParoConfig , QQQConfig , QuantizeConfig
1926from gptqmodel .utils import importer
2027from gptqmodel .utils .backend import BACKEND
2128from gptqmodel .utils .importer import auto_select_device , get_kernel_for_backend , select_quant_linear
22- from gptqmodel .utils .torch import HAS_NPU
29+ from gptqmodel .utils .torch import HAS_NPU , last_npu_device_by_pci_bus_order
2330
2431
25- NPU_TEST_DEVICE = os .environ .get ("GPTQMODEL_TEST_NPU_DEVICE" , "npu:0" )
32+ def _default_npu_test_device () -> str :
33+ selected = last_npu_device_by_pci_bus_order ()
34+ return str (selected ) if selected is not None else "npu:0"
35+
36+
37+ NPU_TEST_DEVICE = os .environ .get ("GPTQMODEL_TEST_NPU_DEVICE" , _default_npu_test_device ())
2638NPU_CPU_FALLBACK_MARKERS = (
2739 "not currently supported on the NPU backend" ,
2840 "fall back to run on the CPU" ,
2941)
3042
3143
3244def _test_npu_device () -> torch .device :
33- return torch .device (NPU_TEST_DEVICE )
45+ device = torch .device (NPU_TEST_DEVICE )
46+ if HAS_NPU :
47+ torch .npu .set_device (device )
48+ return device
3449
3550
3651def _assert_no_npu_cpu_fallback (caught : list [warnings .WarningMessage ]) -> None :
@@ -253,6 +268,58 @@ def _make_exllamav3_torch_module(*, device: torch.device | str = "cpu") -> Exlla
253268 ).eval ()
254269
255270
271+ class _NpuProcessorModelStub :
272+ def __init__ (self , qlinear_kernel = None ):
273+ self .qlinear_kernel = qlinear_kernel
274+ self .rotary_embedding = None
275+ self .lm_head = "lm_head"
276+ self .model = nn .Sequential ()
277+
278+
279+ def _processor_common_kwargs (qcfg ):
280+ return {
281+ "tokenizer" : None ,
282+ "qcfg" : qcfg ,
283+ "calibration" : None ,
284+ "prepare_dataset_func" : None ,
285+ "calibration_concat_size" : None ,
286+ "calibration_sort" : None ,
287+ "batch_size" : 1 ,
288+ }
289+
290+
291+ def _npu_select_quant_linear (qcfg , * , method : METHOD , fmt : FORMAT ):
292+ return select_quant_linear (
293+ bits = qcfg .runtime_bits ,
294+ group_size = qcfg .group_size ,
295+ desc_act = qcfg .desc_act ,
296+ sym = qcfg .sym ,
297+ device = DEVICE .NPU ,
298+ backend = BACKEND .AUTO ,
299+ format = fmt ,
300+ quant_method = method ,
301+ pack_dtype = qcfg .pack_dtype ,
302+ )
303+
304+
305+ def test_last_npu_device_by_pci_bus_order_uses_visible_logical_order (monkeypatch ):
306+ try :
307+ torch .device ("npu:0" )
308+ except (RuntimeError , ValueError ):
309+ pytest .skip ("This PyTorch build does not register the npu device type" )
310+
311+ class _FakeNpu :
312+ @staticmethod
313+ def device_count ():
314+ return 3
315+
316+ torch_utils = sys .modules [last_npu_device_by_pci_bus_order .__module__ ]
317+ monkeypatch .setattr (torch_utils , "HAS_NPU" , True )
318+ monkeypatch .setattr (torch_utils .torch , "npu" , _FakeNpu ())
319+
320+ assert str (last_npu_device_by_pci_bus_order ()) == "npu:2"
321+
322+
256323def test_npu_device_normalization ():
257324 assert normalize_device ("npu" ) is DEVICE .NPU
258325 assert normalize_device ("npu:3" ) is DEVICE .NPU
@@ -358,13 +425,111 @@ def test_qqq_torch_backend_selects_torch_kernel():
358425 assert get_kernel_for_backend (BACKEND .QQQ_TORCH , METHOD .QQQ , FORMAT .QQQ ) is QQQTorchLinear
359426
360427
428+ def test_npu_gptq_processor_has_torch_runtime_kernel ():
429+ qcfg = QuantizeConfig (bits = 4 , group_size = 128 , device = DEVICE .NPU , offload_to_disk = False )
430+ processor = GPTQProcessor (** _processor_common_kwargs (qcfg ))
431+
432+ assert processor .name () == "gptq"
433+ assert processor .execution_config .require_fwd is True
434+ assert _npu_select_quant_linear (qcfg , method = METHOD .GPTQ , fmt = FORMAT .GPTQ ) is TorchLinear
435+
436+
437+ def test_npu_awq_processor_selects_torch_runtime_kernel ():
438+ qcfg = AWQConfig (bits = 4 , group_size = 128 , device = DEVICE .NPU , offload_to_disk = False )
439+ model_stub = _NpuProcessorModelStub ()
440+ processor = AWQProcessor (
441+ ** _processor_common_kwargs (qcfg ),
442+ gptq_model = model_stub ,
443+ model = model_stub .model ,
444+ )
445+
446+ assert processor .name () == "awq"
447+ assert processor .execution_config .enable_activation_capture is True
448+ assert processor .qlinear_kernel is AwqTorchLinear
449+ assert _npu_select_quant_linear (qcfg , method = METHOD .AWQ , fmt = FORMAT .GEMM ) is AwqTorchLinear
450+
451+
452+ def test_npu_paroquant_processor_has_torch_runtime_kernel ():
453+ qcfg = ParoConfig (
454+ bits = 4 ,
455+ group_size = 128 ,
456+ device = DEVICE .NPU ,
457+ opt_rotation_epochs = 1 ,
458+ opt_finetune_epochs = 1 ,
459+ offload_to_disk = False ,
460+ )
461+ model_stub = _NpuProcessorModelStub ()
462+ processor = ParoQuantProcessor (
463+ ** _processor_common_kwargs (qcfg ),
464+ gptq_model = model_stub ,
465+ model = model_stub .model ,
466+ )
467+
468+ assert processor .name () == "paroquant"
469+ assert processor .execution_config .enable_activation_capture is True
470+ assert processor .qlinear_kernel is ParoLinear
471+ assert _npu_select_quant_linear (qcfg , method = METHOD .PARO , fmt = FORMAT .PAROQUANT ) is ParoLinear
472+
473+
474+ def test_npu_qqq_processor_selects_torch_runtime_kernel ():
475+ qcfg = QQQConfig (bits = 4 , group_size = 128 , device = DEVICE .NPU , offload_to_disk = False )
476+ processor = QQQProcessor (** _processor_common_kwargs (qcfg ))
477+ qlinear_cls , backend = processor ._quant_linear_kernel ()
478+
479+ assert processor .name () == "qqq"
480+ assert qlinear_cls is QQQTorchLinear
481+ assert backend is BACKEND .QQQ_TORCH
482+ assert _npu_select_quant_linear (qcfg , method = METHOD .QQQ , fmt = FORMAT .QQQ ) is QQQTorchLinear
483+
484+
485+ def test_npu_gguf_weight_only_processor_has_torch_runtime_kernel ():
486+ qcfg = GGUFConfig (bits = "q4_0" , device = DEVICE .NPU , offload_to_disk = False )
487+ processor = WeightOnlyProcessor (tokenizer = None , qcfg = qcfg )
488+
489+ assert processor .name () == "weight_only_gguf"
490+ assert processor .execution_config .require_fwd is False
491+ assert _npu_select_quant_linear (qcfg , method = METHOD .GGUF , fmt = FORMAT .GGUF ) is GGUFTorchLinear
492+
493+
494+ def test_npu_supported_quant_methods_have_torch_runnable_kernel ():
495+ cases = [
496+ (METHOD .GPTQ , FORMAT .GPTQ , 4 , 128 , TorchLinear ),
497+ (METHOD .AWQ , FORMAT .GEMM , 4 , 128 , AwqTorchLinear ),
498+ (METHOD .PARO , FORMAT .PAROQUANT , 4 , 128 , ParoLinear ),
499+ (METHOD .GGUF , FORMAT .GGUF , "q4_0" , - 1 , GGUFTorchLinear ),
500+ (METHOD .QQQ , FORMAT .QQQ , 4 , 128 , QQQTorchLinear ),
501+ ]
502+ for method , fmt , bits , group_size , expected_cls in cases :
503+ qlinear_cls = select_quant_linear (
504+ bits = bits ,
505+ group_size = group_size ,
506+ desc_act = False ,
507+ sym = True ,
508+ device = DEVICE .NPU ,
509+ backend = BACKEND .AUTO ,
510+ format = fmt ,
511+ quant_method = method ,
512+ pack_dtype = torch .int32 ,
513+ )
514+ assert qlinear_cls is expected_cls
515+ assert DEVICE .ALL in qlinear_cls .SUPPORTS_DEVICES or DEVICE .NPU in qlinear_cls .SUPPORTS_DEVICES
516+
517+
518+ def test_npu_exl3_has_torch_runtime_kernel ():
519+ module = _make_exllamav3_torch_module ()
520+
521+ assert isinstance (module , ExllamaV3TorchLinear )
522+ assert module .QUANT_TYPE == "exl3"
523+
524+
361525def test_npu_does_not_advertise_fp8_torch_until_cann_supports_float8 ():
362526 assert DEVICE .ALL not in TorchFP8Linear .SUPPORTS_DEVICES
363527 assert DEVICE .NPU not in TorchFP8Linear .SUPPORTS_DEVICES
364528
365529
366530@pytest .mark .skipif (not HAS_NPU , reason = "NPU is not available" )
367531def test_npu_awq_unpack_preserves_pack_dimension ():
532+ device = _test_npu_device ()
368533 qweight_cpu = torch .tensor (
369534 [[0 , 1 , - 1 ], [- 2147483648 , 2147483647 , - 123456789 ]],
370535 dtype = torch .int32 ,
@@ -373,8 +538,8 @@ def test_npu_awq_unpack_preserves_pack_dimension():
373538 [[- 1 , 0 , 123456789 ], [2147483647 , - 2147483648 , 7 ]],
374539 dtype = torch .int32 ,
375540 )
376- qweight = qweight_cpu .to ("npu:0" )
377- qzeros = qzeros_cpu .to ("npu:0" )
541+ qweight = qweight_cpu .to (device )
542+ qzeros = qzeros_cpu .to (device )
378543
379544 iweight , izeros = unpack_awq (qweight , qzeros , bits = 4 )
380545 shifts = torch .arange (0 , 32 , 4 , dtype = torch .int32 )
@@ -391,6 +556,7 @@ def test_npu_awq_unpack_preserves_pack_dimension():
391556
392557@pytest .mark .skipif (not HAS_NPU , reason = "NPU is not available" )
393558def test_npu_torch_gptq_unpack_preserves_pack_dimension ():
559+ device = _test_npu_device ()
394560 qweight_cpu = torch .tensor (
395561 [
396562 [0 , 1 , - 1 ],
@@ -400,8 +566,8 @@ def test_npu_torch_gptq_unpack_preserves_pack_dimension():
400566 ],
401567 dtype = torch .int32 ,
402568 )
403- qweight = qweight_cpu .to ("npu:0" )
404- shifts = torch .arange (0 , 32 , 4 , dtype = torch .int32 , device = "npu:0" ).view (1 , 8 , 1 )
569+ qweight = qweight_cpu .to (device )
570+ shifts = torch .arange (0 , 32 , 4 , dtype = torch .int32 , device = device ).view (1 , 8 , 1 )
405571
406572 unpacked = _right_shift_unpack (
407573 qweight .unsqueeze (1 ).expand (- 1 , 8 , - 1 ),
0 commit comments