Skip to content

Commit b777a05

Browse files
author
mnachin
committed
INT4 plain matmul: dp4a decode kernel + dequant dispatch
Adds executorch_cuda::int4_plain_mm custom op that reads Int4Tensor's plain [N, K//2] nibble-packed format directly. C shim (.pte runtime): W4A8 dp4a matvec with dynamic INT8 activation quantization, 16-byte vectorized loads, warp-cooperative quantization. No cuBLAS dependency. Eager dispatch: M<=4 routes through the custom op (dp4a in .pte, dequant + F.linear in eager). M>4 uses inline dequant + F.linear, which AOTI compiles into the .so using inductor's own cuBLAS codegen.
1 parent 78886eb commit b777a05

18 files changed

Lines changed: 1402 additions & 257 deletions

backends/cuda/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
110110
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu runtime/shims/rand.cu
113+
runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu
114+
runtime/shims/rand.cu
114115
)
115116
endif()
116117

backends/cuda/cuda_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
226226
"at::_ops::_weight_int4pack_mm::call": None,
227227
"at::_ops::sort_stable::call": None,
228228
"aoti_torch_cuda_randint_low_out": None,
229+
"executorch_cuda::int4_plain_mm": None,
230+
"aoti_torch_cuda_int4_plain_mm": None,
229231
}
230232

231233
@classmethod
@@ -298,6 +300,20 @@ def get_aoti_compile_options(
298300
"aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel,
299301
}
300302

303+
try:
304+
import torch
305+
306+
options["aot_inductor.custom_ops_to_c_shims"] = {
307+
torch.ops.executorch_cuda.int4_plain_mm.default: [
308+
"AOTITorchError aoti_torch_cuda_int4_plain_mm("
309+
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
310+
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
311+
],
312+
}
313+
except AttributeError:
314+
# int4_dispatch.py not imported — op not registered, skip C shim mapping
315+
pass
316+
301317
# Parse compile_specs to check for platform
302318

303319
platform = "linux"

backends/cuda/int4_dispatch.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Int4Tensor F.linear dispatch for CUDA.
8+
9+
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` — in eager this
10+
dequants + calls F.linear; in .pte runtime the C shim runs a
11+
W4A8 dp4a matvec kernel.
12+
Prefill (M>4): Inline dequant + F.linear — AOTI compiles this into the .so
13+
using inductor's own cuBLAS codegen, so no explicit cuBLAS
14+
dependency in our shim library.
15+
16+
Import this module before using nn.Linear with Int4Tensor weights::
17+
18+
import executorch.backends.cuda.int4_dispatch # noqa: F401
19+
"""
20+
21+
import torch
22+
import torch.nn.functional as F
23+
from torch.library import impl, Library
24+
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
25+
26+
# ---------------------------------------------------------------------------
27+
# Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager
28+
# ---------------------------------------------------------------------------
29+
30+
_lib = Library("executorch_cuda", "DEF")
31+
_lib.define(
32+
"int4_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor"
33+
)
34+
35+
36+
@impl(_lib, "int4_plain_mm", "Meta")
37+
def _meta(self, qdata, scale, zero, group_size):
38+
return torch.empty(
39+
self.shape[0], qdata.shape[0], dtype=self.dtype, device=self.device
40+
)
41+
42+
43+
@impl(_lib, "int4_plain_mm", "CUDA")
44+
def _cuda(self, qdata, scale, zero, group_size):
45+
return _dequant_matmul(self, qdata, scale, zero, group_size)
46+
47+
48+
def _dequant_matmul(x, qdata, scale, zero, group_size):
49+
"""Dequant INT4 weights to input dtype and call F.linear."""
50+
N, K_half = qdata.shape
51+
K = K_half * 2
52+
n_groups = K // group_size
53+
gs_half = group_size // 2
54+
dtype = x.dtype
55+
56+
p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
57+
low = (p & 0x0F).to(dtype)
58+
high = ((p >> 4) & 0x0F).to(dtype)
59+
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)
60+
61+
s = scale.to(dtype).t().unsqueeze(-1)
62+
z = zero.to(dtype).t().unsqueeze(-1)
63+
w_deq = ((data - z) * s).reshape(N, K)
64+
65+
return F.linear(x, w_deq)
66+
67+
68+
# ---------------------------------------------------------------------------
69+
# Int4Tensor F.linear dispatch
70+
# ---------------------------------------------------------------------------
71+
72+
aten = torch.ops.aten
73+
_implements = Int4Tensor.implements
74+
_implements_torch_function = Int4Tensor.implements_torch_function
75+
76+
77+
@_implements([aten.linear.default])
78+
@_implements_torch_function([F.linear])
79+
def _(func, types, args, kwargs):
80+
input_tensor = args[0]
81+
weight_tensor = args[1]
82+
bias = args[2] if len(args) > 2 else None
83+
84+
orig_shape = input_tensor.shape
85+
x_2d = input_tensor.reshape(-1, orig_shape[-1])
86+
87+
qdata = weight_tensor.qdata
88+
scale = weight_tensor.scale
89+
zero = weight_tensor.zero_point
90+
gs = weight_tensor.block_size[-1]
91+
92+
M = x_2d.shape[0]
93+
if M <= 4:
94+
out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs)
95+
else:
96+
out = _dequant_matmul(x_2d, qdata, scale, zero, gs)
97+
98+
out = out.reshape(*orig_shape[:-1], -1)
99+
if bias is not None:
100+
out = out + bias
101+
return out
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.h>
14+
#include <executorch/backends/cuda/runtime/shims/int4_plain_mm.cuh>
15+
#include <executorch/backends/cuda/runtime/shims/memory.h>
16+
#include <executorch/runtime/platform/log.h>
17+
18+
namespace executorch::backends::cuda {
19+
#ifdef __cplusplus
20+
extern "C" {
21+
#endif
22+
23+
AOTITorchError aoti_torch_cuda_int4_plain_mm(
24+
Tensor* self,
25+
Tensor* qdata,
26+
Tensor* scale,
27+
Tensor* zero,
28+
int64_t group_size,
29+
Tensor** ret0) {
30+
ET_CHECK_OR_RETURN_ERROR(
31+
self != nullptr,
32+
InvalidArgument,
33+
"aoti_torch_cuda_int4_plain_mm: self is null");
34+
35+
ET_CHECK_OR_RETURN_ERROR(
36+
qdata != nullptr,
37+
InvalidArgument,
38+
"aoti_torch_cuda_int4_plain_mm: qdata is null");
39+
40+
ET_CHECK_OR_RETURN_ERROR(
41+
scale != nullptr,
42+
InvalidArgument,
43+
"aoti_torch_cuda_int4_plain_mm: scale is null");
44+
45+
ET_CHECK_OR_RETURN_ERROR(
46+
zero != nullptr,
47+
InvalidArgument,
48+
"aoti_torch_cuda_int4_plain_mm: zero is null");
49+
50+
ET_CHECK_OR_RETURN_ERROR(
51+
ret0 != nullptr,
52+
InvalidArgument,
53+
"aoti_torch_cuda_int4_plain_mm: ret0 is null");
54+
55+
int32_t M = self->size(0);
56+
int32_t N = qdata->size(0);
57+
Tensor* C = nullptr;
58+
std::array<int64_t, 2> c_shape = {M, N};
59+
std::array<int64_t, 2> c_stride = {N, 1};
60+
aoti_torch_empty_strided(
61+
2,
62+
c_shape.data(),
63+
c_stride.data(),
64+
static_cast<int32_t>(
65+
executorch::backends::aoti::slim::c10::ScalarType::BFloat16),
66+
static_cast<int32_t>(
67+
executorch::backends::aoti::slim::c10::DeviceType::CUDA),
68+
0,
69+
&C);
70+
71+
_int4_plain_mm_cuda(*self, *qdata, *scale, *zero, group_size, C);
72+
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
73+
74+
*ret0 = C;
75+
return Error::Ok;
76+
}
77+
78+
#ifdef __cplusplus
79+
}
80+
#endif
81+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)