Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit 83d1f4b

Browse files
New kernels 0222 (#12)
* New kernels 0222 * Test build * Add to setup * Fix cc * Separate pybindings * Fix casting * Fix name of pybindings * Fix pybindings * Build on tag * Bump to 0.0.6 * Upgrade to 2.2.1
1 parent 8714b16 commit 83d1f4b

10 files changed

Lines changed: 1595 additions & 5 deletions

File tree

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ jobs:
9191
# Install torch
9292
$cudaVersion = $env:CUDA_VERSION.Replace('.', '')
9393
$cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1)
94-
if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.0" } else {$pytorchVersion = "torch==2.0.1"}
94+
$pytorchVersion = "torch==2.2.1"
9595
python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch
9696
python -m pip install build setuptools wheel ninja
9797

awq_ext/pybind_awq_v2.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#include <pybind11/pybind11.h>
2+
#include <torch/extension.h>
3+
#include "quantization_new/gemm/gemm_cuda.h"
4+
#include "quantization_new/gemv/gemv_cuda.h"
5+
6+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
7+
{
8+
m.def("gemm_forward_cuda_prefill", &gemm_forward_cuda_prefill, "New quantized GEMM kernel.");
9+
m.def("gemv_forward_cuda_decode", &gemv_forward_cuda_decode, "New quantized GEMM kernel.");
10+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
3+
4+
@article{lin2023awq,
5+
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
6+
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
7+
journal={arXiv},
8+
year={2023}
9+
}
10+
*/
11+
#include <cuda_fp16.h>
12+
#pragma once
13+
14+
__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result)
15+
{
16+
// uint4 result;
17+
18+
uint32_t *h = reinterpret_cast<uint32_t *>(result);
19+
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
20+
21+
// First, we extract the i4s and construct an intermediate fp16 number.
22+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
23+
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
24+
static constexpr uint32_t TOP_MASK = 0x00f000f0;
25+
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
26+
27+
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
28+
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
29+
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
30+
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
31+
32+
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
33+
// immediately before required.
34+
const uint32_t top_i4s = i4s >> 8;
35+
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
36+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
37+
: "=r"(h[0])
38+
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
39+
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
40+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
41+
: "=r"(h[1])
42+
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
43+
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
44+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
45+
: "=r"(h[2])
46+
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
47+
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
48+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
49+
: "=r"(h[3])
50+
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
51+
52+
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
53+
// half2 ctor. In this case, I chose performance reliability over code readability.
54+
55+
// This is the half2 {1032, 1032} represented as an integer.
56+
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
57+
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
58+
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
59+
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
60+
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
61+
// This is the half2 {-72, -72} represented as an integer.
62+
// static constexpr uint32_t NEG_72 = 0xd480d480;
63+
// Haotian: Let's use {-64, -64}.
64+
static constexpr uint32_t NEG_64 = 0xd400d400;
65+
66+
// Finally, we construct the output numbers.
67+
// Convert elt_01
68+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
69+
// Convert elt_23
70+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
71+
// Convert elt_45
72+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
73+
// Convert elt_67
74+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
75+
76+
// return result;
77+
}

0 commit comments

Comments
 (0)