Skip to content

Commit fbfb0ef

Browse files
feat: add AWQ dequantize in moore gpu, with test pass (#485)
2 parents 2f3fd75 + 8896615 commit fbfb0ef

5 files changed

Lines changed: 210 additions & 1 deletion

File tree

src/infiniop/devices/moore/moore_kernel_common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ exp_(const float val) {
3737
return expf(val);
3838
}
3939

40+
// Computes exp for long double on Moore GPU,
41+
// casts to double to resolve ambiguous exp call,
42+
// due to conflicting double/float definitions in MUSA math libraries.
4043
__forceinline__ __device__ long double
4144
exp_(const long double val) {
42-
return exp(val);
45+
return static_cast<long double>(exp(static_cast<double>(val)));
4346
}
4447

4548
__forceinline__ __device__ double
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma once
2+
#include <musa_fp16.h> // 需要此头文件来支持 __half 和 __half2 类型
3+
4+
/**
5+
* @brief 将一个包含8个4-bit整数的uint32_t反量化为8个half精度浮点数。
6+
*
7+
* 这是一个通用的 CUDA C++ 实现,用于替代原有的 PTX 汇编版本,
8+
* 以便在不支持高级 PTX 指令(如 lop3.b32)的 GPU 上运行。
9+
* 输出顺序匹配 PTX 的交错打包:v0, v4, v1, v5, v2, v6, v3, v7(经 signed 调整后)。
10+
*
11+
* @param source 输入的32位无符号整数,它打包了8个4-bit的数据。
12+
* @return 一个 uint4 变量,其中包含8个反量化后的 half 值。
13+
*/
14+
__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) {
15+
// 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。
16+
// 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0],
17+
// 其中每个 'v' 都是一个 4-bit 的半字节 (nibble)。
18+
const unsigned int v0 = (source >> 0) & 0x0F;
19+
const unsigned int v1 = (source >> 4) & 0x0F;
20+
const unsigned int v2 = (source >> 8) & 0x0F;
21+
const unsigned int v3 = (source >> 12) & 0x0F;
22+
const unsigned int v4 = (source >> 16) & 0x0F;
23+
const unsigned int v5 = (source >> 20) & 0x0F;
24+
const unsigned int v6 = (source >> 24) & 0x0F;
25+
const unsigned int v7 = (source >> 28) & 0x0F;
26+
27+
// 步骤 2: 对于 signed 4-bit (s4),减去 8 以映射到 [-8, 7] 范围。
28+
// 定义偏移量
29+
__half offset = __half(8);
30+
31+
// 计算 signed 值
32+
__half hv0 = __half(v0) - offset;
33+
__half hv1 = __half(v1) - offset;
34+
__half hv2 = __half(v2) - offset;
35+
__half hv3 = __half(v3) - offset;
36+
__half hv4 = __half(v4) - offset;
37+
__half hv5 = __half(v5) - offset;
38+
__half hv6 = __half(v6) - offset;
39+
__half hv7 = __half(v7) - offset;
40+
41+
// 步骤 3: 将 half 值按 PTX 交错顺序打包成 __half2 并存入 result 中。
42+
// 顺序:result_ptr[0]: low=hv0, high=hv4
43+
// result_ptr[1]: low=hv1, high=hv5
44+
// result_ptr[2]: low=hv2, high=hv6
45+
// result_ptr[3]: low=hv3, high=hv7
46+
// __halves2half2 函数:low 为第一个参数,high 为第二个参数。
47+
uint4 result;
48+
__half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
49+
50+
result_ptr[0] = __halves2half2(hv0, hv4);
51+
result_ptr[1] = __halves2half2(hv1, hv5);
52+
result_ptr[2] = __halves2half2(hv2, hv6);
53+
result_ptr[3] = __halves2half2(hv3, hv7);
54+
55+
return result;
56+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __DEQUANTIZE_AWQ_MOORE_H__
2+
#define __DEQUANTIZE_AWQ_MOORE_H__
3+
4+
#include "../dequantize_awq.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __DEQUANTIZE_AWQ_MOORE_H__
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "../../../devices/moore/moore_handle.h"
2+
#include "../../../devices/moore/moore_kernel_common.h"
3+
#include "dequantize_w42f16_moore.h"
4+
#include "dequantize_w42f16_kernel.h"
5+
6+
#include "../dequantize_awq.h"
7+
#include <musa_fp16.h>
8+
9+
__global__ void __launch_bounds__(64)
10+
dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors,
11+
int *__restrict__ zeros, half *__restrict__ C, int G) {
12+
// static constexpr uint32_t ZERO = 0x0;
13+
half B_shared[32 * (128 + 8)];
14+
15+
half *B_shared_ptr2 = B_shared;
16+
17+
int N = blockDim.x * gridDim.x; // 2
18+
int col = (blockIdx.x * blockDim.x + threadIdx.x);
19+
int row = (blockIdx.y * blockDim.y + threadIdx.y);
20+
int index1 = 8 * col + 8 * row * N;
21+
half *C_ptr2 = C + index1;
22+
23+
int index2 = col + row * N;
24+
int *B_ptr2 = B + index2;
25+
26+
int index3 = col + (int)(row / G) * N;
27+
int *zeros_ptr2 = zeros + index3;
28+
int index4 = 8 * col + (int)(row / G) * N * 8;
29+
half *scaling_factors_ptr2 = scaling_factors + index4;
30+
31+
uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
32+
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
33+
uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2);
34+
35+
uint32_t B_loaded = *(uint32_t *)B_ptr2;
36+
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
37+
38+
// Reinterpret uint4 components as __half2
39+
__half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16);
40+
__half2 *B_loaded_zero_h2 = reinterpret_cast<__half2 *>(&B_loaded_zero);
41+
__half2 *B_loaded_scale_h2 = reinterpret_cast<__half2 *>(&B_loaded_scale);
42+
43+
// Replace PTX sub.f16x2 with __hsub2 for each component
44+
B_loaded_fp16_h2[0] = __hsub2(B_loaded_fp16_h2[0], B_loaded_zero_h2[0]);
45+
B_loaded_fp16_h2[1] = __hsub2(B_loaded_fp16_h2[1], B_loaded_zero_h2[1]);
46+
B_loaded_fp16_h2[2] = __hsub2(B_loaded_fp16_h2[2], B_loaded_zero_h2[2]);
47+
B_loaded_fp16_h2[3] = __hsub2(B_loaded_fp16_h2[3], B_loaded_zero_h2[3]);
48+
49+
// Replace PTX fma.rn.f16x2 with __hfma2 for each component
50+
B_loaded_fp16_h2[0] = __hfma2(B_loaded_fp16_h2[0], B_loaded_scale_h2[0], __float2half2_rn(0.0f));
51+
B_loaded_fp16_h2[1] = __hfma2(B_loaded_fp16_h2[1], B_loaded_scale_h2[1], __float2half2_rn(0.0f));
52+
B_loaded_fp16_h2[2] = __hfma2(B_loaded_fp16_h2[2], B_loaded_scale_h2[2], __float2half2_rn(0.0f));
53+
B_loaded_fp16_h2[3] = __hfma2(B_loaded_fp16_h2[3], B_loaded_scale_h2[3], __float2half2_rn(0.0f));
54+
55+
// Store back to shared memory
56+
*(uint4 *)B_shared_ptr2 = B_loaded_fp16;
57+
58+
for (int i = 0; i < 8; ++i) {
59+
*(C_ptr2 + i) = B_shared[i];
60+
}
61+
}
62+
63+
namespace op::dequantize_awq::moore {
64+
65+
struct Descriptor::Opaque {
66+
std::shared_ptr<device::moore::Handle::Internal> internal;
67+
};
68+
69+
Descriptor::~Descriptor() {
70+
delete _opaque;
71+
}
72+
73+
infiniStatus_t Descriptor::create(
74+
infiniopHandle_t handle_,
75+
Descriptor **desc_ptr,
76+
infiniopTensorDescriptor_t out_desc,
77+
infiniopTensorDescriptor_t qweight_desc,
78+
infiniopTensorDescriptor_t scales_desc,
79+
infiniopTensorDescriptor_t zeros_desc) {
80+
81+
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
82+
auto result = DequantizeAWQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc);
83+
84+
*desc_ptr = new Descriptor(
85+
0,
86+
new Opaque{handle->internal()},
87+
result.take(),
88+
handle->device, handle->device_id);
89+
return INFINI_STATUS_SUCCESS;
90+
}
91+
92+
infiniStatus_t
93+
Descriptor::calculate(
94+
void *workspace,
95+
size_t workspace_size,
96+
void *out,
97+
const void *qweight,
98+
const void *scales,
99+
const void *zeros,
100+
void *stream) const {
101+
int in_features = _info.in_features();
102+
int out_features = _info.out_features();
103+
int group_size = in_features / _info.num_groups();
104+
105+
// ==================== 默认配置, 固定为 8 ====================
106+
constexpr int BLOCK_X = 8;
107+
constexpr int BLOCK_Y = 8;
108+
109+
int x_blocks = (out_features + BLOCK_X - 1) / BLOCK_X;
110+
int y_blocks = (in_features + BLOCK_Y - 1) / BLOCK_Y;
111+
112+
dim3 num_blocks(x_blocks, y_blocks);
113+
dim3 threads_per_block(BLOCK_X, BLOCK_Y);
114+
// =====================================================
115+
116+
half *out_ = reinterpret_cast<half *>(out);
117+
118+
int *qweight_ = const_cast<int *>(reinterpret_cast<const int *>(qweight));
119+
half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales));
120+
int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros));
121+
122+
dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<musaStream_t>(stream)>>>(
123+
qweight_, scales_, zeros_, out_, group_size);
124+
return INFINI_STATUS_SUCCESS;
125+
}
126+
127+
} // namespace op::dequantize_awq::moore

src/infiniop/ops/dequantize_awq/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#ifdef ENABLE_NVIDIA_API
66
#include "nvidia/dequantize_w42f16_nvidia.cuh"
77
#endif
8+
#ifdef ENABLE_MOORE_API
9+
#include "moore/dequantize_w42f16_moore.h"
10+
#endif
811

912
__C infiniStatus_t infiniopCreateDequantizeAWQDescriptor(
1013
infiniopHandle_t handle,
@@ -27,6 +30,9 @@ __C infiniStatus_t infiniopCreateDequantizeAWQDescriptor(
2730
switch (handle->device) {
2831
#ifdef ENABLE_NVIDIA_API
2932
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
33+
#endif
34+
#ifdef ENABLE_MOORE_API
35+
CREATE(INFINI_DEVICE_MOORE, moore);
3036
#endif
3137
default:
3238
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetDequantizeAWQWorkspaceSize(infiniopDequantizeAWQDe
4551
switch (desc->device_type) {
4652
#ifdef ENABLE_NVIDIA_API
4753
GET(INFINI_DEVICE_NVIDIA, nvidia);
54+
#endif
55+
#ifdef ENABLE_MOORE_API
56+
GET(INFINI_DEVICE_MOORE, moore);
4857
#endif
4958
default:
5059
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -70,6 +79,9 @@ __C infiniStatus_t infiniopDequantizeAWQ(
7079
switch (desc->device_type) {
7180
#ifdef ENABLE_NVIDIA_API
7281
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
82+
#endif
83+
#ifdef ENABLE_MOORE_API
84+
CALCULATE(INFINI_DEVICE_MOORE, moore);
7385
#endif
7486
default:
7587
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -89,6 +101,9 @@ infiniopDestroyDequantizeAWQDescriptor(infiniopDequantizeAWQDescriptor_t desc) {
89101
switch (desc->device_type) {
90102
#ifdef ENABLE_NVIDIA_API
91103
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
104+
#endif
105+
#ifdef ENABLE_MOORE_API
106+
DELETE(INFINI_DEVICE_MOORE, moore);
92107
#endif
93108
default:
94109
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

0 commit comments

Comments
 (0)