diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 4fb4a04f296..4476a403540 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -419,6 +419,7 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then python -m executorch.examples.models.qwen3_5_moe.export \ --prequantized "$LOCAL_MODEL_DIR" \ --output-dir "${OUTPUT_DIR}" \ + --dense-prefill dequant \ --moe-activation-dtype int8 echo "::endgroup::" diff --git a/backends/cuda/tests/test_int4_matmul.py b/backends/cuda/tests/test_int4_matmul.py new file mode 100644 index 00000000000..895da240413 --- /dev/null +++ b/backends/cuda/tests/test_int4_matmul.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functional correctness tests for INT4 matmul and dequant Triton kernels. + +Tests both int4_matmul (fused W4A16 GEMM) and dequant_w4_to_bf16 (weight +dequantization) against eager PyTorch references. Uses 0.01 absolute +tolerance to account for INT4 quantization noise and bf16 rounding. + +Usage: + python -m pytest backends/cuda/tests/test_int4_matmul.py -v +""" + +import unittest + +import torch +import torch.nn as nn + +from executorch.backends.cuda.triton.kernels.int4_matmul import ( + dequant_w4_to_bf16, + int4_matmul, + int4_matvec, +) + +ATOL = 0.01 +DEVICE = "cuda" + + +def _quantize_simple(w_bf16, group_size): + """Quantize [N, K] bf16 weight to simple packed INT4 + per-group scales. + + Returns: + w_packed: [N, K//2] int8 — two INT4 values per byte + w_scale: [N, K//group_size] bf16 — symmetric scales + w_ref: [N, K] bf16 — dequantized reference matching kernel's computation + """ + N, K = w_bf16.shape + w = w_bf16.float() + w_grouped = w.reshape(N, K // group_size, group_size) + scale = w_grouped.abs().amax(dim=-1, keepdim=True) / 7.0 + scale = scale.clamp(min=1e-10) + int_data = (w_grouped / scale).round().clamp(-8, 7).to(torch.int8) + # Kernel dequant: (uint4 - 8) * scale = int_data * scale + scale_bf16 = scale.to(torch.bfloat16) + w_ref = ((int_data.float()) * scale_bf16.float()).reshape(N, K).to(torch.bfloat16) + scale_bf16 = scale_bf16.reshape(N, K // group_size) + int_data = int_data.reshape(N, K) + uint4 = (int_data + 8).to(torch.int16) + packed = (uint4[:, 0::2] | (uint4[:, 1::2] << 4)).to(torch.int8) + return packed.to(DEVICE), scale_bf16.to(DEVICE), w_ref.to(DEVICE) + + +def _eager_int4_matmul(x, w_ref): + """Reference matmul: x @ w_ref.T in float32, cast to bf16.""" + return (x.float() @ w_ref.float().T).to(torch.bfloat16) + + +class TestDequantW4ToBf16(unittest.TestCase): + """Tests for dequant_w4_to_bf16 Triton kernel.""" + + def _run_dequant(self, N, K, group_size): + torch.manual_seed(42) + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) + packed, scale, w_ref = _quantize_simple(w, group_size) + + out = dequant_w4_to_bf16(packed, scale, group_size) + + self.assertEqual(out.shape, (N, K)) + self.assertEqual(out.dtype, torch.bfloat16) + max_err = (out.float() - w_ref.float()).abs().max().item() + self.assertLess( + max_err, ATOL, f"dequant [{N}x{K}] gs={group_size}: max_err={max_err}" + ) + + def test_square(self): + self._run_dequant(256, 256, 32) + + def test_tall(self): + self._run_dequant(2048, 256, 32) + + def test_wide(self): + self._run_dequant(256, 2048, 128) + + def test_production_qkv(self): + self._run_dequant(2048, 2048, 128) + + def test_production_shared_expert(self): + self._run_dequant(1024, 2048, 128) + + def test_group_size_32(self): + self._run_dequant(512, 512, 32) + + def test_group_size_128(self): + self._run_dequant(512, 2048, 128) + + def test_non_power_of_two_N(self): + self._run_dequant(12352, 2048, 128) + + def test_small(self): + self._run_dequant(16, 64, 32) + + +class TestInt4Matmul(unittest.TestCase): + """Tests for int4_matmul Triton kernel (fused W4A16 GEMM).""" + + def _run_matmul(self, M, N, K, group_size): + torch.manual_seed(42) + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) + packed, scale, w_ref = _quantize_simple(w, group_size) + x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) + + out = int4_matmul(x, packed, scale, group_size) + ref = _eager_int4_matmul(x, w_ref) + + self.assertEqual(out.shape, (M, N)) + self.assertEqual(out.dtype, torch.bfloat16) + self.assertTrue( + torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01), + f"int4_matmul M={M} [{N}x{K}] gs={group_size}: " + f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, " + f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}", + ) + + # --- Decode (M=1) --- + def test_decode_square(self): + self._run_matmul(1, 256, 256, 32) + + def test_decode_qkv(self): + self._run_matmul(1, 2048, 2048, 128) + + def test_decode_kv_proj(self): + self._run_matmul(1, 256, 2048, 128) + + def test_decode_shared_expert(self): + self._run_matmul(1, 1024, 2048, 128) + + def test_decode_large_N(self): + self._run_matmul(1, 12352, 2048, 128) + + # --- Small prefill --- + def test_prefill_4(self): + self._run_matmul(4, 2048, 2048, 128) + + def test_prefill_16(self): + self._run_matmul(16, 2048, 2048, 128) + + def test_prefill_64(self): + self._run_matmul(64, 2048, 2048, 128) + + # --- Large prefill --- + def test_prefill_256(self): + self._run_matmul(256, 2048, 2048, 128) + + def test_prefill_1024(self): + self._run_matmul(1024, 2048, 2048, 128) + + def test_prefill_4095(self): + self._run_matmul(4095, 2048, 2048, 128) + + # --- Edge cases --- + def test_group_size_32(self): + self._run_matmul(4, 512, 512, 32) + + def test_non_power_of_two_M(self): + self._run_matmul(7, 256, 256, 32) + + def test_non_power_of_two_N(self): + self._run_matmul(4, 12352, 2048, 128) + + def test_small(self): + self._run_matmul(1, 16, 64, 32) + + +class TestInt4Matvec(unittest.TestCase): + """Tests for int4_matvec Triton kernel (M=1 decode).""" + + def _run_matvec(self, N, K, group_size): + torch.manual_seed(42) + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) + packed, scale, w_ref = _quantize_simple(w, group_size) + x = torch.randn(K, dtype=torch.bfloat16, device=DEVICE) + + out = int4_matvec(x.unsqueeze(0), packed, scale, group_size) + ref = int4_matmul(x.unsqueeze(0), packed, scale, group_size) + + self.assertEqual(out.shape, (1, N)) + self.assertEqual(out.dtype, torch.bfloat16) + # atol=1.0 for large accumulation across K, rtol=0.01 for relative + self.assertTrue( + torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01), + f"int4_matvec [{N}x{K}] gs={group_size}: " + f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, " + f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}", + ) + + def test_qkv_proj(self): + self._run_matvec(2048, 2048, 128) + + def test_kv_proj(self): + self._run_matvec(256, 2048, 128) + + def test_shared_expert(self): + self._run_matvec(1024, 2048, 128) + + def test_large_N(self): + self._run_matvec(12352, 2048, 128) + + def test_group_size_32(self): + self._run_matvec(512, 512, 32) + + def test_small(self): + self._run_matvec(16, 64, 32) + + def test_matches_int4_matmul(self): + """Matvec output matches int4_matmul at M=1.""" + torch.manual_seed(42) + N, K, gs = 2048, 2048, 128 + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) + packed, scale, _ = _quantize_simple(w, gs) + x = torch.randn(1, K, dtype=torch.bfloat16, device=DEVICE) + + out_mv = int4_matvec(x, packed, scale, gs) + out_mm = int4_matmul(x, packed, scale, gs) + + self.assertTrue( + torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01), + f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}", + ) + + +class TestDequantThenMatmul(unittest.TestCase): + """Tests that dequant + F.linear matches int4_matmul (both paths should agree).""" + + def _run(self, M, N, K, group_size): + torch.manual_seed(42) + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) + packed, scale, w_ref = _quantize_simple(w, group_size) + x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) + + # Path A: fused int4_matmul + out_fused = int4_matmul(x, packed, scale, group_size) + + # Path B: dequant + F.linear + w_bf16 = dequant_w4_to_bf16(packed, scale, group_size) + out_dequant = torch.nn.functional.linear(x, w_bf16) + + self.assertTrue( + torch.allclose( + out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01 + ), + f"fused vs dequant M={M} [{N}x{K}]: " + f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}", + ) + + def test_decode(self): + self._run(1, 2048, 2048, 128) + + def test_prefill_short(self): + self._run(64, 2048, 2048, 128) + + def test_prefill_long(self): + self._run(1024, 2048, 2048, 128) + + def test_large_N(self): + self._run(4, 12352, 2048, 128) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index 85593caa4a1..8e7e7c58862 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -10,10 +10,16 @@ fused_moe_batched_gemm, moe_align_block_size, ) + +from executorch.backends.cuda.triton.kernels.int4_matmul import ( + dequant_w4_to_bf16, + int4_matvec, +) from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ + "dequant_w4_to_bf16", "fused_moe", "fused_moe_batched", "fused_moe_batched_gemm", diff --git a/backends/cuda/triton/kernels/int4_matmul.py b/backends/cuda/triton/kernels/int4_matmul.py new file mode 100644 index 00000000000..0e7ea916ec1 --- /dev/null +++ b/backends/cuda/triton/kernels/int4_matmul.py @@ -0,0 +1,530 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton W4A16 matmul kernel for dense linear projections. + +Replaces PyTorch's _weight_int4pack_mm (tinygemm) for prefill where large M +makes tinygemm's 16×8 tiles inefficient. Uses 128×128+ tiles for better +tensor core utilization. + +Weight format: same as fused_moe experts: + w_packed: [N, K//2] int8 — two INT4 values packed per byte + w_scale: [N, K//group_size] bf16 — symmetric dequant: (uint4 - 8) * scale + +Registered as triton_op for AOTInductor export. +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +# -- Autotune configs --------------------------------------------------------- + +_INT4_MATMUL_CONFIGS = [ + # Large-M prefill configs (tensor core saturated) + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=8, + num_stages=3, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_warps=8, + num_stages=3, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_warps=4, + num_stages=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=4, + ), + # Small-M decode configs (bandwidth-bound, wide N tiles) + triton.Config( + {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=5, + ), + triton.Config( + {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=5, + ), + triton.Config( + {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_warps=4, + num_stages=5, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, + num_warps=4, + num_stages=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_warps=4, + num_stages=4, + ), +] + + +# -- Triton kernel ------------------------------------------------------------ + + +@triton.autotune(configs=_INT4_MATMUL_CONFIGS, key=["M", "N", "K"]) +@triton.jit +def _int4_matmul_kernel( + # Pointers + A, # [M, K] bf16 activations + B, # [N, K//2] int8 packed INT4 weights + C, # [M, N] bf16 output + B_scale, # [N, K//group_size] bf16 per-group scales + # Dimensions + M, + N: tl.constexpr, + K: tl.constexpr, + # Strides + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + stride_bsn, + stride_bsk, + # Config + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """W4A16 matmul: C[M,N] = A[M,K] × dequant(B[N,K//2]).T + + Each program computes one (BLOCK_M, BLOCK_N) output tile. + INT4 weights are unpacked and dequantized per-group inside the K-loop. + """ + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + m_block = pid // num_n_blocks + n_block = pid % num_n_blocks + + # M and N offsets for this block + offs_m = m_block * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + m_mask = offs_m < M + n_mask = offs_n < N + + # A pointers: [BLOCK_M, BLOCK_K] — rows of activations + a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + + # B pointers: [BLOCK_K, BLOCK_N] — weight is [N, K//2], need transposed access + # B is stored as [N, K//2] with two INT4 per byte along K dim. + # For the dot product A[M,K] @ B_dequant[K,N], we need B transposed: + # b_ptrs indexes as B[offs_n, offs_k//2], then we read [BLOCK_K//2, BLOCK_N] + # and reshape to [BLOCK_K, BLOCK_N] after unpacking. + b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk + b_shifter = (offs_k[:, None] % 2) * 4 + + # Accumulator [BLOCK_M, BLOCK_N] in float32 + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + + # Load A tile [BLOCK_M, BLOCK_K] + a = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0) + + # Load B tile [BLOCK_K, BLOCK_N] and unpack INT4 + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + + # Per-group scale dequantization + if BLOCK_SIZE_K <= group_size: + # One scale per column per tile — broadcast + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = B_scale + offs_n[None, :] * stride_bsn + group_idx * stride_bsk + b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to( + tl.float32 + ) + else: + # Multiple groups per tile — per-element scale + scale_ptrs = ( + B_scale + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) + + # Dequantize: (uint4 - 8) * scale → bf16 + b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(tl.bfloat16) + + # Tensor core matmul: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] → f32 acc + acc += tl.dot(a.to(tl.bfloat16), b_dequant) + + # Advance K pointers + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + # Write output [BLOCK_M, BLOCK_N] + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc.to(tl.bfloat16), mask=m_mask[:, None] & n_mask[None, :]) + + +# -- triton_op wrapper -------------------------------------------------------- + + +@triton_op("triton::int4_matmul", mutates_args={}) +def int4_matmul( + x: torch.Tensor, + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + """W4A16 matmul: output = x @ dequant(w_packed).T + + Args: + x: [M, K] bf16 activations + w_packed: [N, K//2] int8 — two INT4 values packed per byte + w_scale: [N, K//group_size] bf16 per-group scales + group_size: quantization group size + + Returns: + [M, N] bf16 + """ + M, K = x.shape + N = w_packed.shape[0] + + assert x.dtype == torch.bfloat16 + assert w_packed.dtype == torch.int8 + assert w_scale.dtype == torch.bfloat16 + assert w_packed.shape == ( + N, + K // 2, + ), f"w_packed shape {w_packed.shape} != ({N}, {K // 2})" + assert w_scale.shape == ( + N, + K // group_size, + ), f"w_scale shape {w_scale.shape} != ({N}, {K // group_size})" + + output = torch.empty(M, N, dtype=torch.bfloat16, device=x.device) + + grid = ( + triton.cdiv(M, 128) * triton.cdiv(N, 128), + ) # placeholder, autotune picks BLOCK sizes + + def _grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + wrap_triton(_int4_matmul_kernel)[_grid]( + x, + w_packed, + output, + w_scale, + M, + N, + K, + x.stride(0), + x.stride(1), + w_packed.stride(0), + w_packed.stride(1), + output.stride(0), + output.stride(1), + w_scale.stride(0), + w_scale.stride(1), + group_size, + ) + return output + + +@int4_matmul.register_fake +def _int4_matmul_fake( + x: torch.Tensor, + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + M, K = x.shape + N = w_packed.shape[0] + return torch.empty(M, N, dtype=torch.bfloat16, device=x.device) + + +# -- Dequant W4 → BF16 kernel ------------------------------------------------ + + + + +# -- INT4 matvec kernel (M=1 decode) ------------------------------------------ + +_MATVEC_CONFIGS = [ + triton.Config({"BLOCK_N": 4, "BLOCK_K": 128}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_N": 8, "BLOCK_K": 128}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_N": 8, "BLOCK_K": 256}, num_warps=2, num_stages=3), + triton.Config({"BLOCK_N": 4, "BLOCK_K": 256}, num_warps=2, num_stages=3), +] + + +@triton.autotune(configs=_MATVEC_CONFIGS, key=["N", "K"]) +@triton.jit +def _int4_matvec_kernel( + X, # [K] bf16 input vector + W, # [N, K//2] int8 packed INT4 weights + Out, # [N] bf16 output + W_scale, # [N, K//group_size] bf16 per-group scales + N: tl.constexpr, + K: tl.constexpr, + stride_wn, + stride_wk, + stride_sn, + stride_sk, + group_size: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + nb = tl.program_id(0) + offs_n = nb * BLOCK_N + tl.arange(0, BLOCK_N) + nm = offs_n < N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for ks in range(tl.cdiv(K, BLOCK_K)): + abs_k = ks * BLOCK_K + offs_k + km = abs_k < K + + x_val = tl.load(X + abs_k, mask=km, other=0.0).to(tl.float32) + + w_ptrs = W + offs_n[:, None] * stride_wn + (abs_k[None, :] // 2) * stride_wk + w_shift = (abs_k[None, :] % 2) * 4 + w_raw = tl.load(w_ptrs, mask=nm[:, None] & km[None, :], other=0) + w_uint4 = (w_raw >> w_shift) & 0xF + + if BLOCK_K <= group_size: + gi = (ks * BLOCK_K) // group_size + scale = tl.load( + W_scale + offs_n * stride_sn + gi * stride_sk, mask=nm, other=0.0 + ).to(tl.float32) + w_dq = (w_uint4.to(tl.float32) - 8.0) * scale[:, None] + else: + scale_ptrs = ( + W_scale + + offs_n[:, None] * stride_sn + + (abs_k[None, :] // group_size) * stride_sk + ) + scale = tl.load(scale_ptrs, mask=nm[:, None] & km[None, :], other=0.0).to( + tl.float32 + ) + w_dq = (w_uint4.to(tl.float32) - 8.0) * scale + + acc += tl.sum(w_dq * x_val[None, :], axis=1) + + tl.store(Out + offs_n, acc.to(tl.bfloat16), mask=nm) + + +@triton_op("triton::int4_matvec", mutates_args={}) +def int4_matvec( + x: torch.Tensor, + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + """W4A16 matvec for M=1 decode: out[1,N] = x[1,K] @ dequant(w[N,K]).T + + Args: + x: [1, K] bf16 input (M=1) + w_packed: [N, K//2] int8 packed INT4 weights + w_scale: [N, K//group_size] bf16 per-group scales + group_size: quantization group size + + Returns: + [1, N] bf16 + """ + K = x.shape[-1] + N = w_packed.shape[0] + + output = torch.empty(1, N, dtype=torch.bfloat16, device=x.device) + + def _grid(meta): + return (triton.cdiv(N, meta["BLOCK_N"]),) + + wrap_triton(_int4_matvec_kernel)[_grid]( + x.reshape(-1), + w_packed, + output.reshape(-1), + w_scale, + N, + K, + w_packed.stride(0), + w_packed.stride(1), + w_scale.stride(0), + w_scale.stride(1), + group_size, + ) + return output + + +@int4_matvec.register_fake +def _int4_matvec_fake( + x: torch.Tensor, + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + N = w_packed.shape[0] + return torch.empty(1, N, dtype=torch.bfloat16, device=x.device) + + +# -- Dequant W4 → BF16 kernel ------------------------------------------------ + +_DEQUANT_CONFIGS = [ + triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_N": 64, "BLOCK_K": 256}, num_warps=4, num_stages=3), +] + + +@triton.autotune(configs=_DEQUANT_CONFIGS, key=["N", "K"]) +@triton.jit +def _dequant_w4_to_bf16_kernel( + # Pointers + W_packed, # [N, K//2] int8 packed INT4 weights + W_scale, # [N, K//group_size] bf16 per-group scales + Out, # [N, K] bf16 output + # Dimensions + N: tl.constexpr, + K: tl.constexpr, + # Strides + stride_wn, + stride_wk, + stride_sn, + stride_sk, + stride_on, + stride_ok, + # Config + group_size: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Dequantize packed INT4 weights to BF16. + + Each program processes a (BLOCK_N, BLOCK_K) tile of the output [N, K]. + INT4 pairs are unpacked from [N, K//2] int8 and dequantized per-group. + """ + pid = tl.program_id(0) + num_k_blocks = tl.cdiv(K, BLOCK_K) + n_block = pid // num_k_blocks + k_block = pid % num_k_blocks + + offs_n = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = k_block * BLOCK_K + tl.arange(0, BLOCK_K) + n_mask = offs_n < N + k_mask = offs_k < K + + # Load packed bytes [BLOCK_N, BLOCK_K//2] — each byte has two int4 values + packed_ptrs = ( + W_packed + offs_n[:, None] * stride_wn + (offs_k[None, :] // 2) * stride_wk + ) + packed = tl.load(packed_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0) + + # Unpack: even k → low nibble, odd k → high nibble + shift = (offs_k[None, :] % 2) * 4 + uint4 = (packed >> shift) & 0xF + + # Load per-group scales [BLOCK_N, ceil(BLOCK_K/group_size)] + scale_ptrs = ( + W_scale + + offs_n[:, None] * stride_sn + + (offs_k[None, :] // group_size) * stride_sk + ) + scale = tl.load(scale_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0) + + # Dequantize: (uint4 - 8) * scale → bf16 + result = ((uint4.to(tl.float32) - 8.0) * scale.to(tl.float32)).to(tl.bfloat16) + + # Store [BLOCK_N, BLOCK_K] + out_ptrs = Out + offs_n[:, None] * stride_on + offs_k[None, :] * stride_ok + tl.store(out_ptrs, result, mask=n_mask[:, None] & k_mask[None, :]) + + +# -- triton_op wrapper -------------------------------------------------------- + + +@triton_op("triton::dequant_w4_to_bf16", mutates_args={}) +def dequant_w4_to_bf16( + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + """Dequantize packed INT4 weights to BF16. + + Args: + w_packed: [N, K//2] int8 — two INT4 values packed per byte + w_scale: [N, K//group_size] bf16 per-group scales + group_size: quantization group size + + Returns: + [N, K] bf16 dequantized weight matrix + """ + N, K_half = w_packed.shape + K = K_half * 2 + + output = torch.empty(N, K, dtype=torch.bfloat16, device=w_packed.device) + + def _grid(meta): + return (triton.cdiv(N, meta["BLOCK_N"]) * triton.cdiv(K, meta["BLOCK_K"]),) + + wrap_triton(_dequant_w4_to_bf16_kernel)[_grid]( + w_packed, + w_scale, + output, + N, + K, + w_packed.stride(0), + w_packed.stride(1), + w_scale.stride(0), + w_scale.stride(1), + output.stride(0), + output.stride(1), + group_size, + ) + return output + + +@dequant_w4_to_bf16.register_fake +def _dequant_w4_to_bf16_fake( + w_packed: torch.Tensor, + w_scale: torch.Tensor, + group_size: int, +) -> torch.Tensor: + N, K_half = w_packed.shape + K = K_half * 2 + return torch.empty(N, K, dtype=torch.bfloat16, device=w_packed.device) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 8e12d0236dd..5dccb1dfcc4 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -21,6 +21,7 @@ FusedMoEExperts, Qwen35MoE, Qwen35MoEConfig, + W4DequantLinear, ) @@ -465,6 +466,137 @@ def _quantize(model, config, args): print(f"Quantized linear layers ({args.qlinear})") +def _replace_dense_with_w4dequant(model, group_size=128, use_hqq=False): + """Replace quantized dense linears with W4DequantLinear. + + Dequantizes Int4TilePackedTo4dTensor weights to BF16, re-quantizes to + simple [N, K//2] packed INT4 format (same as MoE experts), and wraps in + W4DequantLinear for dual decode/prefill dispatch. + + MoE expert weights (FusedMoEExperts) are left unchanged. + """ + from torch.nn import functional as F + + if use_hqq: + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + else: + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + count = 0 + + def _infer_group_size(linear): + """Infer quantization group size from the weight's metadata. + + For prequantized bundles, the weight may be a torchao + AffineQuantizedTensor whose block_size encodes the original group + size. Trusting the CLI default would silently re-quantize to the + wrong group size (e.g. gs32 when the bundle used gs128). + """ + w = linear.weight + if hasattr(w, "block_size"): + return w.block_size[-1] + return None + + def _convert_one(linear): + nonlocal count + N, K = linear.out_features, linear.in_features + + effective_gs = _infer_group_size(linear) + if effective_gs is not None and effective_gs != group_size: + print( + f"\n Detected group_size={effective_gs} from quantized weight " + f"(overriding CLI value {group_size})" + ) + elif effective_gs is None: + effective_gs = group_size + + linear_cuda = linear.cuda() + with torch.no_grad(): + eye = torch.eye(K, dtype=torch.bfloat16, device="cuda") + w_bf16 = F.linear(eye, linear_cuda.weight).T.contiguous() + del linear_cuda + + w_float = w_bf16.float() + del w_bf16 + if use_hqq: + int_data, scale = _choose_qparams_and_quantize_scale_only_hqq( + w_float, + block_size=[1, effective_gs], + qmin=-8, + qmax=7, + ) + int_data = int_data.to(torch.int8).view(N, K) + scale = scale.view(N, -1) + else: + block_size = (1, effective_gs) + scale, zero_point = choose_qparams_affine( + w_float, + MappingType.SYMMETRIC, + block_size, + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + ) + int_data = quantize_affine( + w_float, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=-8, + quant_max=7, + ) + scale = scale.reshape(N, -1) + del w_float + + uint4 = (int_data + 8).to(torch.int16) + packed = (uint4[:, 0::2] | (uint4[:, 1::2] << 4)).to(torch.int8) + del int_data, uint4 + + w4 = W4DequantLinear(K, N, effective_gs) + w4.w_packed = packed.cpu() + w4.w_scale = scale.to(torch.bfloat16).cpu() + count += 1 + torch.cuda.empty_cache() + return w4 + + replacements = [] + for parent in model.modules(): + if isinstance(parent, FusedMoEExperts): + continue + for name, child in parent.named_children(): + if isinstance(child, nn.Linear): + replacements.append((parent, name, child)) + + if not replacements: + raise RuntimeError( + "--dense-prefill=dequant found no nn.Linear modules to convert. " + "Ensure the model has W4-quantized dense layers." + ) + + for i, (parent, name, linear) in enumerate(replacements): + setattr(parent, name, _convert_one(linear)) + print( + f" Converted {name} ({i + 1}/{len(replacements)})", + end="\r", + ) + print() + print(f"Replaced {count} dense linears with W4DequantLinear") + + +def _set_dequant_prefill(model, enabled): + """Toggle dequant+BF16 prefill path for all W4DequantLinear modules.""" + for mod in model.modules(): + if isinstance(mod, W4DequantLinear): + mod.use_dequant_prefill = enabled + + def _materialize_buffers(model, config): """Materialize meta-device buffers before torch.export. @@ -766,6 +898,7 @@ def _export_cuda(model, config, args): # --- Decode method (T=1, static shape, vec-mat MoE kernel) --- _set_batched_moe(model, False) + _set_dequant_prefill(model, False) print("Exporting decode method...") decode_tokens = torch.tensor([[0]], dtype=torch.long) decode_pos = torch.tensor([0], dtype=torch.long) @@ -785,6 +918,8 @@ def _export_cuda(model, config, args): # that reject longer prompts at runtime. moe_activation_dtype = getattr(args, "moe_activation_dtype", "bf16") _set_batched_moe(model, True, moe_activation_dtype=moe_activation_dtype) + dense_prefill = getattr(args, "dense_prefill", "tinygemm") + _set_dequant_prefill(model, dense_prefill == "dequant") print("Exporting prefill method...") example_prefill_len = config.max_seq_len - 1 @@ -954,6 +1089,14 @@ def main(): # noqa: C901 default="bf16", help="MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores.", ) + parser.add_argument( + "--dense-prefill", + choices=["tinygemm", "dequant"], + default="tinygemm", + help="Dense linear prefill kernel. Decode always uses int4_matvec (Triton W4A16 vec-mat). " + "tinygemm (default): W4A16 _weight_int4pack_mm. " + "dequant: dequant W4→BF16 + cuBLAS GEMM.", + ) args = parser.parse_args() if args.model_id: @@ -988,12 +1131,31 @@ def main(): # noqa: C901 if args.qlinear == "fpa4w" and args.backend != "metal": parser.error("--qlinear=fpa4w can only be used with --backend=metal") + if args.dense_prefill == "dequant": + if args.backend != "cuda": + parser.error("--dense-prefill dequant requires --backend cuda") + if not args.prequantized and args.qlinear != "4w": + parser.error( + "--dense-prefill dequant requires --qlinear=4w or --prequantized " + "(dense weights must be W4 quantized)" + ) + + if args.moe_activation_dtype != "bf16" and args.backend != "cuda": + parser.error("--moe-activation-dtype int8 requires --backend cuda") + model, config = load_and_quantize(args) if args.backend == "cuda": _materialize_buffers(model, config) if args.turboquant: _apply_turboquant(model, config) + if args.dense_prefill == "dequant": + print("Converting dense linears to W4DequantLinear...") + _replace_dense_with_w4dequant( + model, + group_size=getattr(args, "qlinear_group_size", 32), + use_hqq=getattr(args, "hqq", False), + ) export_and_lower(model, config, args) diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index f187ddb8c15..f7363179375 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -537,6 +537,47 @@ def forward(self, x, expert_weights, expert_indices, top_k): ) +class W4DequantLinear(nn.Module): + """Dense W4 linear with dual decode/prefill dispatch. + + Replaces tinygemm-format dense linears with simple [N, K//2] packed INT4 + weights (same format as MoE experts). The prefill/decode path is baked at + export time via use_dequant_prefill: + + False → decode path: Triton int4_matvec (bandwidth-optimized vec-mat) + True → prefill path: dequant_w4_to_bf16 + F.linear (Inductor Triton mm) + """ + + def __init__(self, in_features, out_features, group_size=32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.use_dequant_prefill = False + self.register_buffer("w_packed", None) # [N, K//2] int8 + self.register_buffer("w_scale", None) # [N, K//gs] bf16 + + def forward(self, x): + orig_shape = x.shape + x_2d = x.reshape(-1, self.in_features) + + if self.use_dequant_prefill: + w_bf16 = torch.ops.triton.dequant_w4_to_bf16( + self.w_packed, self.w_scale, self.group_size + ) + out = F.linear(x_2d, w_bf16) + else: + assert x_2d.shape[0] == 1, ( + f"int4_matvec decode path requires M=1, got M={x_2d.shape[0]}. " + f"Set use_dequant_prefill=True for M>1." + ) + out = torch.ops.triton.int4_matvec( + x_2d, self.w_packed, self.w_scale, self.group_size + ) + + return out.reshape(*orig_shape[:-1], self.out_features) + + class SwiGLU(nn.Module): """SwiGLU MLP with fused gate+up projection."""