1+ from collections .abc import Sequence
12import ctypes as ct
23from typing import Optional
34
@@ -119,6 +120,10 @@ def _(
119120) -> tuple [torch .Tensor , torch .Tensor ]:
120121 torch ._check_is_size (blocksize )
121122 torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
123+ torch ._check (
124+ A .dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
125+ lambda : f"Blockwise 4bit quantization only supports 16/32-bit floats, but got { A .dtype } " ,
126+ )
122127
123128 n = A .numel ()
124129
@@ -140,3 +145,39 @@ def _(
140145 packed = packed .squeeze ().view (quant_storage ).unsqueeze (1 )
141146
142147 return packed , absmax .float ()
148+
149+
150+ @register_kernel ("bitsandbytes::dequantize_4bit" , "cpu" )
151+ def _ (
152+ A : torch .Tensor ,
153+ absmax : torch .Tensor ,
154+ blocksize : int ,
155+ quant_type : str ,
156+ shape : Sequence [int ],
157+ dtype : torch .dtype ,
158+ ) -> torch .Tensor :
159+ torch ._check_is_size (blocksize )
160+ torch ._check (quant_type == "nf4" , lambda : f"quant_type must be nf4 on CPU, got { quant_type } " )
161+ torch ._check (
162+ dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
163+ lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
164+ )
165+ torch ._check (
166+ A .dtype == torch .uint8 ,
167+ lambda : f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got { A .dtype } " ,
168+ )
169+
170+ # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
171+ upper = (A >> 4 ).to (torch .int64 )
172+ lower = (A & 0x0F ).to (torch .int64 )
173+
174+ # Expand to blocks
175+ blocks = torch .cat ((upper , lower ), dim = 1 ).reshape (- 1 , blocksize )
176+
177+ # Dequantize
178+ blocks = _NF4_QUANT_TABLE [blocks ] * absmax [:, None ]
179+
180+ # Reshape to original shape
181+ blocks = blocks .reshape (- 1 , * shape [1 :])
182+
183+ return blocks .to (dtype )
0 commit comments