|
1 | 1 | import math |
2 | 2 | import platform |
3 | | -import random |
4 | 3 | import time |
5 | 4 |
|
6 | | -import einops |
7 | 5 | from packaging import version |
8 | 6 | import pytest |
9 | 7 | import torch |
10 | 8 |
|
11 | 9 | import bitsandbytes as bnb |
12 | 10 | from bitsandbytes import functional as F |
13 | 11 | from tests.helpers import ( |
14 | | - BOOLEAN_TUPLES, |
15 | 12 | TRUE_FALSE, |
16 | 13 | describe_dtype, |
17 | 14 | get_available_devices, |
@@ -339,280 +336,6 @@ def test_stable_embedding(): |
339 | 336 | layer.reset_parameters() |
340 | 337 |
|
341 | 338 |
|
342 | | -def quant(x): |
343 | | - max1 = torch.abs(x).max() |
344 | | - x = torch.round(x / max1 * 127) |
345 | | - return max1, x.to(torch.int8) |
346 | | - |
347 | | - |
348 | | -def dequant(c, maxC): |
349 | | - return c.float() * (maxC / 127) |
350 | | - |
351 | | - |
352 | | -def mm_dequant(maxA, maxB, C): |
353 | | - return C.float() * (maxA / 127) * (maxB / 127) |
354 | | - |
355 | | - |
356 | | -def quant_multi(x, dim): |
357 | | - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) |
358 | | - max1[max1 == 0] = 1.0 |
359 | | - x = torch.round(x / max1 * 127) |
360 | | - return max1, x.to(torch.int8) |
361 | | - |
362 | | - |
363 | | -def quant_multi_chunk(x, dim, chunk_size=32): |
364 | | - if dim == 1: |
365 | | - x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size) |
366 | | - max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True) |
367 | | - max1 = torch.tile(max1, (1, 1, x.shape[1])) |
368 | | - max1 = max1.view(x.shape) |
369 | | - elif dim == 0: |
370 | | - x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size) |
371 | | - max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) |
372 | | - max1 = torch.tile(max1, (x.shape[0], 1, 1)) |
373 | | - max1 = max1.view(x.shape) |
374 | | - max1[max1 == 0] = 1.0 |
375 | | - x = torch.round(x / max1 * 127) |
376 | | - return max1, x.to(torch.int8) |
377 | | - |
378 | | - |
379 | | -def mean(xx): |
380 | | - return sum(xx) / float(len(xx)) |
381 | | - |
382 | | - |
383 | | -methods = { |
384 | | - "linear": ( |
385 | | - lambda x, dim: quant(x), |
386 | | - lambda x, dim: quant(x), |
387 | | - dequant, |
388 | | - dequant, |
389 | | - mm_dequant, |
390 | | - ), |
391 | | - "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant), |
392 | | -} |
393 | | - |
394 | | - |
395 | | -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") |
396 | | -class TestIGEMMFunctional: |
397 | | - @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) |
398 | | - @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) |
399 | | - @pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) |
400 | | - @pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) |
401 | | - def test_approx_igemm(self, dim1, dim2, quant_methods, batched): |
402 | | - dim1 = dim1 - (dim1 % 32) |
403 | | - dim2 = dim2 - (dim2 % 32) |
404 | | - errors = [] |
405 | | - relerrors = [] |
406 | | - # print("") |
407 | | - for i in range(5): |
408 | | - if batched: |
409 | | - A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") |
410 | | - B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") |
411 | | - maxA, Ac = quant_methods[0](A, 2) |
412 | | - maxB, Bc = quant_methods[1](B, 1) |
413 | | - else: |
414 | | - A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") |
415 | | - B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") |
416 | | - maxA, Ac = quant_methods[0](A, 1) |
417 | | - maxB, Bc = quant_methods[1](B, 0) |
418 | | - torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) |
419 | | - if batched: |
420 | | - out2 = torch.bmm(A, B) |
421 | | - C = torch.bmm(Ac.float(), Bc.float()) |
422 | | - else: |
423 | | - out2 = torch.mm(A, B) |
424 | | - C = F.igemm(Ac, Bc) |
425 | | - out = quant_methods[4](maxA, maxB, C) |
426 | | - std = out2.std() |
427 | | - out /= std |
428 | | - out2 /= std |
429 | | - err = torch.abs(out - out2) |
430 | | - relerr = err / torch.abs(out2) |
431 | | - errors.append(err.mean().item()) |
432 | | - relerrors.append(relerr.mean().item()) |
433 | | - # print(mean(errors)) |
434 | | - # print(mean(relerrors)) |
435 | | - |
436 | | - @pytest.mark.parametrize("hidden_dim", [32, 256], ids=id_formatter("hidden_dim")) |
437 | | - @pytest.mark.parametrize("batch_dim", [16, 256], ids=id_formatter("batch_dim")) |
438 | | - @pytest.mark.parametrize("seq_dim", [16, 256], ids=id_formatter("seq_dim")) |
439 | | - @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) |
440 | | - def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim): |
441 | | - if ( |
442 | | - torch.version.cuda == "13.0" |
443 | | - and torch.__version__ >= (2, 10) |
444 | | - and not any(transpose) |
445 | | - and batch_dim == 256 |
446 | | - and seq_dim == 256 |
447 | | - ): |
448 | | - pytest.xfail("Failure due to regression in cuBLAS for CUDA Toolkit 13.0.2.") |
449 | | - |
450 | | - hidden_dim = hidden_dim - (hidden_dim % 32) |
451 | | - batch_dim = batch_dim - (batch_dim % 16) |
452 | | - seq_dim = seq_dim - (seq_dim % 16) |
453 | | - for i in range(k): |
454 | | - shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) |
455 | | - shapeB = ( |
456 | | - (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) |
457 | | - ) |
458 | | - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) |
459 | | - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) |
460 | | - if not transpose[0] and not transpose[1]: |
461 | | - out2 = torch.matmul(A.float(), B.float()) |
462 | | - out = F.igemm(A, B) |
463 | | - elif not transpose[0] and transpose[1]: |
464 | | - out2 = torch.matmul(A.float(), B.t().float()) |
465 | | - out = F.igemm(A, B.t()) |
466 | | - elif transpose[0] and not transpose[1]: |
467 | | - out2 = torch.matmul(A.t().float(), B.float()) |
468 | | - out = F.igemm(A.t(), B) |
469 | | - elif transpose[0] and transpose[1]: |
470 | | - out2 = torch.matmul(A.t().float(), B.t().float()) |
471 | | - out = F.igemm(A.t(), B.t()) |
472 | | - |
473 | | - torch.testing.assert_close(out.float(), out2) |
474 | | - |
475 | | - for i in range(k): |
476 | | - shapeA = (batch_dim, seq_dim, hidden_dim) |
477 | | - shapeB = ( |
478 | | - (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) |
479 | | - ) |
480 | | - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) |
481 | | - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) |
482 | | - if not transpose[0] and not transpose[1]: |
483 | | - out2 = torch.matmul(A.float(), B.float()) |
484 | | - out = F.igemm(A, B) |
485 | | - elif not transpose[0] and transpose[1]: |
486 | | - out2 = torch.matmul(A.float(), B.t().float()) |
487 | | - out = F.igemm(A, B.t()) |
488 | | - |
489 | | - torch.testing.assert_close(out.float(), out2) |
490 | | - |
491 | | - @pytest.mark.parametrize("seq_dim", [32, 256, 512], ids=id_formatter("seq_dim")) |
492 | | - @pytest.mark.parametrize("hidden_dim", [64, 1024, 4096], ids=id_formatter("hidden_dim")) |
493 | | - @pytest.mark.parametrize("batch_dim", [2, 8, 16], ids=id_formatter("batch_dim")) |
494 | | - def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): |
495 | | - seq_dim = seq_dim - (seq_dim % 32) |
496 | | - hidden_dim = hidden_dim - (hidden_dim % 32) |
497 | | - batch_dim = batch_dim - (batch_dim % 2) |
498 | | - for i in range(25): |
499 | | - A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) |
500 | | - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) |
501 | | - out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) |
502 | | - iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) |
503 | | - out = F.igemm(A, B, out=iout) |
504 | | - |
505 | | - torch.testing.assert_close(out.float(), out2) |
506 | | - |
507 | | - @pytest.mark.parametrize("seq_dim", [32, 512], ids=id_formatter("seq_dim")) |
508 | | - @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) |
509 | | - @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) |
510 | | - @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) |
511 | | - def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): |
512 | | - def min_max(x): |
513 | | - maxA = torch.amax(x, dim=2, keepdim=True) |
514 | | - minA = torch.amin(x, dim=2, keepdim=True) |
515 | | - scale = (maxA - minA) / 2.0 |
516 | | - return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale |
517 | | - |
518 | | - seq_dim = seq_dim - (seq_dim % 16) |
519 | | - hidden_dim = hidden_dim - (hidden_dim % 16) |
520 | | - batch_dim = batch_dim - (batch_dim % 2) |
521 | | - errs = [] |
522 | | - relerrs = [] |
523 | | - errs2 = [] |
524 | | - relerrs2 = [] |
525 | | - for i in range(k): |
526 | | - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") |
527 | | - if transpose: |
528 | | - B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") |
529 | | - else: |
530 | | - B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") |
531 | | - Ac, minA, scale = min_max(A) |
532 | | - if transpose: |
533 | | - maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) |
534 | | - out = F.igemm(Ac, Bc.t()) |
535 | | - out2 = torch.matmul(A, B.t()) |
536 | | - offset = B.t().sum(0) * (minA + scale) |
537 | | - out = out.float() |
538 | | - out = (out * maxB.t() * scale / (127 * 127)) + offset |
539 | | - |
540 | | - maxA, Ac = quant_multi(A, dim=2) |
541 | | - out3 = F.igemm(Ac, Bc.t()) |
542 | | - out3 = mm_dequant(maxA, maxB.t(), out3) |
543 | | - else: |
544 | | - maxB, Bc = quant_multi(B, dim=0) |
545 | | - offset = B.sum(0) * (minA + scale) |
546 | | - out = F.igemm(Ac, Bc) |
547 | | - out2 = torch.matmul(A, B) |
548 | | - out = out.float() |
549 | | - out = (out * maxB * scale / (127 * 127)) + offset |
550 | | - |
551 | | - maxA, Ac = quant_multi(A, dim=2) |
552 | | - out3 = F.igemm(Ac, Bc) |
553 | | - out3 = mm_dequant(maxA, maxB, out3) |
554 | | - |
555 | | - std = out2.std() |
556 | | - out2 /= std |
557 | | - out /= std |
558 | | - out3 /= std |
559 | | - |
560 | | - err = torch.abs(out - out2) |
561 | | - relerr = err / (torch.abs(out2) + 1e-7) |
562 | | - |
563 | | - err2 = torch.abs(out3 - out2) |
564 | | - relerr2 = err2 / (torch.abs(out2) + 1e-7) |
565 | | - |
566 | | - errs.append(err.mean().item()) |
567 | | - relerrs.append(relerr.mean().item()) |
568 | | - errs2.append(err2.mean().item()) |
569 | | - relerrs2.append(relerr2.mean().item()) |
570 | | - # print(mean(errs)) |
571 | | - # print(mean(relerrs)) |
572 | | - # print(mean(errs2)) |
573 | | - # print(mean(relerrs2)) |
574 | | - assert mean(errs) < 0.015 |
575 | | - |
576 | | - # There's a higher relerr on L40S with torch 2.4+cu118. |
577 | | - is_sm89 = torch.cuda.get_device_capability() == (8, 9) |
578 | | - if torch.version.cuda == "11.8" and is_sm89 and torch.__version__ < (2, 5): |
579 | | - assert mean(relerrs) < 0.41 |
580 | | - else: |
581 | | - assert mean(relerrs) < 0.3 |
582 | | - |
583 | | - @pytest.mark.parametrize("dim1", [1, 64], ids=id_formatter("dim1")) |
584 | | - @pytest.mark.parametrize("dim2", [32, 128], ids=id_formatter("dim2")) |
585 | | - @pytest.mark.parametrize("dim3", [32, 256], ids=id_formatter("dim3")) |
586 | | - @pytest.mark.parametrize("dim4", [32, 256], ids=id_formatter("dim4")) |
587 | | - @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) |
588 | | - def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): |
589 | | - if torch.version.cuda == "13.0" and torch.__version__ >= (2, 10) and dim1 == 64: |
590 | | - pytest.xfail("Failure due to regression in cuBLAS for CUDA Toolkit 13.0.2.") |
591 | | - |
592 | | - dim2 = dim2 - (dim2 % 16) |
593 | | - dim3 = dim3 - (dim3 % 16) |
594 | | - dim4 = dim4 - (dim4 % 16) |
595 | | - for i in range(k): |
596 | | - shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) |
597 | | - shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) |
598 | | - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) |
599 | | - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) |
600 | | - |
601 | | - if not transpose[0] and not transpose[1]: |
602 | | - out2 = torch.bmm(A.float(), B.float()) |
603 | | - out = F.igemm(A, B) |
604 | | - elif not transpose[0] and transpose[1]: |
605 | | - out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) |
606 | | - out = F.igemm(A, B.permute([0, 2, 1])) |
607 | | - elif transpose[0] and not transpose[1]: |
608 | | - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) |
609 | | - out = F.igemm(A.permute([0, 2, 1]), B) |
610 | | - elif transpose[0] and transpose[1]: |
611 | | - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) |
612 | | - out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) |
613 | | - torch.testing.assert_close(out.float(), out2.float()) |
614 | | - |
615 | | - |
616 | 339 | class TestLLMInt8Functional: |
617 | 340 | @staticmethod |
618 | 341 | def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half): |
@@ -723,6 +446,8 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): |
723 | 446 | n = C5.numel() |
724 | 447 | assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) |
725 | 448 |
|
| 449 | + # Keep CUDA-only coverage for int8_double_quant during deprecation cycle. |
| 450 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") |
726 | 451 | @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) |
727 | 452 | @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) |
728 | 453 | def test_int8_double_quant(self, dim1, dim2): |
|
0 commit comments