11import math
2+ import platform
23import random
34import time
45
56import einops
6- import numpy as np
7+ from packaging import version
78import pytest
89import torch
910
@@ -101,16 +102,16 @@ class Test8BitBlockwiseQuantizeFunctional:
101102 def test_dynamic_blockwise_quantization (self , device , dtype , nested , blocksize , signed ):
102103 iters = 100
103104
104- if device == "cpu " :
105+ if device != "cuda " :
105106 iters = 10
106107
107- # This test is slow on CPU , so avoid atypical use cases.
108+ # This test is slow in our non-CUDA implementations , so avoid atypical use cases.
108109 if nested :
109110 pytest .skip ("Not a typical use case." )
110111 if blocksize != 256 :
111- pytest .skip ("Only blocksize 256 is used in CPU/XPU" )
112+ pytest .skip ("Only blocksize 256 is used in CPU/MPS/ XPU" )
112113 if dtype != torch .float32 :
113- pytest .skip ("Only float32 is used in CPU/XPU" )
114+ pytest .skip ("Only float32 is used in CPU/MPS/ XPU" )
114115
115116 diffs = []
116117 reldiffs = []
@@ -239,7 +240,7 @@ def test_fp8_quant(self, device):
239240
240241 abserr = []
241242 relerr = []
242- for i in range (100 ):
243+ for i in range (10 ):
243244 A1 = torch .randn (1024 , 1024 , device = device )
244245 C , SC = F .quantize_blockwise (A1 , code = code )
245246 A2 = F .dequantize_blockwise (C , SC )
@@ -253,7 +254,7 @@ def test_fp8_quant(self, device):
253254
254255 abserr = []
255256 relerr = []
256- for i in range (100 ):
257+ for i in range (10 ):
257258 A1 = torch .rand (1024 , 1024 , device = device )
258259 C , SC = F .quantize_blockwise (A1 , code = code )
259260 A2 = F .dequantize_blockwise (C , SC )
@@ -267,7 +268,7 @@ def test_fp8_quant(self, device):
267268
268269 abserr = []
269270 relerr = []
270- for i in range (100 ):
271+ for i in range (10 ):
271272 A1 = torch .randn (1024 , 1024 , device = device )
272273 C , SC = F .quantize_blockwise (A1 )
273274 A2 = F .dequantize_blockwise (C , SC )
@@ -1406,28 +1407,29 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14061407 @pytest .mark .parametrize ("device" , get_available_devices ())
14071408 @pytest .mark .parametrize ("storage_type" , ["nf4" , "fp4" ], ids = ["nf4" , "fp4" ])
14081409 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = describe_dtype )
1409- @pytest .mark .parametrize ("double_quant" , [False ], ids = ["DQ_True" ])
14101410 @pytest .mark .skipif (
14111411 HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a" ,
14121412 reason = "this test is not supported on ROCm with gfx90a architecture yet" ,
14131413 )
1414- def test_gemv_eye_4bit (self , device , storage_type , dtype , double_quant ):
1414+ def test_gemv_eye_4bit (self , device , storage_type , dtype ):
14151415 if device == "cpu" and dtype == torch .bfloat16 and torch .__version__ < (2 , 3 ):
14161416 pytest .skip ("eye doe not support bfloat16 on CPU in torch < 2.3" )
14171417
14181418 if device == "hpu" and not is_supported_on_hpu (storage_type , dtype ):
14191419 pytest .skip ("This configuration is not supported on HPU." )
14201420
1421- dims = 10
1422- torch .random .manual_seed (np .random .randint (0 , 412424242 ))
1421+ if device == "cpu" and platform .system () == "Windows" and version .parse (torch .__version__ ).release == (2 , 8 , 0 ):
1422+ pytest .skip ("Regression: CPU crash on Windows with torch 2.8.0" )
1423+
1424+ dims = 4
14231425 dims = get_test_dims (0 , 8192 , n = dims )
14241426 dims = [dim + (64 - (dim % 64 )) for dim in dims ]
14251427 # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
14261428 for dim in dims :
14271429 A = torch .normal (0 , 0.1 , size = (1 , 1 , dim ), dtype = dtype , device = device )
14281430 B = torch .eye (dim , dtype = dtype , device = device )
14291431
1430- qB , state = F .quantize_4bit (B , quant_type = storage_type , compress_statistics = double_quant )
1432+ qB , state = F .quantize_4bit (B , quant_type = storage_type , compress_statistics = False )
14311433 C3 = torch .matmul (A , B .t ())
14321434 C2 = bnb .matmul_4bit (A , qB .t (), state )
14331435 A .requires_grad = True
0 commit comments