Skip to content

Commit 3ff0371

Browse files
authored
ascend kernel compat update: cann 9.1beta1 (#2906)
* ascend kernel compat update: cann 9.1beta1 * document ascend shift compatibility helpers
1 parent 97a2958 commit 3ff0371

3 files changed

Lines changed: 154 additions & 28 deletions

File tree

gptqmodel/nn_modules/exllamav3_torch.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,40 @@ def _half_scalar_from_bits(bits: int) -> float:
7070
return float(struct.unpack("<e", packed)[0])
7171

7272

73+
# ExLlamaV3's torch fallback decodes packed codebook values with shift
74+
# operations. Keep those shifts behind helpers so Ascend 910B/CANN 9.1 beta can
75+
# use arithmetic equivalents for operators that torch-npu does not expose as
76+
# native device kernels, while CPU/CUDA/XPU keep the standard bitwise ops.
77+
def _torch_shift_factor(shifts: int | torch.Tensor, device: torch.device) -> int | torch.Tensor:
78+
if torch.is_tensor(shifts):
79+
# Tensor shifts must be materialized on the target device; otherwise the
80+
# NPU arithmetic fallback would introduce cross-device operands.
81+
shifts_i64 = shifts.to(device=device, dtype=torch.int64)
82+
return torch.pow(torch.full_like(shifts_i64, 2), shifts_i64)
83+
return 1 << int(shifts)
84+
85+
86+
def _torch_right_shift(values: torch.Tensor, shifts: int | torch.Tensor) -> torch.Tensor:
87+
if values.device.type != "npu":
88+
return torch.bitwise_right_shift(values, shifts)
89+
90+
# CANN 9.1 beta on Ascend 910B may not provide bitwise_right_shift kernels
91+
# for these tensor paths. floor_divide by powers of two preserves arithmetic
92+
# right-shift behavior for signed packed int tensors and stays on-device.
93+
shifted = torch.floor_divide(values.to(torch.int64), _torch_shift_factor(shifts, values.device))
94+
return shifted.to(values.dtype)
95+
96+
97+
def _torch_left_shift(values: torch.Tensor, shifts: int | torch.Tensor) -> torch.Tensor:
98+
if values.device.type != "npu":
99+
return torch.bitwise_left_shift(values, shifts)
100+
101+
# Mirror left shift as multiplication by powers of two on NPU to avoid
102+
# missing-kernel or CPU-fallback paths in torch-npu.
103+
shifted = values.to(torch.int64) * _torch_shift_factor(shifts, values.device)
104+
return shifted.to(values.dtype)
105+
106+
73107
_EXL3_MUL1_INV = _half_scalar_from_bits(0x1EEE)
74108
_EXL3_MUL1_BIAS = _half_scalar_from_bits(0xC931)
75109

@@ -117,7 +151,7 @@ def _codebook_lut(
117151
halves = torch.stack(
118152
(
119153
(raw & 0xFFFF).to(torch.uint16),
120-
((raw >> 16) & 0xFFFF).to(torch.uint16),
154+
(_torch_right_shift(raw, 16) & 0xFFFF).to(torch.uint16),
121155
),
122156
dim=-1,
123157
).contiguous()
@@ -130,7 +164,7 @@ def _codebook_lut(
130164
halves = torch.stack(
131165
(
132166
(raw & 0xFFFF).to(torch.uint16),
133-
((raw >> 16) & 0xFFFF).to(torch.uint16),
167+
(_torch_right_shift(raw, 16) & 0xFFFF).to(torch.uint16),
134168
),
135169
dim=-1,
136170
).contiguous()
@@ -141,9 +175,9 @@ def _codebook_lut(
141175
raw = (values * _EXL3_MUL1_MULT) & 0xFFFFFFFF
142176
byte_sum = (
143177
(raw & 0xFF)
144-
+ ((raw >> 8) & 0xFF)
145-
+ ((raw >> 16) & 0xFF)
146-
+ ((raw >> 24) & 0xFF)
178+
+ (_torch_right_shift(raw, 8) & 0xFF)
179+
+ (_torch_right_shift(raw, 16) & 0xFF)
180+
+ (_torch_right_shift(raw, 24) & 0xFF)
147181
)
148182
accum = (byte_sum + _EXL3_MUL1_ACC).to(torch.uint16).contiguous()
149183
floats = accum.view(torch.float16).to(torch.float32)
@@ -297,23 +331,23 @@ def _unpack_indices(self) -> torch.Tensor:
297331
bit_in_word = bit_offset % 16
298332
if bit_in_word + bits <= 16:
299333
shift = 16 - bit_in_word - bits
300-
value = (words[..., word_idx] >> shift) & mask
334+
value = _torch_right_shift(words[..., word_idx], shift) & mask
301335
else:
302336
bits_first = 16 - bit_in_word
303337
bits_second = bits - bits_first
304-
high = (words[..., word_idx] & ((1 << bits_first) - 1)) << bits_second
305-
low = words[..., word_idx + 1] >> (16 - bits_second)
338+
high = _torch_left_shift(words[..., word_idx] & ((1 << bits_first) - 1), bits_second)
339+
low = _torch_right_shift(words[..., word_idx + 1], 16 - bits_second)
306340
value = (high | low) & mask
307341
symbols[..., pos::16] = value.to(torch.long)
308342

309343
warmup = (16 + bits - 1) // bits - 1
310344
state = torch.zeros_like(symbols[..., 0], dtype=torch.long)
311345
for idx in range(256 - warmup, 256):
312-
state = ((state << bits) | symbols[..., idx]) & 0xFFFF
346+
state = (_torch_left_shift(state, bits) | symbols[..., idx]) & 0xFFFF
313347

314348
encoded = torch.empty_like(symbols)
315349
for idx in range(256):
316-
state = ((state << bits) | symbols[..., idx]) & 0xFFFF
350+
state = (_torch_left_shift(state, bits) | symbols[..., idx]) & 0xFFFF
317351
encoded[..., idx] = state
318352

319353
return encoded

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,41 @@
2727

2828
log = setup_logger()
2929

30+
31+
# Packed quantized weights are unpacked through shift operations in several
32+
# kernels. Keep those shifts behind helpers so Ascend 910B/CANN 9.1 beta can use
33+
# arithmetic equivalents for operators that torch-npu does not expose as native
34+
# device kernels, while CPU/CUDA/XPU keep the standard bitwise ops.
35+
def _torch_shift_factor(shifts: int | t.Tensor, device: t.device) -> int | t.Tensor:
36+
if t.is_tensor(shifts):
37+
# Tensor shifts must be materialized on the target device; otherwise the
38+
# NPU arithmetic fallback would introduce cross-device operands.
39+
shifts_i64 = shifts.to(device=device, dtype=t.int64)
40+
return t.pow(t.full_like(shifts_i64, 2), shifts_i64)
41+
return 1 << int(shifts)
42+
43+
44+
def _torch_right_shift(values: t.Tensor, shifts: int | t.Tensor) -> t.Tensor:
45+
if values.device.type != "npu":
46+
return t.bitwise_right_shift(values, shifts)
47+
48+
# CANN 9.1 beta on Ascend 910B may not provide bitwise_right_shift kernels
49+
# for these tensor paths. floor_divide by powers of two preserves arithmetic
50+
# right-shift behavior for signed packed int tensors and stays on-device.
51+
shifted = t.floor_divide(values.to(t.int64), _torch_shift_factor(shifts, values.device))
52+
return shifted.to(values.dtype)
53+
54+
55+
def _torch_left_shift(values: t.Tensor, shifts: int | t.Tensor) -> t.Tensor:
56+
if values.device.type != "npu":
57+
return t.bitwise_left_shift(values, shifts)
58+
59+
# Mirror left shift as multiplication by powers of two on NPU to avoid
60+
# missing-kernel or CPU-fallback paths in torch-npu.
61+
shifted = values.to(t.int64) * _torch_shift_factor(shifts, values.device)
62+
return shifted.to(values.dtype)
63+
64+
3065
class BaseQuantLinear(nn.Module):
3166
SUPPORTS_BACKENDS: List[BACKEND] = None
3267
SUPPORTS_METHODS: List[METHOD] = None
@@ -806,14 +841,14 @@ def dequantize_weight(self, num_itr: int = 1):
806841
)
807842

808843
if self.bits in [2, 4, 8]:
809-
zeros = t.bitwise_right_shift(
844+
zeros = _torch_right_shift(
810845
t.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor),
811846
self.wf_unsqueeze_zero # self.wf.unsqueeze(0),
812847
).to(self.dequant_dtype)
813848
zeros = t.bitwise_and(zeros, self.maxq).reshape(self.scales.shape)
814849

815850
weight = t.bitwise_and(
816-
t.bitwise_right_shift(
851+
_torch_right_shift(
817852
t.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1),
818853
self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1)
819854
).to(self.dequant_dtype),
@@ -823,9 +858,9 @@ def dequantize_weight(self, num_itr: int = 1):
823858
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand(
824859
-1, -1, -1, 12
825860
)
826-
zeros = zeros >> self.wf_unsqueeze_zero # self.wf.unsqueeze(0)
827-
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4)
828-
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6)
861+
zeros = _torch_right_shift(zeros, self.wf_unsqueeze_zero) # self.wf.unsqueeze(0)
862+
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | (_torch_left_shift(zeros[:, :, 1, 0], 2) & 0x4)
863+
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | (_torch_left_shift(zeros[:, :, 2, 0], 1) & 0x6)
829864
zeros = zeros & 0x7
830865
zeros = t.cat(
831866
[zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]],
@@ -835,9 +870,9 @@ def dequantize_weight(self, num_itr: int = 1):
835870
weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(
836871
-1, -1, 12, -1
837872
)
838-
weight = (weight >> self.wf_unsqueeze_neg_one) & 0x7 # self.wf.unsqueeze(-1)
839-
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4)
840-
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6)
873+
weight = _torch_right_shift(weight, self.wf_unsqueeze_neg_one) & 0x7 # self.wf.unsqueeze(-1)
874+
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | (_torch_left_shift(weight[:, 1, 0], 2) & 0x4)
875+
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | (_torch_left_shift(weight[:, 2, 0], 1) & 0x6)
841876
weight = weight & 0x7
842877
weight = t.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1)
843878
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

gptqmodel/utils/torch.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import contextlib
77
import importlib
8+
import os
89
import time
910
from contextlib import contextmanager
1011
from enum import Enum
@@ -84,9 +85,21 @@ def timed_gc_collect() -> int:
8485
except Exception:
8586
HAS_MPS = False
8687

88+
89+
def _ascend_runtime_env_ready() -> bool:
90+
# torch_npu may report available before the CANN environment is fully
91+
# sourced. Requiring the standard Ascend paths avoids import-time lazy
92+
# initialization failures when the Python package is installed but the
93+
# runtime is not usable in this shell.
94+
return any(
95+
os.environ.get(name)
96+
for name in ("ASCEND_HOME_PATH", "ASCEND_TOOLKIT_HOME", "ASCEND_OPP_PATH")
97+
)
98+
99+
87100
try:
88101
importlib.import_module("torch_npu")
89-
HAS_NPU = torch.npu.is_available()
102+
HAS_NPU = _ascend_runtime_env_ready() and torch.npu.is_available()
90103
except Exception:
91104
HAS_NPU = False
92105

@@ -455,20 +468,62 @@ def last_npu_device_by_pci_bus_order() -> Optional[torch.device]:
455468

456469
ALL_DEVICES = torch_devices()
457470

458-
if HAS_CUDA:
459-
ALL_STREAMS = [torch.cuda.Stream(device=device) for device in ALL_DEVICES]
460-
elif HAS_XPU:
461-
ALL_STREAMS = [torch.xpu.Stream(device=device) for device in ALL_DEVICES]
462-
elif HAS_NPU:
463-
ALL_STREAMS = [torch.npu.Stream(device=device) for device in ALL_DEVICES]
464-
else:
465-
ALL_STREAMS = [contextlib.nullcontext()]
471+
472+
class _LazyAcceleratorStreams:
473+
"""Create accelerator streams only when a caller actually needs them."""
474+
475+
def __init__(self, devices: List[torch.device]):
476+
self._devices = devices
477+
self._streams: List[Optional[object]] = [None] * len(devices)
478+
479+
def __len__(self):
480+
return max(1, len(self._devices))
481+
482+
def __iter__(self):
483+
for index in range(len(self)):
484+
yield self[index]
485+
486+
def __getitem__(self, index: int):
487+
if not self._devices:
488+
if index == 0:
489+
return contextlib.nullcontext()
490+
raise IndexError(index)
491+
492+
stream = self._streams[index]
493+
if stream is not None:
494+
return stream
495+
496+
device = self._devices[index]
497+
if device.type == "cuda":
498+
stream = torch.cuda.Stream(device=device)
499+
elif device.type == "xpu":
500+
stream = torch.xpu.Stream(device=device)
501+
elif device.type == "npu":
502+
stream = torch.npu.Stream(device=device)
503+
else:
504+
stream = contextlib.nullcontext()
505+
self._streams[index] = stream
506+
return stream
507+
508+
509+
class _LazyAcceleratorStreamRef:
510+
"""Reference one lazy stream without materializing it at import time."""
511+
512+
def __init__(self, streams: _LazyAcceleratorStreams, index: int):
513+
self._streams = streams
514+
self._index = index
515+
516+
def get(self):
517+
return self._streams[self._index]
518+
519+
520+
ALL_STREAMS = _LazyAcceleratorStreams(ALL_DEVICES)
466521

467522
DEVICE_0 = auto_select_torch_device(index=0)
468523
# device_1 may be same as device_0 if there is only 1 visible/active device
469524
DEVICE_1 = auto_select_torch_device(index=1)
470525

471-
DEVICE_0_STREAM = ALL_STREAMS[0]
526+
DEVICE_0_STREAM = _LazyAcceleratorStreamRef(ALL_STREAMS, 0)
472527

473528
NEXT_DEVICE_INDEX = 0
474529

@@ -494,6 +549,8 @@ def last_npu_device_by_pci_bus_order() -> Optional[torch.device]:
494549
# return device
495550

496551
def torch_streamCtx(stream) -> StreamContext:
552+
if isinstance(stream, _LazyAcceleratorStreamRef):
553+
stream = stream.get()
497554
if HAS_CUDA:
498555
return torch.cuda.stream(stream)
499556
if HAS_XPU:

0 commit comments

Comments
 (0)