|
| 1 | +#include "../../../devices/nvidia/nvidia_handle.cuh" |
| 2 | +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" |
| 3 | +#include "dequantize_w42f16_iluvatar.cuh" |
| 4 | +#include "dequantize_w42f16_kernel.cuh" |
| 5 | + |
| 6 | +#include "../dequantize_awq.h" |
| 7 | +#include <cuda_fp16.h> |
| 8 | + |
| 9 | +__global__ void __launch_bounds__(64) |
| 10 | + dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, |
| 11 | + int *__restrict__ zeros, half *__restrict__ C, int G) { |
| 12 | + // static constexpr uint32_t ZERO = 0x0; |
| 13 | + half B_shared[32 * (128 + 8)]; |
| 14 | + |
| 15 | + half *B_shared_ptr2 = B_shared; |
| 16 | + |
| 17 | + int N = blockDim.x * gridDim.x; // 2 |
| 18 | + int col = (blockIdx.x * blockDim.x + threadIdx.x); |
| 19 | + int row = (blockIdx.y * blockDim.y + threadIdx.y); |
| 20 | + int index1 = 8 * col + 8 * row * N; |
| 21 | + half *C_ptr2 = C + index1; |
| 22 | + |
| 23 | + int index2 = col + row * N; |
| 24 | + int *B_ptr2 = B + index2; |
| 25 | + |
| 26 | + int index3 = col + (int)(row / G) * N; |
| 27 | + int *zeros_ptr2 = zeros + index3; |
| 28 | + int index4 = 8 * col + (int)(row / G) * N * 8; |
| 29 | + half *scaling_factors_ptr2 = scaling_factors + index4; |
| 30 | + |
| 31 | + uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); |
| 32 | + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); |
| 33 | + uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2); |
| 34 | + |
| 35 | + uint32_t B_loaded = *(uint32_t *)B_ptr2; |
| 36 | + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); |
| 37 | + |
| 38 | + // Reinterpret uint4 components as __half2 |
| 39 | + __half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16); |
| 40 | + __half2 *B_loaded_zero_h2 = reinterpret_cast<__half2 *>(&B_loaded_zero); |
| 41 | + __half2 *B_loaded_scale_h2 = reinterpret_cast<__half2 *>(&B_loaded_scale); |
| 42 | + |
| 43 | + // Replace PTX sub.f16x2 with __hsub2 for each component |
| 44 | + B_loaded_fp16_h2[0] = __hsub2(B_loaded_fp16_h2[0], B_loaded_zero_h2[0]); |
| 45 | + B_loaded_fp16_h2[1] = __hsub2(B_loaded_fp16_h2[1], B_loaded_zero_h2[1]); |
| 46 | + B_loaded_fp16_h2[2] = __hsub2(B_loaded_fp16_h2[2], B_loaded_zero_h2[2]); |
| 47 | + B_loaded_fp16_h2[3] = __hsub2(B_loaded_fp16_h2[3], B_loaded_zero_h2[3]); |
| 48 | + |
| 49 | + // Replace PTX fma.rn.f16x2 with __hfma2 for each component |
| 50 | + B_loaded_fp16_h2[0] = __hfma2(B_loaded_fp16_h2[0], B_loaded_scale_h2[0], __float2half2_rn(0.0f)); |
| 51 | + B_loaded_fp16_h2[1] = __hfma2(B_loaded_fp16_h2[1], B_loaded_scale_h2[1], __float2half2_rn(0.0f)); |
| 52 | + B_loaded_fp16_h2[2] = __hfma2(B_loaded_fp16_h2[2], B_loaded_scale_h2[2], __float2half2_rn(0.0f)); |
| 53 | + B_loaded_fp16_h2[3] = __hfma2(B_loaded_fp16_h2[3], B_loaded_scale_h2[3], __float2half2_rn(0.0f)); |
| 54 | + |
| 55 | + // Store back to shared memory |
| 56 | + *(uint4 *)B_shared_ptr2 = B_loaded_fp16; |
| 57 | + |
| 58 | + for (int i = 0; i < 8; ++i) { |
| 59 | + *(C_ptr2 + i) = B_shared[i]; |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +namespace op::dequantize_awq::iluvatar { |
| 64 | + |
| 65 | +struct Descriptor::Opaque { |
| 66 | + std::shared_ptr<device::nvidia::Handle::Internal> internal; |
| 67 | +}; |
| 68 | + |
| 69 | +Descriptor::~Descriptor() { |
| 70 | + delete _opaque; |
| 71 | +} |
| 72 | + |
| 73 | +infiniStatus_t Descriptor::create( |
| 74 | + infiniopHandle_t handle_, |
| 75 | + Descriptor **desc_ptr, |
| 76 | + infiniopTensorDescriptor_t out_desc, |
| 77 | + infiniopTensorDescriptor_t qweight_desc, |
| 78 | + infiniopTensorDescriptor_t scales_desc, |
| 79 | + infiniopTensorDescriptor_t zeros_desc) { |
| 80 | + |
| 81 | + auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_); |
| 82 | + auto result = DequantizeAWQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc); |
| 83 | + |
| 84 | + *desc_ptr = new Descriptor( |
| 85 | + 0, |
| 86 | + new Opaque{handle->internal()}, |
| 87 | + result.take(), |
| 88 | + handle->device, handle->device_id); |
| 89 | + return INFINI_STATUS_SUCCESS; |
| 90 | +} |
| 91 | + |
| 92 | +infiniStatus_t |
| 93 | +Descriptor::calculate( |
| 94 | + void *workspace, |
| 95 | + size_t workspace_size, |
| 96 | + void *out, |
| 97 | + const void *qweight, |
| 98 | + const void *scales, |
| 99 | + const void *zeros, |
| 100 | + void *stream) const { |
| 101 | + int in_features = _info.in_features(); |
| 102 | + int out_features = _info.out_features(); |
| 103 | + int group_size = in_features / _info.num_groups(); |
| 104 | + |
| 105 | + // ==================== 默认配置, 固定为 8 ==================== |
| 106 | + constexpr int BLOCK_X = 8; |
| 107 | + constexpr int BLOCK_Y = 8; |
| 108 | + |
| 109 | + int x_blocks = (out_features + BLOCK_X - 1) / BLOCK_X; |
| 110 | + int y_blocks = (in_features + BLOCK_Y - 1) / BLOCK_Y; |
| 111 | + |
| 112 | + dim3 num_blocks(x_blocks, y_blocks); |
| 113 | + dim3 threads_per_block(BLOCK_X, BLOCK_Y); |
| 114 | + // ===================================================== |
| 115 | + |
| 116 | + half *out_ = reinterpret_cast<half *>(out); |
| 117 | + |
| 118 | + int *qweight_ = const_cast<int *>(reinterpret_cast<const int *>(qweight)); |
| 119 | + half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales)); |
| 120 | + int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros)); |
| 121 | + |
| 122 | + dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<cudaStream_t>(stream)>>>( |
| 123 | + qweight_, scales_, zeros_, out_, group_size); |
| 124 | + return INFINI_STATUS_SUCCESS; |
| 125 | +} |
| 126 | + |
| 127 | +} // namespace op::dequantize_awq::iluvatar |
0 commit comments