Skip to content

Commit 264bdc4

Browse files
authored
Add Ascend NPU support (#2831)
* Add Ascend NPU support * Document Ascend NPU CUDA API equivalents * Validate torch NPU qlinear parity * Add NPU-native GPTQ Hessian inverse * Add torch QQQ backend for NPU * Prepare README for 7.0.0 release * Address NPU quality review comments * Clean up 7.0 README release note * Fix device normalization in kernel auto selection
1 parent aeb641f commit 264bdc4

36 files changed

Lines changed: 1593 additions & 151 deletions

README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
</div>
55
<h1 align="center">GPT-QModel</h1>
66
</p>
7-
<p align="center">LLM model quantization (compression) toolkit with hw acceleration support for NVIDIA CUDA, AMD ROCm, Intel XPU, and Intel/AMD/Apple CPUs via HF, vLLM, and SGLang.</p>
7+
<p align="center">LLM model quantization (compression) toolkit with hw acceleration support for NVIDIA CUDA, AMD ROCm, Huawei Ascend NPU, Intel XPU, and Intel/AMD/Apple CPUs via HF, vLLM, and SGLang.</p>
88
<p align="center">
99
<a href="https://github.com/ModelCloud/GPTQModel/releases" style="text-decoration:none;"><img alt="GitHub release" src="https://img.shields.io/github/release/ModelCloud/GPTQModel.svg"></a>
1010
<a href="https://pypi.org/project/gptqmodel/" style="text-decoration:none;"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/gptqmodel"></a>
@@ -20,8 +20,7 @@
2020
</p>
2121

2222
## Latest News
23-
* 04/27/2026 6.1.0-dev `main`: ✨ Added `internvl_chat` model support.
24-
* 04/23/2026 6.1.0-dev `main`: ✨ Added `gemma3n``GLM-OCR``GLM-ASR` and `falcon_mamba` model support.
23+
* 04/28/2026 [7.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v7.0.0): 🚀 Added Huawei Ascend NPU support through native torch kernels for GPTQ, AWQ, ParoQuant, GGUF, QQQ, and EXL3. Added `internvl_chat`, `gemma3n`, `GLM-OCR`, `GLM-ASR`, and `falcon_mamba` model support.
2524
* 04/16/2026 [6.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.1.0): 🚀🔥⚡ CUDA kernels are now fully JIT-compiled, shrinking the wheel by about 300x and building only what you use; Marlin now supports NVIDIA `Turing+` GPUs, Machete kernel validation now covers supported GPUs, `GLM 5/5.1` joins the lineup, and LazyTurtle plus AWQ / multi-GPU MoE fixes make large-model quantization easier, lighter, and smoother.
2625
* 04/03/2026 [6.0.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.3): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MoE Lite` model support.
2726
* 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ.
@@ -191,14 +190,14 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat
191190

192191
### Quant Method / Format / Backend Matrix
193192

194-
Canonical backend names are shown below. Legacy aliases such as `BACKEND.TORCH`, `BACKEND.MARLIN`, `BACKEND.GEMM`, and `BACKEND.PARO` are still accepted and normalized to the matching canonical backend for the selected quant method.
193+
Canonical backend names are shown below. Method-specific aliases are only accepted where explicitly implemented by that quant method.
195194

196195
| Quant Method | Formats | Backends / Kernels |
197196
| --- | --- | --- |
198197
| `METHOD.GPTQ` | `FORMAT.GPTQ`, `FORMAT.GPTQ_V2`, `FORMAT.MARLIN`, `FORMAT.BITBLAS` | `FORMAT.GPTQ`: `BACKEND.GPTQ_TORCH_ATEN`, `BACKEND.GPTQ_MACHETE`, `BACKEND.GPTQ_MARLIN`, `BACKEND.GPTQ_EXLLAMA_V2`, `BACKEND.GPTQ_TORCH_FUSED`, `BACKEND.GPTQ_TRITON`, `BACKEND.GPTQ_BITBLAS`, `BACKEND.GPTQ_TORCH`, `BACKEND.GPTQ_TORCH_INT8`<br>`FORMAT.GPTQ_V2`: `BACKEND.GPTQ_TORCH_ATEN`, `BACKEND.GPTQ_EXLLAMA_V2`, `BACKEND.GPTQ_TORCH_FUSED`, `BACKEND.GPTQ_TRITON`, `BACKEND.GPTQ_BITBLAS`, `BACKEND.GPTQ_TORCH`, `BACKEND.GPTQ_TORCH_INT8`<br>`FORMAT.MARLIN`: `BACKEND.GPTQ_MARLIN`<br>`FORMAT.BITBLAS`: `BACKEND.GPTQ_BITBLAS` |
199198
| `METHOD.AWQ` | `FORMAT.GEMM`, `FORMAT.GEMV`, `FORMAT.GEMV_FAST`, `FORMAT.LLM_AWQ`, `FORMAT.MARLIN`, `FORMAT.BITBLAS` | `FORMAT.GEMM`: `BACKEND.AWQ_TORCH_ATEN`, `BACKEND.AWQ_MACHETE`, `BACKEND.AWQ_MARLIN`, `BACKEND.AWQ_EXLLAMA_V2`, `BACKEND.AWQ_GEMM`, `BACKEND.AWQ_GEMM_TRITON`, `BACKEND.AWQ_TORCH_FUSED`, `BACKEND.AWQ_TORCH`, `BACKEND.AWQ_TORCH_INT8`, `BACKEND.AWQ_BITBLAS`<br>`FORMAT.GEMV`: `BACKEND.AWQ_GEMV`<br>`FORMAT.GEMV_FAST`: `BACKEND.AWQ_GEMV_FAST`<br>`FORMAT.LLM_AWQ`: `BACKEND.AWQ_GEMV_FAST`<br>`FORMAT.MARLIN`: `BACKEND.AWQ_MACHETE`, `BACKEND.AWQ_MARLIN`<br>`FORMAT.BITBLAS`: `BACKEND.AWQ_BITBLAS` |
200199
| `METHOD.PARO` | `FORMAT.PAROQUANT` | `BACKEND.PAROQUANT_CUDA`, `BACKEND.PAROQUANT_TRITON` |
201-
| `METHOD.QQQ` | `FORMAT.QQQ` | `BACKEND.QQQ` |
200+
| `METHOD.QQQ` | `FORMAT.QQQ` | `BACKEND.QQQ`, `BACKEND.QQQ_TORCH` |
202201
| `METHOD.GGUF` | `FORMAT.GGUF` | `BACKEND.GGUF_TRITON`, `BACKEND.GGUF_CPP_CUDA`, `BACKEND.GGUF_CPP_CPU`, `BACKEND.GGUF_TORCH` |
203202
| `METHOD.FP8` | `FORMAT.FP8` | `BACKEND.FP8_TORCH` |
204203
| `METHOD.BITSANDBYTES` | `FORMAT.BITSANDBYTES` | `BACKEND.BITSANDBYTES` |
@@ -216,7 +215,7 @@ Marlin uses `GPTQMODEL_MARLIN_USE_FP32` (default: enabled) to control fp32 accum
216215
* 🚀 Quantize MoE models with ease even with extreme routing activation bias via `Moe.Routing` and/or `FailSafe`.
217216
* 🚀 Data Parallelism for 80%+ quantization speed reduction with Multi-GPU.
218217
* 🚀 Optimized for Python >= 3.13t (free threading) with lock-free threading.
219-
* ✨ Linux, macOS, Windows platform support for CUDA (NVIDIA), XPU (Intel), ROCm (AMD), MPS (Apple Silicon), CPU (Intel/AMD/Apple Silicon).
218+
* ✨ Linux, macOS, Windows platform support for CUDA (NVIDIA), NPU (Huawei Ascend), XPU (Intel), ROCm (AMD), MPS (Apple Silicon), CPU (Intel/AMD/Apple Silicon).
220219
*`Dynamic` per-module mixed quantization control: each layer/module can have a unique quantization config or be excluded from quantization.
221220
* 🚀 Intel Torch 2.8 fused kernel support for XPU [`Arc` + `Datacenter Max`] and CPU [`avx`, `amx`].
222221
* 🚀 Python 3.13.3t (free-threading, GIL disabled) support for multi-GPU accelerated quantization for MoE models and multi-core CPU boost for packing.

gptqmodel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def _build_device_thread_pool():
173173
workers={
174174
"cuda:per": 4,
175175
"xpu:per": 1,
176+
"npu:per": 1,
176177
"mps": 8,
177178
"cpu": min(12, max(1, (os.cpu_count() or 1) + 1 // 2)), # count + 1, fixed pool size > 1 check when count=3
178179
"model_loader:cpu": 2,

gptqmodel/looper/forward_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
531531
replica = module_replicas[device]
532532
submitter = (
533533
device_thread_pool.submit_serial
534-
if device.type in ("cuda", "xpu", "mps")
534+
if device.type in ("cuda", "xpu", "npu", "mps")
535535
else device_thread_pool.submit
536536
)
537537

gptqmodel/looper/gptq_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..utils.device import get_device
2525
from ..utils.model import create_quant_module, find_modules, pack_module
2626
from ..utils.module_locks import parent_module_lock
27+
from ..utils.torch import HAS_NPU
2728

2829
log = setup_logger()
2930
lock = threading.Lock()
@@ -276,6 +277,12 @@ def process(
276277
f"CUDA thread context {current_cuda_device} does not match expected device {expected_device} "
277278
f"while processing '{module.full_name}'."
278279
)
280+
if expected_device.type == "npu" and HAS_NPU:
281+
current_npu_device = torch.device("npu", torch.npu.current_device())
282+
assert current_npu_device == expected_device, (
283+
f"NPU thread context {current_npu_device} does not match expected device {expected_device} "
284+
f"while processing '{module.full_name}'."
285+
)
279286

280287
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()
281288

gptqmodel/looper/loop_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..quantization.config import QuantizeConfig
2727
from ..utils.colors import ANSIColor, color_text
2828
from ..utils.logger import setup_logger
29-
from ..utils.torch import CPU, DEVICE_0, DEVICE_1
29+
from ..utils.torch import CPU, DEVICE_0, DEVICE_1, HAS_NPU
3030

3131
log = setup_logger()
3232

@@ -524,6 +524,13 @@ def _discover_accelerator_devices(self) -> List[str]:
524524
except Exception: # pragma: no cover - defensive, XPU runtime differences
525525
pass
526526

527+
if HAS_NPU:
528+
try:
529+
for idx in range(torch.npu.device_count()):
530+
devices.append(f"npu:{idx}")
531+
except Exception: # pragma: no cover - defensive, NPU runtime differences
532+
pass
533+
527534
return devices
528535

529536
def _safe_query_metric(self, device_key: str, handle: Device):

gptqmodel/looper/qqq_processor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor
1313
from ..looper.named_module import NamedModule
1414
from ..models import BaseQModel
15+
from ..models._const import DEVICE
1516
from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
1617
PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
17-
from ..nn_modules.qlinear.qqq import QQQLinear
18+
from ..nn_modules.qlinear.qqq import QQQLinear, QQQTorchLinear
1819
from ..quantization.config import METHOD, QuantizeConfig, resolve_quant_format
1920
from ..utils.fallback import normalize_fallback
2021
from ..quantization.qqq import QQQ
22+
from ..utils.backend import BACKEND
2123
from ..utils.logger import setup_logger, log_time_block
2224
from ..utils.model import create_quant_module, find_modules, move_to, pack_module
2325
from ..utils.torch import CPU
@@ -57,6 +59,16 @@ def __init__(
5759
self.calculate_w_wq_diff = calculate_w_wq_diff
5860
self.avg_losses = []
5961

62+
def _quant_linear_kernel(self):
63+
device = self.qcfg.device
64+
if isinstance(device, DEVICE):
65+
return (QQQTorchLinear, BACKEND.QQQ_TORCH) if device == DEVICE.NPU else (QQQLinear, BACKEND.QQQ)
66+
if isinstance(device, torch.device):
67+
return (QQQTorchLinear, BACKEND.QQQ_TORCH) if device.type == "npu" else (QQQLinear, BACKEND.QQQ)
68+
if isinstance(device, str) and device.split(":")[0].lower() == "npu":
69+
return QQQTorchLinear, BACKEND.QQQ_TORCH
70+
return QQQLinear, BACKEND.QQQ
71+
6072
def set_calibration_dataset(self, calibration_dataset):
6173
"""Rejects dataset replacement because QQQ capture is fixed at construction."""
6274

@@ -252,6 +264,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
252264

253265
layers = find_modules(model.model)
254266
module_label = getattr(module, "full_name", getattr(module, "name", ""))
267+
quant_linear_cls, backend = self._quant_linear_kernel()
255268

256269
# replace module with quantized module
257270
with log_time_block(
@@ -261,7 +274,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
261274
):
262275
create_quant_module(
263276
name=module.full_name,
264-
linear_cls=QQQLinear,
277+
linear_cls=quant_linear_cls,
265278
bits=self.qcfg.runtime_bits,
266279
desc_act=self.qcfg.desc_act,
267280
dynamic=self.qcfg.dynamic,
@@ -273,13 +286,14 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
273286
lm_head_name=model.lm_head,
274287
pack_dtype=self.qcfg.pack_dtype,
275288
format=resolve_quant_format(self.qcfg.format, self.qcfg.method),
289+
backend=backend,
276290
register_buffers=False,
277291
)
278292

279293
# pack module
280294
qModules = {
281295
name: submodule
282-
for name, submodule in find_modules(model.model, [QQQLinear]).items()
296+
for name, submodule in find_modules(model.model, [quant_linear_cls]).items()
283297
if name == module.full_name
284298
}
285299
with log_time_block(
@@ -294,7 +308,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
294308
q_zeros=q_zeros,
295309
q_g_idx=q_g_idx,
296310
layers=layers,
297-
quant_linear_cls=QQQLinear,
311+
quant_linear_cls=quant_linear_cls,
298312
lock=self.lock,
299313
q_scales_extra=q_scales_extra,
300314
quantize_config=self.qcfg,

gptqmodel/models/_const.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313

1414
from ..utils import BACKEND
1515
from ..utils.rocm import IS_ROCM
16-
from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU
16+
from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_NPU, HAS_XPU
17+
18+
19+
def _optional_torch_device(spec: str, fallback: torch.device) -> torch.device:
20+
try:
21+
return device(spec)
22+
except (RuntimeError, ValueError):
23+
return fallback
1724

1825

1926
CPU = device("cpu")
@@ -22,6 +29,8 @@
2229
CUDA_0 = device("cuda:0")
2330
XPU = device("xpu")
2431
XPU_0 = device("xpu:0")
32+
NPU = _optional_torch_device("npu", CPU)
33+
NPU_0 = _optional_torch_device("npu:0", CPU)
2534
MPS = device("mps")
2635
ROCM = device("cuda:0") # rocm maps to fake cuda
2736

@@ -34,6 +43,7 @@ class DEVICE(str, Enum):
3443
CPU = "cpu" # All CPU: Optimized for IPEX is CPU has AVX, AVX512, AMX, or XMX instructions
3544
CUDA = "cuda" # Nvidia GPU: Optimized for Ampere+
3645
XPU = "xpu" # Intel GPU: Datacenter Max + Arc
46+
NPU = "npu" # Ascend NPU: portable Torch kernels
3747
MPS = "mps" # MacOS GPU: Apple Silicon/Metal)
3848
ROCM = "rocm" # AMD GPU: ROCm maps to fake cuda
3949

@@ -54,7 +64,7 @@ def type(self) -> str:
5464
@property
5565
def index(self) -> int | None:
5666
"""Default index used when materialising a torch.device from this enum."""
57-
if self in (DEVICE.CUDA, DEVICE.ROCM, DEVICE.XPU):
67+
if self in (DEVICE.CUDA, DEVICE.ROCM, DEVICE.XPU, DEVICE.NPU):
5868
return 0
5969
return None
6070

@@ -96,6 +106,8 @@ def normalize_device(type_value: str | DEVICE | int | torch.device) -> DEVICE:
96106
return DEVICE.CUDA
97107
elif HAS_XPU:
98108
return DEVICE.XPU
109+
elif HAS_NPU:
110+
return DEVICE.NPU
99111
elif HAS_MPS:
100112
return DEVICE.MPS
101113
else:
@@ -123,6 +135,8 @@ def get_best_device(backend: BACKEND = BACKEND.AUTO) -> torch.device:
123135
return CUDA_0
124136
elif HAS_XPU:
125137
return XPU_0
138+
elif HAS_NPU:
139+
return NPU_0
126140
elif HAS_MPS:
127141
return MPS
128142
else:

gptqmodel/models/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -772,12 +772,20 @@ def quantize(
772772

773773
preferred_backend = requested_backend
774774
if preferred_backend in (None, BACKEND.AUTO):
775+
quant_device = self.quantize_config.device
776+
if isinstance(quant_device, DEVICE):
777+
quant_device_type = quant_device.type
778+
elif isinstance(quant_device, torch.device):
779+
quant_device_type = quant_device.type
780+
else:
781+
quant_device_type = str(quant_device).split(":")[0].lower()
782+
775783
if export_quant_method == METHOD.AWQ:
776784
if format_code == FORMAT.GEMM:
777785
# Weight-only RTN->AWQ export should stay on the portable torch kernel.
778786
preferred_backend = (
779787
BACKEND.AWQ_TORCH
780-
if self.quantize_config.uses_weight_only_lifecycle()
788+
if self.quantize_config.uses_weight_only_lifecycle() or quant_device_type == "npu"
781789
else BACKEND.AWQ_GEMM
782790
)
783791
elif format_code == FORMAT.BITBLAS:
@@ -789,7 +797,7 @@ def quantize(
789797
else:
790798
raise ValueError(f"Unsupported FORMAT: `{self.quantize_config.format}` with `METHOD.AWQ`")
791799
elif self.quantize_config.method == METHOD.QQQ:
792-
preferred_backend = BACKEND.QQQ
800+
preferred_backend = BACKEND.QQQ_TORCH if quant_device_type == "npu" else BACKEND.QQQ
793801
elif self.quantize_config.method == METHOD.PARO:
794802
preferred_backend = BACKEND.PAROQUANT_CUDA
795803
elif self.quantize_config.method == METHOD.EXL3:

gptqmodel/models/loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
make_quant,
7676
simple_dispatch_model,
7777
)
78-
from ._const import DEVICE, normalize_device
78+
from ._const import DEVICE, HAS_NPU, normalize_device
7979

8080

8181
log = setup_logger()
@@ -1161,6 +1161,11 @@ def build_layerwise_device_map(
11611161
device_strs = [f"xpu:{i}" for i in range(num_gpus)]
11621162
else:
11631163
raise RuntimeError("XPU is not available")
1164+
elif device == DEVICE.NPU:
1165+
if HAS_NPU:
1166+
device_strs = [f"npu:{i}" for i in range(num_gpus)]
1167+
else:
1168+
raise RuntimeError("NPU is not available")
11641169
else:
11651170
device_strs = ["cpu"] * num_gpus
11661171

@@ -1311,6 +1316,8 @@ def assign(mod, device_id):
13111316
num_gpus = torch.cuda.device_count()
13121317
elif device is DEVICE.XPU:
13131318
num_gpus = torch.xpu.device_count()
1319+
elif device is DEVICE.NPU:
1320+
num_gpus = torch.npu.device_count()
13141321
device_map = build_layerwise_device_map(model, device, layers, ignore_modules, num_gpus)
13151322
else:
13161323
device_map = dict(explicit_device_map)

gptqmodel/nn_modules/exllamav3_torch.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,33 @@
2626
_EXL3_LOP3_BIAS = 0x3B603B60
2727

2828

29+
def _tensor_core_perm_values() -> list[int]:
30+
perm = [0] * 256
31+
for t in range(32):
32+
r0 = (t % 4) * 2
33+
r1 = r0 + 1
34+
r2 = r0 + 8
35+
r3 = r0 + 9
36+
c0 = t // 4
37+
c1 = c0 + 8
38+
perm[t * 8 + 0] = r0 * 16 + c0
39+
perm[t * 8 + 1] = r1 * 16 + c0
40+
perm[t * 8 + 2] = r2 * 16 + c0
41+
perm[t * 8 + 3] = r3 * 16 + c0
42+
perm[t * 8 + 4] = r0 * 16 + c1
43+
perm[t * 8 + 5] = r1 * 16 + c1
44+
perm[t * 8 + 6] = r2 * 16 + c1
45+
perm[t * 8 + 7] = r3 * 16 + c1
46+
return perm
47+
48+
49+
def _inverse_perm_values(perm: list[int]) -> list[int]:
50+
inv = [0] * len(perm)
51+
for index, value in enumerate(perm):
52+
inv[value] = index
53+
return inv
54+
55+
2956
def _half_scalar_from_bits(bits: int) -> float:
3057
# Convert a uint16 bit pattern to its IEEE-754 binary16 (float16) value
3158
# without allocating a torch tensor. The previous implementation used
@@ -50,29 +77,13 @@ def _half_scalar_from_bits(bits: int) -> float:
5077
@lru_cache(maxsize=None)
5178
def _tensor_core_perm(device_type: str, device_index: int | None) -> torch.Tensor:
5279
device = torch.device(device_type, device_index)
53-
perm = [0] * 256
54-
for t in range(32):
55-
r0 = (t % 4) * 2
56-
r1 = r0 + 1
57-
r2 = r0 + 8
58-
r3 = r0 + 9
59-
c0 = t // 4
60-
c1 = c0 + 8
61-
perm[t * 8 + 0] = r0 * 16 + c0
62-
perm[t * 8 + 1] = r1 * 16 + c0
63-
perm[t * 8 + 2] = r2 * 16 + c0
64-
perm[t * 8 + 3] = r3 * 16 + c0
65-
perm[t * 8 + 4] = r0 * 16 + c1
66-
perm[t * 8 + 5] = r1 * 16 + c1
67-
perm[t * 8 + 6] = r2 * 16 + c1
68-
perm[t * 8 + 7] = r3 * 16 + c1
69-
return torch.tensor(perm, dtype=torch.long, device=device)
80+
return torch.tensor(_tensor_core_perm_values(), dtype=torch.long, device=device)
7081

7182

7283
@lru_cache(maxsize=None)
7384
def _tensor_core_perm_i(device_type: str, device_index: int | None) -> torch.Tensor:
74-
perm = _tensor_core_perm(device_type, device_index)
75-
return torch.argsort(perm)
85+
device = torch.device(device_type, device_index)
86+
return torch.tensor(_inverse_perm_values(_tensor_core_perm_values()), dtype=torch.long, device=device)
7687

7788

7889
@lru_cache(maxsize=None)

0 commit comments

Comments
 (0)