Skip to content

Commit e7375a1

Browse files
committed
Harden int4_plain_mm: dtype checks, scale hoist, docstrings
- Add dtype checks for qdata (uint8/int8), scale (bf16), zero (bf16) in C shim - Hoist weight scale/zero loads outside inner loop (reload only on group change) - Clarify int4_dispatch.py docblock: runs at eager/trace time, not .pte runtime - Clarify test docblock: tests eager dispatch, not C shim runtime
1 parent 6273bb2 commit e7375a1

3 files changed

Lines changed: 40 additions & 12 deletions

File tree

backends/cuda/int4_dispatch.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Int4Tensor F.linear dispatch for CUDA.
8-
9-
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm`` — in eager this
10-
dequants + calls F.linear; in .pte runtime the C shim runs a
11-
W4A8 dp4a matvec kernel.
12-
Prefill (M>4): Inline dequant + F.linear — AOTI compiles this into the .so
13-
using inductor's own cuBLAS codegen, so no explicit cuBLAS
14-
dependency in our shim library.
7+
"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time.
8+
9+
This module overrides Int4Tensor's F.linear dispatch so that torch.export
10+
traces through our custom op and dequant logic instead of torchao's default
11+
(mslk/tinygemm). The code here executes during eager inference and during
12+
AOTI export tracing — it does NOT run at .pte runtime.
13+
14+
At .pte runtime, the captured graph is executed by the AOTI-generated .so:
15+
- The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that
16+
runs the W4A8 dp4a matvec kernel (backends/cuda/runtime/shims/).
17+
- The inline dequant + F.linear is compiled by inductor into fused Triton
18+
dequant + cuBLAS matmul kernels.
19+
20+
Dispatch strategy (determines what gets captured in the export graph):
21+
Decode (M<=4): Custom op ``executorch_cuda::int4_plain_mm``
22+
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)
1523
1624
Import this module before using nn.Linear with Int4Tensor weights::
1725

backends/cuda/runtime/shims/int4_plain_mm.cuh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ __global__ void __launch_bounds__(MV_THREADS)
130130

131131
float sum = 0.0f;
132132

133+
int32_t prev_g = -1;
134+
float ws = 0.0f, wz = 0.0f;
135+
133136
for (int32_t i = lane_id; i < K_half_16; i += MV_WARP_SIZE) {
134137
uint4 packed16 = __ldg(&qrow16[i]);
135138
int32_t k_base = i * 32;
@@ -141,6 +144,12 @@ __global__ void __launch_bounds__(MV_THREADS)
141144
int32_t k_word = k_base + w * 8;
142145
int32_t g = k_word >> gs_shift;
143146

147+
if (g != prev_g) {
148+
ws = __bfloat162float(__ldg(&scale_base[g * scale_stride]));
149+
wz = __bfloat162float(__ldg(&zero_base[g * scale_stride]));
150+
prev_g = g;
151+
}
152+
144153
int32_t vi_lo = packed & 0x0F0F0F0F;
145154
int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F;
146155

@@ -156,8 +165,6 @@ __global__ void __launch_bounds__(MV_THREADS)
156165
int32_t dp = __dp4a(vi_lo, a_even, 0);
157166
dp = __dp4a(vi_hi, a_odd, dp);
158167

159-
float ws = __bfloat162float(__ldg(&scale_base[g * scale_stride]));
160-
float wz = __bfloat162float(__ldg(&zero_base[g * scale_stride]));
161168
float a_scale = qb->d;
162169

163170
int32_t a_sum8 = __dp4a(0x01010101, a_even, 0);
@@ -212,6 +219,11 @@ void _int4_plain_mm_cuda(
212219
int32_t N = qdata.size(0);
213220

214221
ET_CHECK(A.dtype() == c10::ScalarType::BFloat16);
222+
ET_CHECK(
223+
qdata.dtype() == c10::ScalarType::Byte ||
224+
qdata.dtype() == c10::ScalarType::Char);
225+
ET_CHECK(scale.dtype() == c10::ScalarType::BFloat16);
226+
ET_CHECK(zero.dtype() == c10::ScalarType::BFloat16);
215227
ET_CHECK(A.dim() == 2);
216228
ET_CHECK(qdata.dim() == 2);
217229
ET_CHECK(qdata.size(1) == K / 2);

backends/cuda/tests/test_int4_dispatch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,18 @@
77

88
"""Tests for Int4Tensor F.linear dispatch via int4_dispatch.
99
10+
These tests validate the eager / trace-time dispatch path — the same code
11+
that torch.export traces through when building the AOTI graph. They do NOT
12+
test the .pte runtime C shim (dp4a kernel); that is covered by
13+
test_aoti_torch_cuda_int4_plain_mm.cpp (C++ unit tests) and
14+
test_cuda_pipeline.py::TestCudaExport (end-to-end export + lower).
15+
1016
The API contract: after importing int4_dispatch, F.linear and nn.Linear
1117
with Int4Tensor weights produce numerically correct results. Tests verify
12-
this across decode (M=1), prefill (M>1), batched (3D), bias, group sizes,
13-
and symmetric/asymmetric quantization.
18+
this across decode (M<=4), prefill (M>4), batched (3D), bias, group sizes,
19+
and symmetric/asymmetric quantization. Correctness is measured as mean
20+
relative error against the unquantized bf16 reference (not per-element
21+
atol/rtol, which is too strict for INT4 quantization noise).
1422
1523
Usage:
1624
python -m pytest backends/cuda/tests/test_int4_dispatch.py -v

0 commit comments

Comments
 (0)