@@ -246,3 +246,108 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
246246 assert out .isreal ().all ()
247247
248248 opcheck (torch .ops .bitsandbytes .gemv_4bit .default , (A , B_q , B .shape , absmax , code , blocksize ))
249+
250+
251+ class TestNonContiguousInputs :
252+ """Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""
253+
254+ @pytest .mark .parametrize ("device" , get_available_devices ())
255+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
256+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 ])
257+ def test_quantize_blockwise_non_contiguous (self , device , dtype , blocksize ):
258+ if device == "cpu" :
259+ pytest .skip ("Non-contiguous fix targets CUDA backend only" )
260+
261+ code = bitsandbytes .functional .create_dynamic_map ().to (device )
262+
263+ # Create non-contiguous tensor via slicing
264+ A_full = torch .randn (3 , 4 , 6 , 256 , dtype = dtype , device = device )
265+ A_noncontig = A_full [:, ::2 , :, :]
266+ assert not A_noncontig .is_contiguous ()
267+
268+ A_contig = A_noncontig .contiguous ()
269+
270+ out_nc , absmax_nc = torch .ops .bitsandbytes .quantize_blockwise (A_noncontig , code , blocksize )
271+ out_c , absmax_c = torch .ops .bitsandbytes .quantize_blockwise (A_contig , code , blocksize )
272+
273+ torch .testing .assert_close (absmax_nc , absmax_c )
274+ torch .testing .assert_close (out_nc , out_c )
275+
276+ @pytest .mark .parametrize ("device" , get_available_devices ())
277+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
278+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 ])
279+ def test_dequantize_blockwise_non_contiguous (self , device , dtype , blocksize ):
280+ if device == "cpu" :
281+ pytest .skip ("Non-contiguous fix targets CUDA backend only" )
282+
283+ code = bitsandbytes .functional .create_dynamic_map ().to (device , dtype = torch .float32 )
284+
285+ # Quantize a contiguous tensor, then create non-contiguous uint8 via transpose
286+ A = torch .randn (1024 , 1024 , dtype = dtype , device = device )
287+ quantized , absmax = torch .ops .bitsandbytes .quantize_blockwise (A , code , blocksize )
288+
289+ # Create non-contiguous uint8 tensor by transposing and transposing back
290+ q_noncontig = quantized .t ().t ()
291+ # If that's still contiguous, use a different approach
292+ if q_noncontig .is_contiguous ():
293+ # Pad and slice to force non-contiguity
294+ q_padded = torch .zeros (1024 , 1025 , dtype = torch .uint8 , device = device )
295+ q_padded [:, :1024 ] = quantized
296+ q_noncontig = q_padded [:, :1024 ]
297+
298+ assert not q_noncontig .is_contiguous ()
299+ q_contig = q_noncontig .contiguous ()
300+
301+ out_nc = torch .ops .bitsandbytes .dequantize_blockwise (q_noncontig , absmax , code , blocksize , dtype )
302+ out_c = torch .ops .bitsandbytes .dequantize_blockwise (q_contig , absmax , code , blocksize , dtype )
303+
304+ torch .testing .assert_close (out_nc , out_c )
305+
306+ @pytest .mark .parametrize ("device" , get_available_devices ())
307+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
308+ @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
309+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 ])
310+ def test_quantize_4bit_non_contiguous (self , device , dtype , quant_type , blocksize ):
311+ if device != "cuda" :
312+ pytest .skip ("Non-contiguous fix targets CUDA backend only" )
313+
314+ # Reproduce issue #1342: non-contiguous tensor from slicing
315+ A_full = torch .randn (3 , 4 , 6 , 256 , dtype = dtype , device = device )
316+ A_noncontig = A_full [:, ::2 , :, :]
317+ assert not A_noncontig .is_contiguous ()
318+
319+ A_contig = A_noncontig .contiguous ()
320+ storage_dtype = torch .uint8
321+
322+ out_nc , absmax_nc = torch .ops .bitsandbytes .quantize_4bit (A_noncontig , blocksize , quant_type , storage_dtype )
323+ out_c , absmax_c = torch .ops .bitsandbytes .quantize_4bit (A_contig , blocksize , quant_type , storage_dtype )
324+
325+ torch .testing .assert_close (absmax_nc , absmax_c )
326+ torch .testing .assert_close (out_nc , out_c )
327+
328+ @pytest .mark .parametrize ("device" , get_available_devices ())
329+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
330+ @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
331+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 ])
332+ def test_quantize_4bit_roundtrip_non_contiguous (self , device , dtype , quant_type , blocksize ):
333+ """End-to-end test: quantize non-contiguous, dequantize, compare with contiguous path."""
334+ if device != "cuda" :
335+ pytest .skip ("Non-contiguous fix targets CUDA backend only" )
336+
337+ A_full = torch .randn (3 , 4 , 6 , 256 , dtype = dtype , device = device )
338+ A_noncontig = A_full [:, ::2 , :, :]
339+ assert not A_noncontig .is_contiguous ()
340+
341+ A_contig = A_noncontig .contiguous ()
342+ storage_dtype = torch .uint8
343+
344+ # Quantize both
345+ q_nc , absmax_nc = torch .ops .bitsandbytes .quantize_4bit (A_noncontig , blocksize , quant_type , storage_dtype )
346+ q_c , absmax_c = torch .ops .bitsandbytes .quantize_4bit (A_contig , blocksize , quant_type , storage_dtype )
347+
348+ # Dequantize both
349+ shape = A_contig .shape
350+ deq_nc = torch .ops .bitsandbytes .dequantize_4bit (q_nc , absmax_nc , blocksize , quant_type , shape , dtype )
351+ deq_c = torch .ops .bitsandbytes .dequantize_4bit (q_c , absmax_c , blocksize , quant_type , shape , dtype )
352+
353+ torch .testing .assert_close (deq_nc , deq_c )
0 commit comments