|
| 1 | +#include "ops_common.h" |
| 2 | +#include "reduce/sm70.cuh" |
| 3 | + |
| 4 | + |
| 5 | +namespace lightllm { |
| 6 | +namespace ops { |
| 7 | + |
| 8 | +using namespace lightllm; |
| 9 | + |
| 10 | +// CUDA kernel for per token quantization from FP16 to FP8 |
| 11 | +template<int32_t TPB> |
| 12 | +__global__ void device_per_token_quant_fp16_to_fp8_general( |
| 13 | + const fp16_t* __restrict__ input, // Input tensor in FP16 format |
| 14 | + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format |
| 15 | + fp32_t* __restrict__ scales, // Output scales for each token |
| 16 | + const int64_t N |
| 17 | +) { |
| 18 | + const int32_t bid = blockIdx.x; |
| 19 | + const int32_t tid = threadIdx.x; |
| 20 | + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format |
| 21 | + |
| 22 | + const fp16_t* _input = input + bid * N; // Input pointer for the token |
| 23 | + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token |
| 24 | + |
| 25 | + fp32_t* _scales; |
| 26 | + _scales = scales + bid; |
| 27 | + |
| 28 | + // Local arrays for intermediate storage |
| 29 | + fp8_e4m3_t local_f8; |
| 30 | + fp16_t local_fp16; |
| 31 | + |
| 32 | + extern __shared__ fp16_t workspace1[]; |
| 33 | + |
| 34 | + fp32_t local_max = -FLT_MAX; |
| 35 | + for (int32_t i = tid; i < N; i += TPB) { |
| 36 | + local_fp16 = _input[i]; |
| 37 | + workspace1[i] = local_fp16; |
| 38 | + |
| 39 | + fp32_t tmp = cvt_f16_f32(local_fp16); |
| 40 | + local_max = fmaxf(local_max, fabsf(tmp)); |
| 41 | + } |
| 42 | + |
| 43 | + // Reduce the maximum value across the block |
| 44 | + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); |
| 45 | + |
| 46 | + // Compute the scale factor with epsilon to avoid division by zero |
| 47 | + constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f); |
| 48 | + const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX); |
| 49 | + |
| 50 | + for (int32_t i = tid; i < N; i += TPB) { |
| 51 | + local_fp16 = workspace1[i]; |
| 52 | + |
| 53 | + fp32_t tmp = cvt_f16_f32(local_fp16); |
| 54 | + fp32_t x = tmp / scale; |
| 55 | + local_f8 = fp8_e4m3_t(x); |
| 56 | + |
| 57 | + _output[i] = local_f8; |
| 58 | + } |
| 59 | + |
| 60 | + if (tid == 0) { |
| 61 | + *_scales = scale; |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +// CUDA kernel for per token quantization from FP16 to FP8 |
| 66 | +template<int32_t TPB> |
| 67 | +__global__ void device_per_token_quant_fp16_to_fp8_vpt( |
| 68 | + const fp16_t* __restrict__ input, // Input tensor in FP16 format |
| 69 | + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format |
| 70 | + fp32_t* __restrict__ scales, // Output scales for each token |
| 71 | + const int32_t N |
| 72 | +) { |
| 73 | + constexpr int32_t VPT = 8; |
| 74 | + |
| 75 | + const int32_t bid = blockIdx.x; |
| 76 | + const int32_t tid = threadIdx.x; |
| 77 | + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format |
| 78 | + |
| 79 | + const fp16_t* _input = input + bid * N; // Input pointer for the token |
| 80 | + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token |
| 81 | + |
| 82 | + fp32_t* _scales; |
| 83 | + _scales = scales + bid; |
| 84 | + |
| 85 | + // Local arrays for intermediate storage |
| 86 | + fp8x4_e4m3_t local_f8[VPT / 4]; |
| 87 | + fp16x2_t local_fp16[VPT / 2]; |
| 88 | + |
| 89 | + extern __shared__ fp16x2_t workspace2[]; |
| 90 | + |
| 91 | + fp32_t local_max = -FLT_MAX; |
| 92 | + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { |
| 93 | + // Load VPT FP16 elements from global memory (_X) into local vector (local_x). |
| 94 | + vec_copy<sizeof(fp16_t) * VPT>(_input + i, local_fp16); |
| 95 | + |
| 96 | + vec_copy<sizeof(fp16_t) * VPT>(local_fp16, workspace2 + (i >> 1)); |
| 97 | + |
| 98 | + // Compute the max for the VPT elements. |
| 99 | + #pragma unroll |
| 100 | + for (int32_t j = 0; j < VPT / 2; j++) { |
| 101 | + fp32x2_t tmp = fp16x2_to_fp32x2(local_fp16[j]); |
| 102 | + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); |
| 103 | + local_max = fmaxf(local_max, max); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + // Reduce the maximum value across the block |
| 108 | + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); |
| 109 | + |
| 110 | + // Compute the scale factor with epsilon to avoid division by zero |
| 111 | + constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f); |
| 112 | + const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX); |
| 113 | + |
| 114 | + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { |
| 115 | + vec_copy<sizeof(fp16_t) * VPT>(workspace2 + (i >> 1), local_fp16); |
| 116 | + |
| 117 | + #pragma unroll |
| 118 | + for (int32_t j = 0; j < VPT / 4; j++) { |
| 119 | + fp32x2_t x = fp16x2_to_fp32x2(local_fp16[2 * j + 0]); |
| 120 | + fp32x2_t y = fp16x2_to_fp32x2(local_fp16[2 * j + 1]); |
| 121 | + fp32x4_t ret = make_float4( |
| 122 | + x.x / scale, |
| 123 | + x.y / scale, |
| 124 | + y.x / scale, |
| 125 | + y.y / scale |
| 126 | + ); |
| 127 | + local_f8[j] = fp8x4_e4m3_t(ret); |
| 128 | + } |
| 129 | + |
| 130 | + vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i); |
| 131 | + } |
| 132 | + |
| 133 | + if (tid == 0) { |
| 134 | + *_scales = scale; |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +// CUDA kernel for per token quantization from FP16 to FP8 |
| 139 | +template<int32_t TPB, int32_t N> |
| 140 | +__global__ void device_per_token_quant_fp16_to_fp8( |
| 141 | + const fp16_t* __restrict__ input, // Input tensor in FP16 format |
| 142 | + fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format |
| 143 | + fp32_t* __restrict__ scales // Output scales for each token |
| 144 | +) { |
| 145 | + constexpr int32_t VPT = 8; |
| 146 | + |
| 147 | + static_assert(N % 2 == 0, "N must be even."); |
| 148 | + static_assert(N % VPT == 0, "N must be a multiple of VPT."); |
| 149 | + |
| 150 | + const int32_t bid = blockIdx.x; |
| 151 | + const int32_t tid = threadIdx.x; |
| 152 | + constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format |
| 153 | + |
| 154 | + const fp16_t* _input = input + bid * N; // Input pointer for the token |
| 155 | + fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token |
| 156 | + |
| 157 | + fp32_t* _scales; |
| 158 | + _scales = scales + bid; |
| 159 | + |
| 160 | + // Local arrays for intermediate storage |
| 161 | + fp8x4_e4m3_t local_f8[VPT / 4]; |
| 162 | + fp16x2_t local_fp16[VPT / 2]; |
| 163 | + |
| 164 | + __shared__ fp16x2_t workspace[N / 2]; |
| 165 | + |
| 166 | + fp32_t local_max = -FLT_MAX; |
| 167 | + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { |
| 168 | + // Load VPT FP16 elements from global memory (_X) into local vector (local_x). |
| 169 | + vec_copy<sizeof(fp16_t) * VPT>(_input + i, local_fp16); |
| 170 | + |
| 171 | + vec_copy<sizeof(fp16_t) * VPT>(local_fp16, workspace + (i >> 1)); |
| 172 | + |
| 173 | + // Compute the max for the VPT elements. |
| 174 | + #pragma unroll |
| 175 | + for (int32_t j = 0; j < VPT / 2; j++) { |
| 176 | + fp32x2_t tmp = fp16x2_to_fp32x2(local_fp16[j]); |
| 177 | + fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y)); |
| 178 | + local_max = fmaxf(local_max, max); |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + // Reduce the maximum value across the block |
| 183 | + const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max); |
| 184 | + |
| 185 | + // Compute the scale factor with epsilon to avoid division by zero |
| 186 | + constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f); |
| 187 | + const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX); |
| 188 | + |
| 189 | + for (int32_t i = tid * VPT; i < N; i += TPB * VPT) { |
| 190 | + vec_copy<sizeof(fp16_t) * VPT>(workspace + (i >> 1), local_fp16); |
| 191 | + |
| 192 | + #pragma unroll |
| 193 | + for (int32_t j = 0; j < VPT / 4; j++) { |
| 194 | + fp32x2_t x = fp16x2_to_fp32x2(local_fp16[2 * j + 0]); |
| 195 | + fp32x2_t y = fp16x2_to_fp32x2(local_fp16[2 * j + 1]); |
| 196 | + fp32x4_t ret = make_float4( |
| 197 | + x.x / scale, |
| 198 | + x.y / scale, |
| 199 | + y.x / scale, |
| 200 | + y.y / scale |
| 201 | + ); |
| 202 | + local_f8[j] = fp8x4_e4m3_t(ret); |
| 203 | + } |
| 204 | + |
| 205 | + vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i); |
| 206 | + } |
| 207 | + |
| 208 | + if (tid == 0) { |
| 209 | + *_scales = scale; |
| 210 | + } |
| 211 | +} |
| 212 | + |
| 213 | +void per_token_quant_fp16_fp8( |
| 214 | + Tensor& output, |
| 215 | + const Tensor& input, |
| 216 | + Tensor& scales |
| 217 | +) { |
| 218 | + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); |
| 219 | + TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional"); |
| 220 | + TORCH_CHECK(input.scalar_type() == c10::kHalf, "Input must be FP16 type"); |
| 221 | + |
| 222 | + Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous(); |
| 223 | + Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous(); |
| 224 | + |
| 225 | + const int64_t M = input.size(0); |
| 226 | + const int64_t N = input.size(1); |
| 227 | + |
| 228 | + const int32_t blocks = M; |
| 229 | + |
| 230 | + switch (N) { |
| 231 | + case 16: |
| 232 | + device_per_token_quant_fp16_to_fp8<128, 16> |
| 233 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 234 | + PTR<fp16_t>(contiguous_input), |
| 235 | + PTR<fp8_e4m3_t>(output), |
| 236 | + PTR<fp32_t>(contiguous_scales) |
| 237 | + ); |
| 238 | + break; |
| 239 | + case 32: |
| 240 | + device_per_token_quant_fp16_to_fp8<128, 32> |
| 241 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 242 | + PTR<fp16_t>(contiguous_input), |
| 243 | + PTR<fp8_e4m3_t>(output), |
| 244 | + PTR<fp32_t>(contiguous_scales) |
| 245 | + ); |
| 246 | + break; |
| 247 | + case 64: |
| 248 | + device_per_token_quant_fp16_to_fp8<128, 64> |
| 249 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 250 | + PTR<fp16_t>(contiguous_input), |
| 251 | + PTR<fp8_e4m3_t>(output), |
| 252 | + PTR<fp32_t>(contiguous_scales) |
| 253 | + ); |
| 254 | + break; |
| 255 | + case 512: |
| 256 | + device_per_token_quant_fp16_to_fp8<128, 512> |
| 257 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 258 | + PTR<fp16_t>(contiguous_input), |
| 259 | + PTR<fp8_e4m3_t>(output), |
| 260 | + PTR<fp32_t>(contiguous_scales) |
| 261 | + ); |
| 262 | + break; |
| 263 | + case 1024: |
| 264 | + device_per_token_quant_fp16_to_fp8<128, 1024> |
| 265 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 266 | + PTR<fp16_t>(contiguous_input), |
| 267 | + PTR<fp8_e4m3_t>(output), |
| 268 | + PTR<fp32_t>(contiguous_scales) |
| 269 | + ); |
| 270 | + break; |
| 271 | + case 3200: |
| 272 | + device_per_token_quant_fp16_to_fp8<128, 3200> |
| 273 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 274 | + PTR<fp16_t>(contiguous_input), |
| 275 | + PTR<fp8_e4m3_t>(output), |
| 276 | + PTR<fp32_t>(contiguous_scales) |
| 277 | + ); |
| 278 | + break; |
| 279 | + case 4096: |
| 280 | + device_per_token_quant_fp16_to_fp8<128, 4096> |
| 281 | + <<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 282 | + PTR<fp16_t>(contiguous_input), |
| 283 | + PTR<fp8_e4m3_t>(output), |
| 284 | + PTR<fp32_t>(contiguous_scales) |
| 285 | + ); |
| 286 | + break; |
| 287 | + case 12800: |
| 288 | + device_per_token_quant_fp16_to_fp8<256, 12800> |
| 289 | + <<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 290 | + PTR<fp16_t>(contiguous_input), |
| 291 | + PTR<fp8_e4m3_t>(output), |
| 292 | + PTR<fp32_t>(contiguous_scales) |
| 293 | + ); |
| 294 | + break; |
| 295 | + default: { |
| 296 | + static constexpr int TPB = 128; |
| 297 | + const int64_t shared_mem_size = N * sizeof(fp16_t); |
| 298 | + if (N % 8 == 0) { |
| 299 | + device_per_token_quant_fp16_to_fp8_vpt<TPB> |
| 300 | + <<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>( |
| 301 | + PTR<fp16_t>(contiguous_input), |
| 302 | + PTR<fp8_e4m3_t>(output), |
| 303 | + PTR<fp32_t>(contiguous_scales), |
| 304 | + N |
| 305 | + ); |
| 306 | + } else { |
| 307 | + device_per_token_quant_fp16_to_fp8_general<TPB> |
| 308 | + <<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>( |
| 309 | + PTR<fp16_t>(contiguous_input), |
| 310 | + PTR<fp8_e4m3_t>(output), |
| 311 | + PTR<fp32_t>(contiguous_scales), |
| 312 | + N |
| 313 | + ); |
| 314 | + } |
| 315 | + } |
| 316 | + } |
| 317 | + |
| 318 | + return; |
| 319 | +} |
| 320 | + |
| 321 | +} // namespace ops |
| 322 | +} // namespace lightllm |
0 commit comments