Skip to content

Commit 4a6d7bc

Browse files
committed
per token quant fp8 and int8 support fp16 input
1 parent 07f2f62 commit 4a6d7bc

9 files changed

Lines changed: 798 additions & 64 deletions

File tree

csrc/ops_bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ PYBIND11_MODULE(_C, m) {
1111
m.def("pre_tp_norm_bf16", &pre_tp_norm_bf16, "PRE TP NORM (CUDA)");
1212
m.def("post_tp_norm_bf16", &post_tp_norm_bf16, "POST TP NORM (CUDA)");
1313
m.def("per_token_quant_bf16_fp8", &per_token_quant_bf16_fp8, "PER TOKEN QUANT FP8 (CUDA)");
14+
m.def("per_token_quant_fp16_fp8", &per_token_quant_fp16_fp8, "PER TOKEN QUANT FP8 (CUDA)");
1415
m.def("per_token_quant_bf16_int8", &per_token_quant_bf16_int8, "PER TOKEN QUANT INT8 (CUDA)");
16+
m.def("per_token_quant_fp16_int8", &per_token_quant_fp16_int8, "PER TOKEN QUANT INT8 (CUDA)");
1517
m.def("add_norm_quant_bf16_fp8", &add_norm_quant_bf16_fp8, "ADD NORM QUANT FUSED (CUDA)");
1618
m.def("gelu_per_token_quant_bf16_fp8", &gelu_per_token_quant_bf16_fp8, "GELU QUANT FUSED (CUDA)");
1719
m.def("cutlass_scaled_mm", &cutlass_scaled_mm, "CUTLASS SCALED MM (CUDA)");
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
#include "ops_common.h"
2+
#include "reduce/sm70.cuh"
3+
4+
5+
namespace lightllm {
6+
namespace ops {
7+
8+
using namespace lightllm;
9+
10+
// CUDA kernel for per token quantization from FP16 to FP8
11+
template<int32_t TPB>
12+
__global__ void device_per_token_quant_fp16_to_fp8_general(
13+
const fp16_t* __restrict__ input, // Input tensor in FP16 format
14+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
15+
fp32_t* __restrict__ scales, // Output scales for each token
16+
const int64_t N
17+
) {
18+
const int32_t bid = blockIdx.x;
19+
const int32_t tid = threadIdx.x;
20+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
21+
22+
const fp16_t* _input = input + bid * N; // Input pointer for the token
23+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
24+
25+
fp32_t* _scales;
26+
_scales = scales + bid;
27+
28+
// Local arrays for intermediate storage
29+
fp8_e4m3_t local_f8;
30+
fp16_t local_fp16;
31+
32+
extern __shared__ fp16_t workspace1[];
33+
34+
fp32_t local_max = -FLT_MAX;
35+
for (int32_t i = tid; i < N; i += TPB) {
36+
local_fp16 = _input[i];
37+
workspace1[i] = local_fp16;
38+
39+
fp32_t tmp = cvt_f16_f32(local_fp16);
40+
local_max = fmaxf(local_max, fabsf(tmp));
41+
}
42+
43+
// Reduce the maximum value across the block
44+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
45+
46+
// Compute the scale factor with epsilon to avoid division by zero
47+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
48+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
49+
50+
for (int32_t i = tid; i < N; i += TPB) {
51+
local_fp16 = workspace1[i];
52+
53+
fp32_t tmp = cvt_f16_f32(local_fp16);
54+
fp32_t x = tmp / scale;
55+
local_f8 = fp8_e4m3_t(x);
56+
57+
_output[i] = local_f8;
58+
}
59+
60+
if (tid == 0) {
61+
*_scales = scale;
62+
}
63+
}
64+
65+
// CUDA kernel for per token quantization from FP16 to FP8
66+
template<int32_t TPB>
67+
__global__ void device_per_token_quant_fp16_to_fp8_vpt(
68+
const fp16_t* __restrict__ input, // Input tensor in FP16 format
69+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
70+
fp32_t* __restrict__ scales, // Output scales for each token
71+
const int32_t N
72+
) {
73+
constexpr int32_t VPT = 8;
74+
75+
const int32_t bid = blockIdx.x;
76+
const int32_t tid = threadIdx.x;
77+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
78+
79+
const fp16_t* _input = input + bid * N; // Input pointer for the token
80+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
81+
82+
fp32_t* _scales;
83+
_scales = scales + bid;
84+
85+
// Local arrays for intermediate storage
86+
fp8x4_e4m3_t local_f8[VPT / 4];
87+
fp16x2_t local_fp16[VPT / 2];
88+
89+
extern __shared__ fp16x2_t workspace2[];
90+
91+
fp32_t local_max = -FLT_MAX;
92+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
93+
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
94+
vec_copy<sizeof(fp16_t) * VPT>(_input + i, local_fp16);
95+
96+
vec_copy<sizeof(fp16_t) * VPT>(local_fp16, workspace2 + (i >> 1));
97+
98+
// Compute the max for the VPT elements.
99+
#pragma unroll
100+
for (int32_t j = 0; j < VPT / 2; j++) {
101+
fp32x2_t tmp = fp16x2_to_fp32x2(local_fp16[j]);
102+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
103+
local_max = fmaxf(local_max, max);
104+
}
105+
}
106+
107+
// Reduce the maximum value across the block
108+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
109+
110+
// Compute the scale factor with epsilon to avoid division by zero
111+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
112+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
113+
114+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
115+
vec_copy<sizeof(fp16_t) * VPT>(workspace2 + (i >> 1), local_fp16);
116+
117+
#pragma unroll
118+
for (int32_t j = 0; j < VPT / 4; j++) {
119+
fp32x2_t x = fp16x2_to_fp32x2(local_fp16[2 * j + 0]);
120+
fp32x2_t y = fp16x2_to_fp32x2(local_fp16[2 * j + 1]);
121+
fp32x4_t ret = make_float4(
122+
x.x / scale,
123+
x.y / scale,
124+
y.x / scale,
125+
y.y / scale
126+
);
127+
local_f8[j] = fp8x4_e4m3_t(ret);
128+
}
129+
130+
vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i);
131+
}
132+
133+
if (tid == 0) {
134+
*_scales = scale;
135+
}
136+
}
137+
138+
// CUDA kernel for per token quantization from FP16 to FP8
139+
template<int32_t TPB, int32_t N>
140+
__global__ void device_per_token_quant_fp16_to_fp8(
141+
const fp16_t* __restrict__ input, // Input tensor in FP16 format
142+
fp8_e4m3_t* __restrict__ output, // Output tensor in FP8 format
143+
fp32_t* __restrict__ scales // Output scales for each token
144+
) {
145+
constexpr int32_t VPT = 8;
146+
147+
static_assert(N % 2 == 0, "N must be even.");
148+
static_assert(N % VPT == 0, "N must be a multiple of VPT.");
149+
150+
const int32_t bid = blockIdx.x;
151+
const int32_t tid = threadIdx.x;
152+
constexpr fp32_t FP8_E4M3_MAX = 448.0f; // Maximum value representable in FP8 E4M3 format
153+
154+
const fp16_t* _input = input + bid * N; // Input pointer for the token
155+
fp8_e4m3_t* _output = output + bid * N; // Output pointer for the token
156+
157+
fp32_t* _scales;
158+
_scales = scales + bid;
159+
160+
// Local arrays for intermediate storage
161+
fp8x4_e4m3_t local_f8[VPT / 4];
162+
fp16x2_t local_fp16[VPT / 2];
163+
164+
__shared__ fp16x2_t workspace[N / 2];
165+
166+
fp32_t local_max = -FLT_MAX;
167+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
168+
// Load VPT FP16 elements from global memory (_X) into local vector (local_x).
169+
vec_copy<sizeof(fp16_t) * VPT>(_input + i, local_fp16);
170+
171+
vec_copy<sizeof(fp16_t) * VPT>(local_fp16, workspace + (i >> 1));
172+
173+
// Compute the max for the VPT elements.
174+
#pragma unroll
175+
for (int32_t j = 0; j < VPT / 2; j++) {
176+
fp32x2_t tmp = fp16x2_to_fp32x2(local_fp16[j]);
177+
fp32_t max = fmaxf(fabsf(tmp.x), fabsf(tmp.y));
178+
local_max = fmaxf(local_max, max);
179+
}
180+
}
181+
182+
// Reduce the maximum value across the block
183+
const fp32_t reduced_max = lightllm::reduce::sm70::sync_block_reduce_max_f32<TPB>(local_max);
184+
185+
// Compute the scale factor with epsilon to avoid division by zero
186+
constexpr fp32_t epsilon = 1.0f / (FP8_E4M3_MAX * 512.0f);
187+
const fp32_t scale = fmaxf(epsilon, reduced_max / FP8_E4M3_MAX);
188+
189+
for (int32_t i = tid * VPT; i < N; i += TPB * VPT) {
190+
vec_copy<sizeof(fp16_t) * VPT>(workspace + (i >> 1), local_fp16);
191+
192+
#pragma unroll
193+
for (int32_t j = 0; j < VPT / 4; j++) {
194+
fp32x2_t x = fp16x2_to_fp32x2(local_fp16[2 * j + 0]);
195+
fp32x2_t y = fp16x2_to_fp32x2(local_fp16[2 * j + 1]);
196+
fp32x4_t ret = make_float4(
197+
x.x / scale,
198+
x.y / scale,
199+
y.x / scale,
200+
y.y / scale
201+
);
202+
local_f8[j] = fp8x4_e4m3_t(ret);
203+
}
204+
205+
vec_copy<sizeof(fp8_e4m3_t) * VPT>(local_f8, _output + i);
206+
}
207+
208+
if (tid == 0) {
209+
*_scales = scale;
210+
}
211+
}
212+
213+
void per_token_quant_fp16_fp8(
214+
Tensor& output,
215+
const Tensor& input,
216+
Tensor& scales
217+
) {
218+
TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor");
219+
TORCH_CHECK(input.dim() == 2, "Input must be 2-dimensional");
220+
TORCH_CHECK(input.scalar_type() == c10::kHalf, "Input must be FP16 type");
221+
222+
Tensor contiguous_input = input.is_contiguous() ? input : input.contiguous();
223+
Tensor contiguous_scales = scales.is_contiguous() ? scales : scales.contiguous();
224+
225+
const int64_t M = input.size(0);
226+
const int64_t N = input.size(1);
227+
228+
const int32_t blocks = M;
229+
230+
switch (N) {
231+
case 16:
232+
device_per_token_quant_fp16_to_fp8<128, 16>
233+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
234+
PTR<fp16_t>(contiguous_input),
235+
PTR<fp8_e4m3_t>(output),
236+
PTR<fp32_t>(contiguous_scales)
237+
);
238+
break;
239+
case 32:
240+
device_per_token_quant_fp16_to_fp8<128, 32>
241+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
242+
PTR<fp16_t>(contiguous_input),
243+
PTR<fp8_e4m3_t>(output),
244+
PTR<fp32_t>(contiguous_scales)
245+
);
246+
break;
247+
case 64:
248+
device_per_token_quant_fp16_to_fp8<128, 64>
249+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
250+
PTR<fp16_t>(contiguous_input),
251+
PTR<fp8_e4m3_t>(output),
252+
PTR<fp32_t>(contiguous_scales)
253+
);
254+
break;
255+
case 512:
256+
device_per_token_quant_fp16_to_fp8<128, 512>
257+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
258+
PTR<fp16_t>(contiguous_input),
259+
PTR<fp8_e4m3_t>(output),
260+
PTR<fp32_t>(contiguous_scales)
261+
);
262+
break;
263+
case 1024:
264+
device_per_token_quant_fp16_to_fp8<128, 1024>
265+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
266+
PTR<fp16_t>(contiguous_input),
267+
PTR<fp8_e4m3_t>(output),
268+
PTR<fp32_t>(contiguous_scales)
269+
);
270+
break;
271+
case 3200:
272+
device_per_token_quant_fp16_to_fp8<128, 3200>
273+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
274+
PTR<fp16_t>(contiguous_input),
275+
PTR<fp8_e4m3_t>(output),
276+
PTR<fp32_t>(contiguous_scales)
277+
);
278+
break;
279+
case 4096:
280+
device_per_token_quant_fp16_to_fp8<128, 4096>
281+
<<<blocks, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
282+
PTR<fp16_t>(contiguous_input),
283+
PTR<fp8_e4m3_t>(output),
284+
PTR<fp32_t>(contiguous_scales)
285+
);
286+
break;
287+
case 12800:
288+
device_per_token_quant_fp16_to_fp8<256, 12800>
289+
<<<blocks, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
290+
PTR<fp16_t>(contiguous_input),
291+
PTR<fp8_e4m3_t>(output),
292+
PTR<fp32_t>(contiguous_scales)
293+
);
294+
break;
295+
default: {
296+
static constexpr int TPB = 128;
297+
const int64_t shared_mem_size = N * sizeof(fp16_t);
298+
if (N % 8 == 0) {
299+
device_per_token_quant_fp16_to_fp8_vpt<TPB>
300+
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
301+
PTR<fp16_t>(contiguous_input),
302+
PTR<fp8_e4m3_t>(output),
303+
PTR<fp32_t>(contiguous_scales),
304+
N
305+
);
306+
} else {
307+
device_per_token_quant_fp16_to_fp8_general<TPB>
308+
<<<blocks, TPB, shared_mem_size, at::cuda::getCurrentCUDAStream()>>>(
309+
PTR<fp16_t>(contiguous_input),
310+
PTR<fp8_e4m3_t>(output),
311+
PTR<fp32_t>(contiguous_scales),
312+
N
313+
);
314+
}
315+
}
316+
}
317+
318+
return;
319+
}
320+
321+
} // namespace ops
322+
} // namespace lightllm

0 commit comments

Comments
 (0)