Skip to content

Commit af8cf6d

Browse files
committed
issue/1118: bit=4 error
1 parent d194d47 commit af8cf6d

File tree

2 files changed

+228
-43
lines changed

2 files changed

+228
-43
lines changed

src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,7 @@ infiniStatus_t Descriptor::calculate(void *workspace,
4343
int64_t quant_type,
4444
int64_t bit,
4545
void *stream) const {
46-
47-
int64_t M = static_cast<int64_t>(_info.M);
4846
int64_t K = static_cast<int64_t>(_info.K);
49-
int64_t N = static_cast<int64_t>(_info.N);
50-
int64_t scales_size_0 = static_cast<int64_t>(_info.scales_size_0);
51-
int64_t scales_size_1 = static_cast<int64_t>(_info.scales_size_1);
52-
int64_t lda = static_cast<int64_t>(_info.lda);
53-
int64_t ldb = static_cast<int64_t>(_info.ldb);
54-
int64_t result_ld = static_cast<int64_t>(_info.result_ld);
55-
bool transpose_mat_1 = _info.transpose_mat_1;
56-
bool transpose_mat_2 = _info.transpose_mat_2;
5747

5848
cudaDataType_t computeType_ = (cudaDataType_t)CUDA_R_32F;
5949
cudaDataType_t kernel_Atype_, kernel_Btype_, kernel_Ctype_, kernel_Stype_, kernel_Ztype_;
@@ -76,7 +66,6 @@ infiniStatus_t Descriptor::calculate(void *workspace,
7666

7767
if (4 == bit) {
7868
kernel_Atype_ = (cudaDataType_t)CUDA_R_4U;
79-
K = K * 2;
8069
}
8170
}
8271

@@ -127,11 +116,42 @@ infiniStatus_t Descriptor::calculate(void *workspace,
127116
float alpha = 1.0f;
128117
float beta = 0.0f;
129118

119+
bool transpose_mat_1 = _info.transpose_mat_1;
120+
bool transpose_mat_2 = _info.transpose_mat_2;
121+
int64_t M;
122+
int64_t N;
123+
int64_t lda;
124+
int64_t ldb;
125+
cublasOperation_t transa;
126+
cublasOperation_t transb;
127+
128+
if (transpose_mat_2) {
129+
M = static_cast<int64_t>(_info.N);
130+
N = static_cast<int64_t>(_info.M);
131+
lda = (bit == 4 ? static_cast<int64_t>(_info.ldb) * 2 : static_cast<int64_t>(_info.ldb));
132+
ldb = static_cast<int64_t>(_info.lda);
133+
std::swap(a, b);
134+
std::swap(kernel_Atype_, kernel_Btype_);
135+
transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
136+
transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
137+
} else {
138+
M = static_cast<int64_t>(_info.M);
139+
N = static_cast<int64_t>(_info.N);
140+
lda = static_cast<int64_t>(_info.lda);
141+
ldb = static_cast<int64_t>(_info.ldb);
142+
transa = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
143+
transb = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
144+
}
145+
int64_t scales_size_0 = static_cast<int64_t>(_info.scales_size_0);
146+
int64_t scales_size_1 = static_cast<int64_t>(_info.scales_size_1);
147+
148+
int64_t result_ld = static_cast<int64_t>(_info.result_ld);
149+
130150
dlblasExtQuantParametersV2_t extParameters;
131151

132152
if (quant_type == 0) {
133-
extParameters.a_group_size_m = M / scales_size_0;
134-
extParameters.a_group_size_k = K / scales_size_1;
153+
extParameters.a_group_size_m = M / scales_size_1;
154+
extParameters.a_group_size_k = K / scales_size_0;
135155
extParameters.a_zeropoints_type = kernel_Ztype_;
136156
extParameters.a_zeropoints = b_zeros;
137157
extParameters.a_scales_type = kernel_Stype_;
@@ -146,13 +166,13 @@ infiniStatus_t Descriptor::calculate(void *workspace,
146166
} else if (quant_type == 2 || quant_type == 3) {
147167
// calculate block_shape according weight/scales shape
148168
int block_shape = 128;
149-
while ((N + block_shape - 1) / block_shape < scales_size_0) {
169+
while ((M + block_shape - 1) / block_shape < scales_size_0) {
150170
block_shape /= 2;
151171
if (block_shape < 32) {
152172
fprintf(stderr,
153173
"INTERNAL ASSERT FAILED: block_shape >= 32\n"
154174
"Invalid fp blockwise linear arguments. Weight: [%d, %d]. Scales: [%d, %d].\n",
155-
(int)N, (int)K, (int)scales_size_0, (int)scales_size_1);
175+
(int)M, (int)K, (int)scales_size_0, (int)scales_size_1);
156176
abort();
157177
}
158178
}
@@ -168,9 +188,6 @@ infiniStatus_t Descriptor::calculate(void *workspace,
168188
extParameters.a_scales = b_scales;
169189
}
170190

171-
cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
172-
cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
173-
174191
if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) {
175192
CHECK_STATUS(_opaque->internal->useCublas(
176193
(cudaStream_t)stream,
@@ -179,16 +196,16 @@ infiniStatus_t Descriptor::calculate(void *workspace,
179196
dlblasGemmExV2(handle,
180197
transa,
181198
transb,
182-
N,
183199
M,
200+
N,
184201
K,
185202
&alpha,
186-
b,
187-
kernel_Btype_,
188-
ldb,
189203
a,
190204
kernel_Atype_,
191205
lda,
206+
b,
207+
kernel_Btype_,
208+
ldb,
192209
&beta,
193210
out,
194211
kernel_Ctype_,

test/infiniop/gptq_qyblas_gemm.py

Lines changed: 189 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,14 @@ def to_iter(x):
5151
)
5252

5353

54+
_TEST_CASES_W4 = [(32768, 3584, 4608, [128, 128], InfiniDtype.U8),]
55+
56+
5457
# Data types used for testing
5558
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
5659

60+
_TENSOR_DTYPES_W4 = [InfiniDtype.F16]
61+
5762

5863
DEBUG = False
5964
PROFILE = False
@@ -129,9 +134,6 @@ def test(
129134
quant_type = 3
130135
bit = 8
131136

132-
int8_info = torch.iinfo(torch.int8)
133-
int8_max, int8_min = int8_info.max, int8_info.min
134-
135137
block_n, block_k = block_size[0], block_size[1]
136138
n_tiles = (N + block_n - 1) // block_n
137139
k_tiles = (K + block_k - 1) // block_k
@@ -143,23 +145,28 @@ def test(
143145
device,
144146
)
145147
if weight_dtype == InfiniDtype.I8:
146-
B_orig = TestTensor(
147-
(N, K),
148-
None,
149-
weight_dtype,
150-
device,
151-
randint_low=int8_min,
152-
randint_high=int8_max,
153-
)
154-
B_torch = B_orig.torch_tensor().t()
155-
B = TestTensor(
156-
(K, N),
157-
B_torch.stride(),
158-
weight_dtype,
159-
device,
160-
mode="manual",
161-
set_tensor=B_torch,
162-
)
148+
_info = torch.iinfo(torch.int8)
149+
elif weight_dtype == InfiniDtype.U8:
150+
_info = torch.iinfo(torch.uint8)
151+
elif weight_dtype == InfiniDtype.F8:
152+
_info = torch.iinfo(float8_e4m3fn)
153+
B_orig = TestTensor(
154+
(N, K),
155+
None,
156+
weight_dtype,
157+
device,
158+
randint_low=_info.min,
159+
randint_high=_info.max,
160+
)
161+
B_torch = B_orig.torch_tensor().t()
162+
B = TestTensor(
163+
(K, N),
164+
B_torch.stride(),
165+
weight_dtype,
166+
device,
167+
mode="manual",
168+
set_tensor=B_torch,
169+
)
163170

164171
b_scales = TestTensor(
165172
(n_tiles, k_tiles),
@@ -254,6 +261,165 @@ def lib_gptq_qyblas_gemm():
254261
check_error(LIBINFINIOP.infiniopDestroyGptqQyblasGemmDescriptor(descriptor))
255262

256263

264+
def test_w4(
265+
handle,
266+
device,
267+
M,
268+
K,
269+
N,
270+
block_size,
271+
weight_dtype=InfiniDtype.I8,
272+
dtype=InfiniDtype.BF16,
273+
sync=None,
274+
):
275+
print(
276+
f"Testing w4 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, weight dtype:{InfiniDtypeNames[weight_dtype]}, dtype:{InfiniDtypeNames[dtype]}"
277+
)
278+
quant_type = 0
279+
bit = 4
280+
281+
block_n, block_k = block_size[0], block_size[1]
282+
n_tiles = (N + block_n - 1) // block_n
283+
k_tiles = (K + block_k - 1) // block_k
284+
285+
A = TestTensor(
286+
(M, K),
287+
None,
288+
dtype,
289+
device,
290+
)
291+
if weight_dtype == InfiniDtype.I8:
292+
_info = torch.iinfo(torch.int8)
293+
elif weight_dtype == InfiniDtype.U8:
294+
_info = torch.iinfo(torch.uint8)
295+
elif weight_dtype == InfiniDtype.F8:
296+
_info = torch.iinfo(float8_e4m3fn)
297+
# B_orig = TestTensor(
298+
# (N, K // 2),
299+
# None,
300+
# weight_dtype,
301+
# device,
302+
# randint_low=_info.min,
303+
# randint_high=_info.max,
304+
# )
305+
# B_torch = B_orig.torch_tensor().t()
306+
# B = TestTensor(
307+
# (K // 2, N),
308+
# B_torch.stride(),
309+
# weight_dtype,
310+
# device,
311+
# mode="manual",
312+
# set_tensor=B_torch,
313+
# )
314+
315+
B = TestTensor(
316+
(K // 2, N),
317+
None,
318+
weight_dtype,
319+
device,
320+
randint_low=_info.min,
321+
randint_high=_info.max,
322+
)
323+
324+
b_scales = TestTensor(
325+
(k_tiles, N),
326+
None,
327+
dtype,
328+
device,
329+
)
330+
331+
b_zeros = TestTensor(
332+
(k_tiles, N),
333+
None,
334+
dtype,
335+
device,
336+
mode="zeros",
337+
)
338+
339+
out = TestTensor(
340+
(M, N),
341+
None,
342+
dtype,
343+
device,
344+
mode="zeros",
345+
)
346+
347+
print("A", A.torch_tensor().shape, A.torch_tensor().dtype, A.torch_tensor().stride())
348+
print("B", B.torch_tensor().shape, B.torch_tensor().dtype, B.torch_tensor().stride())
349+
print("scales", b_scales.torch_tensor().shape, b_scales.torch_tensor().dtype, b_scales.torch_tensor().stride())
350+
print("zeros", b_zeros.torch_tensor().shape, b_zeros.torch_tensor().dtype, b_zeros.torch_tensor().stride())
351+
print("out", out.torch_tensor().shape, out.torch_tensor().dtype, out.torch_tensor().stride())
352+
353+
if sync is not None:
354+
sync()
355+
356+
descriptor = infiniopOperatorDescriptor_t()
357+
check_error(
358+
LIBINFINIOP.infiniopCreateGptqQyblasGemmDescriptor(
359+
handle,
360+
ctypes.byref(descriptor),
361+
out.descriptor,
362+
A.descriptor,
363+
B.descriptor,
364+
b_scales.descriptor,
365+
b_zeros.descriptor,
366+
)
367+
)
368+
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
369+
370+
for tensor in [out, A, B, b_scales, b_zeros]:
371+
tensor.destroy_desc()
372+
373+
workspace_size = c_uint64(0)
374+
check_error(
375+
LIBINFINIOP.infiniopGetGptqQyblasGemmWorkspaceSize(
376+
descriptor, ctypes.byref(workspace_size)
377+
)
378+
)
379+
workspace = TestWorkspace(workspace_size.value, A.device)
380+
381+
def lib_gptq_qyblas_gemm():
382+
check_error(
383+
LIBINFINIOP.infiniopGptqQyblasGemm(
384+
descriptor,
385+
workspace.data(),
386+
workspace_size.value,
387+
out.data(),
388+
A.data(),
389+
B.data(),
390+
b_scales.data(),
391+
b_zeros.data(),
392+
quant_type,
393+
bit,
394+
None,
395+
)
396+
)
397+
398+
lib_gptq_qyblas_gemm()
399+
400+
if sync is not None:
401+
sync()
402+
403+
out_dtype = to_torch_dtype(dtype)
404+
ans = native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype)
405+
406+
rel_diff = (torch.mean(
407+
torch.abs(out.actual_tensor().to(torch.float32) - ans.to(torch.float32))) /
408+
torch.mean(torch.abs(ans.to(torch.float32))))
409+
410+
assert rel_diff < 0.05
411+
412+
413+
# Profiling workflow
414+
if PROFILE:
415+
# fmt: off
416+
profile_operation("PyTorch", lambda: native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS)
417+
profile_operation(" lib", lambda: lib_gptq_qyblas_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
418+
# fmt: on
419+
420+
check_error(LIBINFINIOP.infiniopDestroyGptqQyblasGemmDescriptor(descriptor))
421+
422+
257423
if __name__ == "__main__":
258424
args = get_args()
259425

@@ -263,7 +429,9 @@ def lib_gptq_qyblas_gemm():
263429
NUM_PRERUN = args.num_prerun
264430
NUM_ITERATIONS = args.num_iterations
265431

432+
# for device in get_test_devices(args):
433+
# test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
266434
for device in get_test_devices(args):
267-
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
435+
test_operator(device, test_w4, _TEST_CASES_W4, _TENSOR_DTYPES_W4)
268436

269437
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)