|
| 1 | +// NF4 反量化 CUDA 程序 |
| 2 | +// 用法: ./nf4_dequant <weight_file> <output_file> [bf16|fp16] [warmup] [repeats] |
| 3 | + |
| 4 | +#include <cstdio> |
| 5 | +#include <cstdlib> |
| 6 | +#include <cstdint> |
| 7 | +#include <cstring> |
| 8 | +#include <cmath> |
| 9 | +#include <vector> |
| 10 | +#include <string> |
| 11 | +#include <chrono> |
| 12 | +#include <algorithm> |
| 13 | + |
| 14 | +#include <cuda_runtime.h> |
| 15 | + |
| 16 | +#include "nf4_dequant_kernel.cuh" |
| 17 | + |
| 18 | +// CUDA 错误检查 |
| 19 | +#define CUDA_CHECK(call) \ |
| 20 | + do { \ |
| 21 | + cudaError_t err = (call); \ |
| 22 | + if (err != cudaSuccess) { \ |
| 23 | + fprintf(stderr, "CUDA error at %s:%d: %s\n", \ |
| 24 | + __FILE__, __LINE__, cudaGetErrorString(err)); \ |
| 25 | + exit(EXIT_FAILURE); \ |
| 26 | + } \ |
| 27 | + } while (0) |
| 28 | + |
| 29 | +// 二进制权重文件布局: header (rows, cols, blocksize) + packed_weights + absmax_q + absmax2 + code2 + offset |
| 30 | +struct NF4Data { |
| 31 | + int64_t num_rows; |
| 32 | + int64_t num_cols; |
| 33 | + int32_t blocksize; |
| 34 | + |
| 35 | + std::vector<uint8_t> packed_weights; |
| 36 | + std::vector<uint8_t> absmax_q; |
| 37 | + std::vector<uint16_t> absmax2; // fp16 raw bits |
| 38 | + std::vector<uint16_t> code2; // fp16[256] raw bits |
| 39 | + float offset; |
| 40 | + |
| 41 | + int64_t n_elements; |
| 42 | + int32_t num_blocks; |
| 43 | + int32_t num_groups; |
| 44 | + int32_t s2_blocksize; |
| 45 | +}; |
| 46 | + |
| 47 | +bool read_nf4_data(const char* filepath, NF4Data& data) { |
| 48 | + FILE* f = fopen(filepath, "rb"); |
| 49 | + if (!f) { |
| 50 | + fprintf(stderr, "[ERROR] Cannot open file: %s\n", filepath); |
| 51 | + return false; |
| 52 | + } |
| 53 | + |
| 54 | + // Header |
| 55 | + fread(&data.num_rows, sizeof(int64_t), 1, f); |
| 56 | + fread(&data.num_cols, sizeof(int64_t), 1, f); |
| 57 | + fread(&data.blocksize, sizeof(int32_t), 1, f); |
| 58 | + |
| 59 | + data.n_elements = data.num_rows * data.num_cols; |
| 60 | + data.num_blocks = (int32_t)((data.n_elements + data.blocksize - 1) / data.blocksize); |
| 61 | + |
| 62 | + int64_t packed_size = data.n_elements / 2; |
| 63 | + data.packed_weights.resize(packed_size); |
| 64 | + fread(data.packed_weights.data(), 1, packed_size, f); |
| 65 | + |
| 66 | + data.absmax_q.resize(data.num_blocks); |
| 67 | + fread(data.absmax_q.data(), 1, data.num_blocks, f); |
| 68 | + |
| 69 | + // 从剩余字节反推 num_groups(文件中未显式存储) |
| 70 | + long current_pos = ftell(f); |
| 71 | + fseek(f, 0, SEEK_END); |
| 72 | + long file_size = ftell(f); |
| 73 | + fseek(f, current_pos, SEEK_SET); |
| 74 | + |
| 75 | + long remaining = file_size - current_pos; |
| 76 | + long fixed_tail = 256 * 2 + 4; // code2 (512B) + offset (4B) |
| 77 | + long absmax2_bytes = remaining - fixed_tail; |
| 78 | + data.num_groups = (int32_t)(absmax2_bytes / 2); |
| 79 | + data.s2_blocksize = (data.num_blocks + data.num_groups - 1) / data.num_groups; |
| 80 | + |
| 81 | + data.absmax2.resize(data.num_groups); |
| 82 | + fread(data.absmax2.data(), 2, data.num_groups, f); |
| 83 | + |
| 84 | + data.code2.resize(256); |
| 85 | + fread(data.code2.data(), 2, 256, f); |
| 86 | + |
| 87 | + fread(&data.offset, sizeof(float), 1, f); |
| 88 | + |
| 89 | + fclose(f); |
| 90 | + return true; |
| 91 | +} |
| 92 | + |
| 93 | +int main(int argc, char* argv[]) { |
| 94 | + if (argc < 3) { |
| 95 | + fprintf(stderr, "用法: %s <weight_file> <output_file> [bf16|fp16] [warmup] [repeats]\n", argv[0]); |
| 96 | + return 1; |
| 97 | + } |
| 98 | + |
| 99 | + const char* weight_file = argv[1]; |
| 100 | + const char* output_file = argv[2]; |
| 101 | + std::string compute_type = (argc > 3) ? argv[3] : "bf16"; |
| 102 | + int warmup = (argc > 4) ? atoi(argv[4]) : 10; |
| 103 | + int repeats = (argc > 5) ? atoi(argv[5]) : 100; |
| 104 | + |
| 105 | + bool use_bf16 = (compute_type == "bf16"); |
| 106 | + |
| 107 | + // 读取数据 |
| 108 | + printf("[INFO] 读取权重文件: %s\n", weight_file); |
| 109 | + NF4Data data; |
| 110 | + if (!read_nf4_data(weight_file, data)) return 1; |
| 111 | + |
| 112 | + printf(" num_rows = %ld\n", (long)data.num_rows); |
| 113 | + printf(" num_cols = %ld\n", (long)data.num_cols); |
| 114 | + printf(" blocksize = %d\n", data.blocksize); |
| 115 | + printf(" n_elements = %ld\n", (long)data.n_elements); |
| 116 | + printf(" num_blocks = %d\n", data.num_blocks); |
| 117 | + printf(" num_groups = %d\n", data.num_groups); |
| 118 | + printf(" s2_blocksize = %d\n", data.s2_blocksize); |
| 119 | + printf(" offset = %f\n", data.offset); |
| 120 | + printf(" compute_type = %s\n", compute_type.c_str()); |
| 121 | + |
| 122 | + // 分配 GPU 内存 |
| 123 | + uint8_t* d_packed_weights; |
| 124 | + uint8_t* d_absmax_q; |
| 125 | + half* d_absmax2; |
| 126 | + half* d_code2; |
| 127 | + void* d_output; |
| 128 | + |
| 129 | + int64_t packed_size = data.n_elements / 2; |
| 130 | + int64_t output_bytes = data.n_elements * 2; // bf16/fp16 = 2 bytes each |
| 131 | + |
| 132 | + CUDA_CHECK(cudaMalloc(&d_packed_weights, packed_size)); |
| 133 | + CUDA_CHECK(cudaMalloc(&d_absmax_q, data.num_blocks)); |
| 134 | + CUDA_CHECK(cudaMalloc(&d_absmax2, data.num_groups * sizeof(half))); |
| 135 | + CUDA_CHECK(cudaMalloc(&d_code2, 256 * sizeof(half))); |
| 136 | + CUDA_CHECK(cudaMalloc(&d_output, output_bytes)); |
| 137 | + |
| 138 | + // H2D 传输 |
| 139 | + CUDA_CHECK(cudaMemcpy(d_packed_weights, data.packed_weights.data(), |
| 140 | + packed_size, cudaMemcpyHostToDevice)); |
| 141 | + CUDA_CHECK(cudaMemcpy(d_absmax_q, data.absmax_q.data(), |
| 142 | + data.num_blocks, cudaMemcpyHostToDevice)); |
| 143 | + CUDA_CHECK(cudaMemcpy(d_absmax2, data.absmax2.data(), |
| 144 | + data.num_groups * sizeof(half), cudaMemcpyHostToDevice)); |
| 145 | + CUDA_CHECK(cudaMemcpy(d_code2, data.code2.data(), |
| 146 | + 256 * sizeof(half), cudaMemcpyHostToDevice)); |
| 147 | + |
| 148 | + // Kernel launch 配置 |
| 149 | + int n_packed = (int)((data.n_elements + 1) / 2); |
| 150 | + int n_packed_vec = (n_packed + 3) / 4; // 每线程 4 字节 |
| 151 | + int threads_per_block = 256; |
| 152 | + int num_blocks_kernel = (n_packed_vec + threads_per_block - 1) / threads_per_block; |
| 153 | + |
| 154 | + // 预计算 log2 用于位移优化 |
| 155 | + int log2_bs = log2_pow2(data.blocksize); |
| 156 | + int log2_s2 = log2_pow2(data.s2_blocksize); |
| 157 | + |
| 158 | + printf("\n[INFO] Kernel 配置:\n"); |
| 159 | + printf(" n_packed = %d\n", n_packed); |
| 160 | + printf(" n_packed_vec = %d (向量化后)\n", n_packed_vec); |
| 161 | + printf(" threads_per_block = %d\n", threads_per_block); |
| 162 | + printf(" grid_size = %d\n", num_blocks_kernel); |
| 163 | + printf(" log2_blocksize = %d\n", log2_bs); |
| 164 | + printf(" log2_s2_blocksize = %d\n", log2_s2); |
| 165 | + |
| 166 | + // 预热 |
| 167 | + printf("\n[INFO] 预热 %d 次...\n", warmup); |
| 168 | + for (int i = 0; i < warmup; i++) { |
| 169 | + if (use_bf16) { |
| 170 | + nf4_dequantize_kernel<__nv_bfloat16><<<num_blocks_kernel, threads_per_block>>>( |
| 171 | + d_packed_weights, d_absmax_q, d_absmax2, d_code2, |
| 172 | + data.offset, log2_bs, log2_s2, |
| 173 | + data.n_elements, (__nv_bfloat16*)d_output |
| 174 | + ); |
| 175 | + } else { |
| 176 | + nf4_dequantize_kernel<half><<<num_blocks_kernel, threads_per_block>>>( |
| 177 | + d_packed_weights, d_absmax_q, d_absmax2, d_code2, |
| 178 | + data.offset, log2_bs, log2_s2, |
| 179 | + data.n_elements, (half*)d_output |
| 180 | + ); |
| 181 | + } |
| 182 | + } |
| 183 | + CUDA_CHECK(cudaDeviceSynchronize()); |
| 184 | + |
| 185 | + // 计时: CUDA Events,每次迭代间同步以隔离测量 |
| 186 | + printf("[INFO] 计时 %d 次...\n", repeats); |
| 187 | + |
| 188 | + cudaEvent_t ev_start, ev_end; |
| 189 | + CUDA_CHECK(cudaEventCreate(&ev_start)); |
| 190 | + CUDA_CHECK(cudaEventCreate(&ev_end)); |
| 191 | + |
| 192 | + std::vector<float> times(repeats); |
| 193 | + |
| 194 | + for (int i = 0; i < repeats; i++) { |
| 195 | + CUDA_CHECK(cudaDeviceSynchronize()); |
| 196 | + CUDA_CHECK(cudaEventRecord(ev_start)); |
| 197 | + if (use_bf16) { |
| 198 | + nf4_dequantize_kernel<__nv_bfloat16><<<num_blocks_kernel, threads_per_block>>>( |
| 199 | + d_packed_weights, d_absmax_q, d_absmax2, d_code2, |
| 200 | + data.offset, log2_bs, log2_s2, |
| 201 | + data.n_elements, (__nv_bfloat16*)d_output |
| 202 | + ); |
| 203 | + } else { |
| 204 | + nf4_dequantize_kernel<half><<<num_blocks_kernel, threads_per_block>>>( |
| 205 | + d_packed_weights, d_absmax_q, d_absmax2, d_code2, |
| 206 | + data.offset, log2_bs, log2_s2, |
| 207 | + data.n_elements, (half*)d_output |
| 208 | + ); |
| 209 | + } |
| 210 | + CUDA_CHECK(cudaEventRecord(ev_end)); |
| 211 | + CUDA_CHECK(cudaEventSynchronize(ev_end)); |
| 212 | + CUDA_CHECK(cudaEventElapsedTime(×[i], ev_start, ev_end)); |
| 213 | + } |
| 214 | + |
| 215 | + // 排序取中位数,抗干扰 |
| 216 | + std::vector<float> sorted_times = times; |
| 217 | + std::sort(sorted_times.begin(), sorted_times.end()); |
| 218 | + |
| 219 | + float total_ms = 0.0f; |
| 220 | + float min_ms = sorted_times.front(); |
| 221 | + float max_ms = sorted_times.back(); |
| 222 | + for (int i = 0; i < repeats; i++) total_ms += times[i]; |
| 223 | + float avg_ms = total_ms / repeats; |
| 224 | + float median_ms = sorted_times[repeats / 2]; |
| 225 | + |
| 226 | + // 有效内存带宽 (基于中位数) |
| 227 | + double read_bytes = (double)packed_size + data.num_blocks + data.num_groups * 2 + 256 * 2; |
| 228 | + double write_bytes = (double)output_bytes; |
| 229 | + double total_bytes = read_bytes + write_bytes; |
| 230 | + double bandwidth_gbps = total_bytes / (median_ms * 1e-3) / 1e9; |
| 231 | + |
| 232 | + printf("\n========================================\n"); |
| 233 | + printf(" NF4 反量化 Kernel 性能\n"); |
| 234 | + printf("========================================\n"); |
| 235 | + printf(" 矩阵大小 : (%ld, %ld)\n", (long)data.num_rows, (long)data.num_cols); |
| 236 | + printf(" 块大小 : %d\n", data.blocksize); |
| 237 | + printf(" 输出类型 : %s\n", compute_type.c_str()); |
| 238 | + printf(" 平均耗时 : %.4f ms\n", avg_ms); |
| 239 | + printf(" 中位数耗时 : %.4f ms\n", median_ms); |
| 240 | + printf(" 最小耗时 : %.4f ms\n", min_ms); |
| 241 | + printf(" 最大耗时 : %.4f ms\n", max_ms); |
| 242 | + printf(" 有效带宽 : %.2f GB/s (基于中位数)\n", bandwidth_gbps); |
| 243 | + printf("========================================\n"); |
| 244 | + |
| 245 | + // 写出结果 |
| 246 | + std::vector<uint8_t> h_output(output_bytes); |
| 247 | + CUDA_CHECK(cudaMemcpy(h_output.data(), d_output, output_bytes, cudaMemcpyDeviceToHost)); |
| 248 | + |
| 249 | + FILE* fout = fopen(output_file, "wb"); |
| 250 | + if (!fout) { |
| 251 | + fprintf(stderr, "[ERROR] Cannot open output file: %s\n", output_file); |
| 252 | + return 1; |
| 253 | + } |
| 254 | + fwrite(h_output.data(), 1, output_bytes, fout); |
| 255 | + fclose(fout); |
| 256 | + printf("\n[INFO] 已写入解量化输出: %s (%ld bytes)\n", output_file, (long)output_bytes); |
| 257 | + |
| 258 | + // 清理 |
| 259 | + cudaEventDestroy(ev_start); |
| 260 | + cudaEventDestroy(ev_end); |
| 261 | + CUDA_CHECK(cudaFree(d_packed_weights)); |
| 262 | + CUDA_CHECK(cudaFree(d_absmax_q)); |
| 263 | + CUDA_CHECK(cudaFree(d_absmax2)); |
| 264 | + CUDA_CHECK(cudaFree(d_code2)); |
| 265 | + CUDA_CHECK(cudaFree(d_output)); |
| 266 | + |
| 267 | + printf("[DONE] 完成\n"); |
| 268 | + return 0; |
| 269 | +} |
0 commit comments