@@ -172,6 +172,32 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
172172
173173 opcheck (torch .ops .bitsandbytes .quantize_4bit .default , (A , blocksize , quant_type , storage_dtype ))
174174
175+ @pytest .mark .parametrize ("device" , get_available_devices ())
176+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
177+ @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
178+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 ])
179+ def test_quantize_4bit_not_divisible_by_blocksize (self , device , dtype , quant_type , blocksize ):
180+ """Test quantize/dequantize roundtrip when n_elements is not divisible by blocksize."""
181+ # Shape chosen so numel is NOT divisible by blocksize
182+ shape = (7 , blocksize - 1 )
183+ A = torch .randn (shape , dtype = dtype , device = device )
184+ storage_dtype = torch .uint8
185+
186+ # Should not raise
187+ packed , absmax = torch .ops .bitsandbytes .quantize_4bit (A , blocksize , quant_type , storage_dtype )
188+
189+ assert packed .device == A .device
190+ assert absmax .device == A .device
191+
192+ # Dequantize back and verify shape is preserved
193+ out = torch .ops .bitsandbytes .dequantize_4bit (packed , absmax , blocksize , quant_type , shape , dtype )
194+
195+ assert out .shape == shape
196+ assert out .dtype == dtype
197+
198+ # Verify output is finite (no NaN/Inf)
199+ assert torch .isfinite (out ).all (), "Dequantized output contains NaN or Inf"
200+
175201 @pytest .mark .parametrize ("device" , get_available_devices ())
176202 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
177203 @pytest .mark .parametrize ("storage_dtype" , [torch .uint8 , torch .bfloat16 ], ids = id_formatter ("storage_dtype" ))
0 commit comments