Skip to content

Commit d194d47

Browse files
committed
issue/1118: success qy int8 test
1 parent 0aeee9e commit d194d47

File tree

3 files changed

+64
-168
lines changed

3 files changed

+64
-168
lines changed

src/infiniop/ops/gptq_qyblas_gemm/info.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class GptqQyblasGemmInfo {
1212
GptqQyblasGemmInfo() = default;
1313

1414
public:
15-
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype;
15+
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype;
1616
size_t M, K, N, scales_size_0, scales_size_1;
1717
ptrdiff_t lda, ldb, result_ld;
1818
bool transpose_mat_1, transpose_mat_2, transpose_result;
@@ -27,13 +27,13 @@ class GptqQyblasGemmInfo {
2727
auto dtype = a_desc->dtype();
2828

2929
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
30+
CHECK_DTYPE(dtype, out_desc->dtype());
3031

3132
const infiniDtype_t weight_dtype = b_desc->dtype();
3233
CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
3334

3435
const infiniDtype_t scales_dtype = b_scales_desc->dtype();
3536
const infiniDtype_t zeros_dtype = b_zeros_desc->dtype();
36-
const infiniDtype_t out_dtype = out_desc->dtype();
3737

3838
size_t M = out_desc->shape()[0];
3939
size_t N = out_desc->shape()[1];
@@ -80,7 +80,7 @@ class GptqQyblasGemmInfo {
8080
ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0];
8181

8282
return utils::Result<GptqQyblasGemmInfo>(GptqQyblasGemmInfo{
83-
dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype,
83+
dtype, weight_dtype, scales_dtype, zeros_dtype,
8484
M, K, N, scales_size_0, scales_size_1,
8585
lda, ldb, result_ld,
8686
transpose_mat_1, transpose_mat_2, transpose_result});

src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#if defined ENABLE_QY_API
12
#include "../../../devices/nvidia/nvidia_handle.cuh"
23
#include "dlblas_ext.h"
34
#include "gptq_qyblas_gemm_nvidia.cuh"
@@ -93,16 +94,7 @@ infiniStatus_t Descriptor::calculate(void *workspace,
9394
return INFINI_STATUS_BAD_TENSOR_DTYPE;
9495
}
9596

96-
switch (_info.out_dtype) {
97-
case INFINI_DTYPE_F16:
98-
kernel_Ctype_ = CUDA_R_16F;
99-
break;
100-
case INFINI_DTYPE_BF16:
101-
kernel_Ctype_ = CUDA_R_16BF;
102-
break;
103-
default:
104-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
105-
}
97+
kernel_Ctype_ = kernel_Atype_;
10698

10799
switch (_info.scales_dtype) {
108100
case INFINI_DTYPE_F32:
@@ -178,12 +170,6 @@ infiniStatus_t Descriptor::calculate(void *workspace,
178170

179171
cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
180172
cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
181-
printf("a=%s, b=%s, c=%s\n",
182-
_info.transpose_mat_1 ? "true" : "false",
183-
_info.transpose_mat_2 ? "true" : "false",
184-
_info.transpose_result ? "true" : "false");
185-
printf("M-K-N:[%ld, %ld, %ld], lda-ldb-ldc:[%ld, %ld, %ld]\n", M, K, N, lda, ldb, result_ld);
186-
printf("quant type:%ld, bit:%ld, block_shape:%d\n", quant_type, bit, extParameters.a_group_size_m);
187173

188174
if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) {
189175
CHECK_STATUS(_opaque->internal->useCublas(
@@ -221,3 +207,4 @@ infiniStatus_t Descriptor::calculate(void *workspace,
221207
}
222208

223209
} // namespace op::gptq_qyblas_gemm::nvidia
210+
#endif

test/infiniop/gptq_qyblas_gemm.py

Lines changed: 58 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
# Test configurations
3030

3131
BLOCK_SIZE = [[128, 128]]
32-
M_list = [1, 7, 83, 512, 2048]
33-
N_list = [128, 512, 1024, 4096, 7748, 13824]
34-
K_list = [256, 4096, 5120, 3884, 13824]
32+
M_list = [1, 7]#, 83, 512, 2048]
33+
N_list = [128, 512]#, 1024, 4096, 7748, 13824]
34+
K_list = [256, 4096]#, 5120, 3884, 13824]
35+
_WEIGHT_DTYPES = [InfiniDtype.I8]
36+
3537
SEEDS = 0
3638

3739
def to_iter(x):
@@ -44,12 +46,13 @@ def to_iter(x):
4446
to_iter(K_list),
4547
to_iter(N_list),
4648
to_iter(BLOCK_SIZE),
49+
to_iter(_WEIGHT_DTYPES),
4750
)
4851
)
4952

5053

5154
# Data types used for testing
52-
_TENSOR_DTYPES = [InfiniDtype.F16]
55+
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
5356

5457

5558
DEBUG = False
@@ -108,164 +111,82 @@ def native_w8a16_block_int8_matmul(
108111
return C
109112

110113

111-
def native_w8a16_block_fp8_matmul(
112-
A,
113-
B,
114-
Bs,
115-
block_size,
116-
output_dtype: torch.float16,
117-
) -> torch.Tensor:
118-
return native_w8a16_block_int8_matmul(A, B, Bs, block_size, output_dtype)
119-
120-
121-
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
122-
torch.manual_seed(seed)
123-
factor_for_scale = 1e-2
124-
fp8_info = torch.finfo(torch.float8_e4m3fn)
125-
fp8_max, fp8_min = fp8_info.max, fp8_info.min
126-
127-
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
128-
#A_fp32 = A_fp32.fill_(1)
129-
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
130-
131-
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
132-
#B_fp32 = B_fp32.fill_(1)
133-
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
134-
135-
block_n, block_k = block_size[0], block_size[1]
136-
n_tiles = (N + block_n - 1) // block_n
137-
k_tiles = (K + block_k - 1) // block_k
138-
139-
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
140-
#As = As.fill_(1)
141-
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
142-
#Bs = Bs.fill_(1.5)
143-
#ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
144-
# out_dtype)
145-
ref_out = native_w8a16_block_fp8_matmul(A_fp32.to(torch.bfloat16), B_fp8, Bs, block_size, out_dtype)
146-
#out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
147-
148-
B_fp8_T = B_fp8.t()
149-
#print('B_fp8_T', B_fp8_T.size(), B_fp8_T)
150-
151-
Bs_T = Bs
152-
quant_type = 3
153-
bit = 8
154-
return ref_out, A_fp32.to(torch.bfloat16), B_fp8_T, Bs_T, Bs_T, quant_type, bit
155-
156-
157-
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
158-
torch.manual_seed(seed)
159-
factor_for_scale = 1e-2
160-
int8_info = torch.iinfo(torch.int8)
161-
int8_max, int8_min = int8_info.max, int8_info.min
162-
163-
A_fpb16 = torch.rand(M, K, dtype=torch.float32) / 10
164-
165-
166-
#A_fp32 = A_fp32.fill_(1)
167-
#A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
168-
169-
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max
170-
#B_fp32 = B_fp32.fill_(1)
171-
B_int8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
172-
173-
block_n, block_k = block_size[0], block_size[1]
174-
n_tiles = (N + block_n - 1) // block_n
175-
k_tiles = (K + block_k - 1) // block_k
176-
177-
A_fpb16 =A_fpb16.to(torch.float16)
178-
179-
#As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
180-
#As = As.fill_(1)
181-
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
182-
#Bs = Bs.fill_(1.5)
183-
#ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
184-
185-
ref_out = native_w8a16_block_fp8_matmul(A_fpb16, B_int8, Bs, block_size, out_dtype)
186-
#a_q, a_s = native_per_token_group_quant_int8(A_fpb16, block_k)
187-
#ref_out = native_w8a8_block_int8_matmul(a_q, B_int8, a_s, Bs, block_size, output_dtype=A_fpb16.dtype)
188-
##out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
189-
#print('Bs', Bs.size(), Bs.dtype)
190-
quant_type = 3
191-
bit = 8
192-
return ref_out, A_fpb16, B_int8, Bs, Bs, quant_type, bit
193-
194-
195-
def test_int8(
114+
def test(
196115
handle,
197116
device,
198117
M,
199118
K,
200119
N,
201120
block_size,
121+
weight_dtype=InfiniDtype.I8,
202122
dtype=InfiniDtype.BF16,
203123
sync=None,
204124
):
205125

206126
print(
207-
f"Testing int8 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, dtype:{InfiniDtypeNames[dtype]}"
127+
f"Testing int8 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, weight dtype:{InfiniDtypeNames[weight_dtype]}, dtype:{InfiniDtypeNames[dtype]}"
208128
)
209-
out_dtype = to_torch_dtype(dtype)
210-
ans, a, b_orig, b_scales, b_zeros, quant_type, bit = test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, SEEDS)
211-
b = b_orig.t()
212-
129+
quant_type = 3
130+
bit = 8
131+
132+
int8_info = torch.iinfo(torch.int8)
133+
int8_max, int8_min = int8_info.max, int8_info.min
134+
135+
block_n, block_k = block_size[0], block_size[1]
136+
n_tiles = (N + block_n - 1) // block_n
137+
k_tiles = (K + block_k - 1) // block_k
138+
213139
A = TestTensor(
214-
a.shape,
215-
a.stride(),
216-
InfiniDtype.F16,
217-
device,
218-
mode="manual",
219-
set_tensor=a,
220-
)
221-
B_orig = TestTensor(
222-
b_orig.shape,
223-
b_orig.stride(),
224-
InfiniDtype.I8,
225-
device,
226-
mode="manual",
227-
set_tensor=b_orig,
228-
)
229-
B = TestTensor(
230-
b.shape,
231-
b.stride(),
232-
InfiniDtype.I8,
140+
(M, K),
141+
None,
142+
dtype,
233143
device,
234-
mode="manual",
235-
set_tensor=b,
236144
)
145+
if weight_dtype == InfiniDtype.I8:
146+
B_orig = TestTensor(
147+
(N, K),
148+
None,
149+
weight_dtype,
150+
device,
151+
randint_low=int8_min,
152+
randint_high=int8_max,
153+
)
154+
B_torch = B_orig.torch_tensor().t()
155+
B = TestTensor(
156+
(K, N),
157+
B_torch.stride(),
158+
weight_dtype,
159+
device,
160+
mode="manual",
161+
set_tensor=B_torch,
162+
)
163+
237164
b_scales = TestTensor(
238-
b_scales.shape,
239-
b_scales.stride(),
165+
(n_tiles, k_tiles),
166+
None,
240167
InfiniDtype.F32,
241168
device,
242-
mode="manual",
243-
set_tensor=b_scales,
244169
)
170+
245171
b_zeros = TestTensor(
246-
b_zeros.shape,
247-
b_zeros.stride(),
172+
(n_tiles, k_tiles),
173+
None,
248174
InfiniDtype.F32,
249175
device,
250-
mode="manual",
251-
set_tensor=b_zeros,
176+
mode="zeros",
252177
)
178+
253179
out = TestTensor(
254-
ans.shape,
180+
(M, N),
255181
None,
256182
dtype,
257183
device,
184+
mode="zeros",
258185
)
259-
260-
print("a: ", A.torch_tensor().shape, A.torch_tensor().stride(), A.torch_tensor().dtype)
261-
print("b: ", B.torch_tensor().shape, B.torch_tensor().stride(), B.torch_tensor().dtype)
262-
print("scales: ", b_scales.torch_tensor().shape, b_scales.torch_tensor().dtype)
263-
print("zeros: ", b_zeros.torch_tensor().shape, b_zeros.torch_tensor().dtype)
264-
print("out: ", out.torch_tensor().shape, out.torch_tensor().dtype)
186+
265187
if sync is not None:
266188
sync()
267189

268-
269190
descriptor = infiniopOperatorDescriptor_t()
270191
check_error(
271192
LIBINFINIOP.infiniopCreateGptqQyblasGemmDescriptor(
@@ -278,7 +199,6 @@ def test_int8(
278199
b_zeros.descriptor,
279200
)
280201
)
281-
282202
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
283203

284204
for tensor in [out, A, B, b_scales, b_zeros]:
@@ -314,31 +234,20 @@ def lib_gptq_qyblas_gemm():
314234
if sync is not None:
315235
sync()
316236

317-
tmpa = out.torch_tensor().to(torch.float32).detach().to('cpu').numpy().flatten()
318-
tmpb = ans.to(torch.float32).to('cpu').detach().numpy().flatten()
237+
out_dtype = to_torch_dtype(dtype)
238+
ans = native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype)
319239

320-
atol = max(abs(tmpa - tmpb))
321-
322-
rtol = atol / (max(abs(tmpb)) + 1e-8)
323-
324-
325-
print("absolute error:%.4e"%(atol))
326-
print("relative error:%.4e"%(rtol))
327-
print(out.torch_tensor().device, ans.device)
328-
# print(out.torch_tensor())
329-
# print(ans)
330-
ans = ans.to(out.torch_tensor().device)
331240
rel_diff = (torch.mean(
332-
torch.abs(out.torch_tensor().to(torch.float32) - ans.to(torch.float32))) /
241+
torch.abs(out.actual_tensor().to(torch.float32) - ans.to(torch.float32))) /
333242
torch.mean(torch.abs(ans.to(torch.float32))))
334-
print(rel_diff)
243+
335244
assert rel_diff < 0.05
336245

337246

338247
# Profiling workflow
339248
if PROFILE:
340249
# fmt: off
341-
profile_operation("PyTorch", lambda: native_w8a16_block_fp8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS)
250+
profile_operation("PyTorch", lambda: native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS)
342251
profile_operation(" lib", lambda: lib_gptq_qyblas_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
343252
# fmt: on
344253

@@ -355,6 +264,6 @@ def lib_gptq_qyblas_gemm():
355264
NUM_ITERATIONS = args.num_iterations
356265

357266
for device in get_test_devices(args):
358-
test_operator(device, test_int8, _TEST_CASES, _TENSOR_DTYPES)
267+
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
359268

360269
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)