2727
2828log = setup_logger ()
2929
30+
31+ # Packed quantized weights are unpacked through shift operations in several
32+ # kernels. Keep those shifts behind helpers so Ascend 910B/CANN 9.1 beta can use
33+ # arithmetic equivalents for operators that torch-npu does not expose as native
34+ # device kernels, while CPU/CUDA/XPU keep the standard bitwise ops.
35+ def _torch_shift_factor (shifts : int | t .Tensor , device : t .device ) -> int | t .Tensor :
36+ if t .is_tensor (shifts ):
37+ # Tensor shifts must be materialized on the target device; otherwise the
38+ # NPU arithmetic fallback would introduce cross-device operands.
39+ shifts_i64 = shifts .to (device = device , dtype = t .int64 )
40+ return t .pow (t .full_like (shifts_i64 , 2 ), shifts_i64 )
41+ return 1 << int (shifts )
42+
43+
44+ def _torch_right_shift (values : t .Tensor , shifts : int | t .Tensor ) -> t .Tensor :
45+ if values .device .type != "npu" :
46+ return t .bitwise_right_shift (values , shifts )
47+
48+ # CANN 9.1 beta on Ascend 910B may not provide bitwise_right_shift kernels
49+ # for these tensor paths. floor_divide by powers of two preserves arithmetic
50+ # right-shift behavior for signed packed int tensors and stays on-device.
51+ shifted = t .floor_divide (values .to (t .int64 ), _torch_shift_factor (shifts , values .device ))
52+ return shifted .to (values .dtype )
53+
54+
55+ def _torch_left_shift (values : t .Tensor , shifts : int | t .Tensor ) -> t .Tensor :
56+ if values .device .type != "npu" :
57+ return t .bitwise_left_shift (values , shifts )
58+
59+ # Mirror left shift as multiplication by powers of two on NPU to avoid
60+ # missing-kernel or CPU-fallback paths in torch-npu.
61+ shifted = values .to (t .int64 ) * _torch_shift_factor (shifts , values .device )
62+ return shifted .to (values .dtype )
63+
64+
3065class BaseQuantLinear (nn .Module ):
3166 SUPPORTS_BACKENDS : List [BACKEND ] = None
3267 SUPPORTS_METHODS : List [METHOD ] = None
@@ -806,14 +841,14 @@ def dequantize_weight(self, num_itr: int = 1):
806841 )
807842
808843 if self .bits in [2 , 4 , 8 ]:
809- zeros = t . bitwise_right_shift (
844+ zeros = _torch_right_shift (
810845 t .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , self .pack_factor ),
811846 self .wf_unsqueeze_zero # self.wf.unsqueeze(0),
812847 ).to (self .dequant_dtype )
813848 zeros = t .bitwise_and (zeros , self .maxq ).reshape (self .scales .shape )
814849
815850 weight = t .bitwise_and (
816- t . bitwise_right_shift (
851+ _torch_right_shift (
817852 t .unsqueeze (self .qweight , 1 ).expand (- 1 , self .pack_factor , - 1 ),
818853 self .wf_unsqueeze_neg_one # self.wf.unsqueeze(-1)
819854 ).to (self .dequant_dtype ),
@@ -823,9 +858,9 @@ def dequantize_weight(self, num_itr: int = 1):
823858 zeros = self .qzeros .reshape (self .qzeros .shape [0 ], self .qzeros .shape [1 ] // 3 , 3 , 1 ).expand (
824859 - 1 , - 1 , - 1 , 12
825860 )
826- zeros = zeros >> self .wf_unsqueeze_zero # self.wf.unsqueeze(0)
827- zeros [:, :, 0 , 10 ] = (zeros [:, :, 0 , 10 ] & 0x3 ) | ((zeros [:, :, 1 , 0 ] << 2 ) & 0x4 )
828- zeros [:, :, 1 , 11 ] = (zeros [:, :, 1 , 11 ] & 0x1 ) | ((zeros [:, :, 2 , 0 ] << 1 ) & 0x6 )
861+ zeros = _torch_right_shift ( zeros , self .wf_unsqueeze_zero ) # self.wf.unsqueeze(0)
862+ zeros [:, :, 0 , 10 ] = (zeros [:, :, 0 , 10 ] & 0x3 ) | (_torch_left_shift (zeros [:, :, 1 , 0 ], 2 ) & 0x4 )
863+ zeros [:, :, 1 , 11 ] = (zeros [:, :, 1 , 11 ] & 0x1 ) | (_torch_left_shift (zeros [:, :, 2 , 0 ], 1 ) & 0x6 )
829864 zeros = zeros & 0x7
830865 zeros = t .cat (
831866 [zeros [:, :, 0 , :11 ], zeros [:, :, 1 , 1 :12 ], zeros [:, :, 2 , 1 :11 ]],
@@ -835,9 +870,9 @@ def dequantize_weight(self, num_itr: int = 1):
835870 weight = self .qweight .reshape (self .qweight .shape [0 ] // 3 , 3 , 1 , self .qweight .shape [1 ]).expand (
836871 - 1 , - 1 , 12 , - 1
837872 )
838- weight = (weight >> self .wf_unsqueeze_neg_one ) & 0x7 # self.wf.unsqueeze(-1)
839- weight [:, 0 , 10 ] = (weight [:, 0 , 10 ] & 0x3 ) | ((weight [:, 1 , 0 ] << 2 ) & 0x4 )
840- weight [:, 1 , 11 ] = (weight [:, 1 , 11 ] & 0x1 ) | ((weight [:, 2 , 0 ] << 1 ) & 0x6 )
873+ weight = _torch_right_shift (weight , self .wf_unsqueeze_neg_one ) & 0x7 # self.wf.unsqueeze(-1)
874+ weight [:, 0 , 10 ] = (weight [:, 0 , 10 ] & 0x3 ) | (_torch_left_shift (weight [:, 1 , 0 ], 2 ) & 0x4 )
875+ weight [:, 1 , 11 ] = (weight [:, 1 , 11 ] & 0x1 ) | (_torch_left_shift (weight [:, 2 , 0 ], 1 ) & 0x6 )
841876 weight = weight & 0x7
842877 weight = t .cat ([weight [:, 0 , :11 ], weight [:, 1 , 1 :12 ], weight [:, 2 , 1 :11 ]], dim = 1 )
843878 weight = weight .reshape (weight .shape [0 ] * weight .shape [1 ], weight .shape [2 ])
0 commit comments