Skip to content

Commit 1e2dc09

Browse files
TimDettmersclaude
andcommitted
style: Fix pre-commit lint issues (ruff, clang-format, typos)
- Remove unused variables M, K in LinearNVFP4.forward() - Prefix unused unpacked variables with _ in GEMM tests - Add UE4M3, IST to typos ignore config (valid technical terms) - Apply clang-format to all CUDA source files - Apply ruff format to Python files Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 27cae0b commit 1e2dc09

File tree

12 files changed

+290
-291
lines changed

12 files changed

+290
-291
lines changed

_typos.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ extend-exclude = [
1111
[default]
1212
extend-ignore-re = [
1313
"@Ther-nul", # valid Github user
14+
"UE4M3", # unsigned E4M3 floating point format (NVFP4 block scale type)
15+
"ue4m3", # unsigned E4M3 lowercase
16+
"IST[ -]", # IST Austria / IST-DASLab (Institute of Science and Technology)
17+
"ist-", # ist-daslab lowercase in anchor links
1418
]
1519
extend-ignore-identifiers-re = [
1620
".*arange.*",
@@ -24,3 +28,4 @@ extend-ignore-identifiers-re = [
2428
"subtile" = "subtile"
2529
"subtiles" = "subtiles"
2630
"transation" = "transation" # TODO: is this transition, transaction, translation..?
31+
"ue" = "ue" # UE4M3: unsigned E4M3 floating point format (NVFP4 block scale type)

bitsandbytes/backends/cuda/ops.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -793,11 +793,17 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
793793

794794
with _cuda_device_of(A):
795795
if A.dtype == torch.float16:
796-
lib.cquantize_nvfp4_fp16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
796+
lib.cquantize_nvfp4_fp16(
797+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
798+
)
797799
elif A.dtype == torch.bfloat16:
798-
lib.cquantize_nvfp4_bf16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
800+
lib.cquantize_nvfp4_bf16(
801+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
802+
)
799803
else:
800-
lib.cquantize_nvfp4_fp32(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
804+
lib.cquantize_nvfp4_fp32(
805+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
806+
)
801807

802808
ts_out = torch.tensor([tensor_scale], dtype=torch.float32, device=A.device)
803809
return packed, block_scales, ts_out
@@ -814,11 +820,32 @@ def _(
814820

815821
with _cuda_device_of(packed):
816822
if dtype == torch.float16:
817-
lib.cdequantize_nvfp4_fp16(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
823+
lib.cdequantize_nvfp4_fp16(
824+
get_ptr(packed),
825+
get_ptr(block_scales),
826+
ct.c_float(tensor_scale),
827+
get_ptr(output),
828+
ct.c_int(numel),
829+
ct.c_void_p(0),
830+
)
818831
elif dtype == torch.bfloat16:
819-
lib.cdequantize_nvfp4_bf16(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
832+
lib.cdequantize_nvfp4_bf16(
833+
get_ptr(packed),
834+
get_ptr(block_scales),
835+
ct.c_float(tensor_scale),
836+
get_ptr(output),
837+
ct.c_int(numel),
838+
ct.c_void_p(0),
839+
)
820840
else:
821-
lib.cdequantize_nvfp4_fp32(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
841+
lib.cdequantize_nvfp4_fp32(
842+
get_ptr(packed),
843+
get_ptr(block_scales),
844+
ct.c_float(tensor_scale),
845+
get_ptr(output),
846+
ct.c_int(numel),
847+
ct.c_void_p(0),
848+
)
822849

823850
return output
824851

@@ -860,11 +887,17 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
860887

861888
with _cuda_device_of(A):
862889
if A.dtype == torch.float16:
863-
lib.cfused_hadamard_quantize_nvfp4_fp16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
890+
lib.cfused_hadamard_quantize_nvfp4_fp16(
891+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
892+
)
864893
elif A.dtype == torch.bfloat16:
865-
lib.cfused_hadamard_quantize_nvfp4_bf16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
894+
lib.cfused_hadamard_quantize_nvfp4_bf16(
895+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
896+
)
866897
else:
867-
lib.cfused_hadamard_quantize_nvfp4_fp32(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
898+
lib.cfused_hadamard_quantize_nvfp4_fp32(
899+
get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n)
900+
)
868901

869902
ts_out = torch.tensor([tensor_scale], dtype=torch.float32, device=A.device)
870903
return packed, block_scales, ts_out
@@ -887,10 +920,14 @@ def _(
887920

888921
with _cuda_device_of(A_packed):
889922
lib.cgemm_nvfp4(
890-
get_ptr(A_packed), get_ptr(B_packed),
891-
get_ptr(A_scales), get_ptr(B_scales),
923+
get_ptr(A_packed),
924+
get_ptr(B_packed),
925+
get_ptr(A_scales),
926+
get_ptr(B_scales),
892927
get_ptr(D_out),
893-
ct.c_int(M), ct.c_int(N), ct.c_int(K),
928+
ct.c_int(M),
929+
ct.c_int(N),
930+
ct.c_int(K),
894931
)
895932

896933
# Apply tensor scales (the GEMM kernel operates on raw quantized values)

bitsandbytes/functional.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,15 @@ def gemm_nvfp4(
12321232
N = B_state.shape[0]
12331233

12341234
return torch.ops.bitsandbytes.gemm_nvfp4(
1235-
A_data, B_data, A_state.block_scales, B_state.block_scales,
1236-
A_state.tensor_scale, B_state.tensor_scale, M, N, K,
1235+
A_data,
1236+
B_data,
1237+
A_state.block_scales,
1238+
B_state.block_scales,
1239+
A_state.tensor_scale,
1240+
B_state.tensor_scale,
1241+
M,
1242+
N,
1243+
K,
12371244
)
12381245

12391246

bitsandbytes/nn/modules.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
725725

726726
# Reshape input: (*, K) -> (M, K)
727727
x_2d = x.reshape(-1, input_shape[-1]).float().contiguous()
728-
M = x_2d.shape[0]
729-
K = x_2d.shape[1]
730728
N = self.weight_state.shape[0] # out_features
731729

732730
# Quantize activations to NVFP4

csrc/kernels.cu

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_
128128
// ============================================================================
129129

130130
// E2M1 dequantization LUT - maps 3-bit unsigned magnitude code to float
131-
__device__ static float nvfp4_dequant_lut[8] = {
132-
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f
133-
};
131+
__device__ static float nvfp4_dequant_lut[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f};
134132

135133
// Dequantize a 4-bit E2M1 code to float
136134
// Bit layout: [sign(1) | exponent(2) | mantissa(1)]
@@ -170,8 +168,10 @@ __device__ unsigned char dQuantizeNVFP4(float x) {
170168
// Convert positive float to unsigned E4M3 (8-bit: 4 exponent bits, bias=7, 3 mantissa bits)
171169
// Range: [0, 448]. Used for NVFP4 block scale factors.
172170
__device__ unsigned char dFloatToE4M3(float x) {
173-
if (x <= 0.0f) return 0;
174-
if (x >= 448.0f) return 0x7E; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
171+
if (x <= 0.0f)
172+
return 0;
173+
if (x >= 448.0f)
174+
return 0x7E; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
175175

176176
unsigned int bits = __float_as_uint(x);
177177
int fp32_exp = ((bits >> 23) & 0xFF) - 127; // Unbiased FP32 exponent
@@ -180,8 +180,10 @@ __device__ unsigned char dFloatToE4M3(float x) {
180180
if (e4m3_exp <= 0) {
181181
// Subnormal in E4M3: value = mantissa/8 * 2^(-6)
182182
int mant = __float2int_rn(x * 512.0f); // 512 = 8 * 2^6
183-
if (mant <= 0) return 0;
184-
if (mant > 7) mant = 7;
183+
if (mant <= 0)
184+
return 0;
185+
if (mant > 7)
186+
mant = 7;
185187
return (unsigned char)mant;
186188
}
187189

@@ -194,15 +196,18 @@ __device__ unsigned char dFloatToE4M3(float x) {
194196
e4m3_exp++;
195197
}
196198

197-
if (e4m3_exp > 15) return 0x7E;
198-
if (e4m3_exp == 15 && mant_3bit >= 7) return 0x7E; // Clamp, don't produce NaN
199+
if (e4m3_exp > 15)
200+
return 0x7E;
201+
if (e4m3_exp == 15 && mant_3bit >= 7)
202+
return 0x7E; // Clamp, don't produce NaN
199203

200204
return (unsigned char)((e4m3_exp << 3) | mant_3bit);
201205
}
202206

203207
// Convert unsigned E4M3 byte to float
204208
__device__ float dE4M3ToFloat(unsigned char val) {
205-
if (val == 0) return 0.0f;
209+
if (val == 0)
210+
return 0.0f;
206211

207212
int exp = (val >> 3) & 0x0F;
208213
int mant = val & 0x07;
@@ -225,17 +230,17 @@ __device__ float dE4M3ToFloat(unsigned char val) {
225230
template <typename T>
226231
__global__ void kQuantizeNVFP4(
227232
const T* __restrict__ input,
228-
unsigned char* __restrict__ output, // Packed FP4: n/2 bytes
233+
unsigned char* __restrict__ output, // Packed FP4: n/2 bytes
229234
unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
230-
const float tensor_scale,
231-
const int n
235+
const float tensor_scale, const int n
232236
) {
233237
// Each thread handles 2 consecutive elements (packs into 1 byte)
234238
// 8 threads per 16-element quantization block
235239
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
236240
const int element_idx = tid * 2;
237241

238-
if (element_idx >= n) return;
242+
if (element_idx >= n)
243+
return;
239244

240245
const float inv_tensor_scale = (tensor_scale > 0.0f) ? (1.0f / tensor_scale) : 0.0f;
241246

@@ -246,10 +251,10 @@ __global__ void kQuantizeNVFP4(
246251
// Compute per-thread absmax
247252
float local_max = fmaxf(fabsf(val0), fabsf(val1));
248253

249-
// Warp-shuffle reduction within 8-thread quantization block
250-
// Threads 0-7 handle block 0, 8-15 handle block 1, etc.
251-
// XOR offsets 4, 2, 1 stay within each 8-thread group
252-
#pragma unroll
254+
// Warp-shuffle reduction within 8-thread quantization block
255+
// Threads 0-7 handle block 0, 8-15 handle block 1, etc.
256+
// XOR offsets 4, 2, 1 stay within each 8-thread group
257+
#pragma unroll
253258
for (int offset = 4; offset >= 1; offset >>= 1) {
254259
float other = __shfl_xor_sync(0xFFFFFFFF, local_max, offset);
255260
local_max = fmaxf(local_max, other);
@@ -287,16 +292,15 @@ __global__ void kQuantizeNVFP4(
287292
// ============================================================================
288293
template <typename T>
289294
__global__ void kDequantizeNVFP4(
290-
const unsigned char* __restrict__ input, // Packed FP4: n/2 bytes
295+
const unsigned char* __restrict__ input, // Packed FP4: n/2 bytes
291296
const unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
292-
const float tensor_scale,
293-
T* __restrict__ output,
294-
const int n
297+
const float tensor_scale, T* __restrict__ output, const int n
295298
) {
296299
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
297300
const int element_idx = tid * 2;
298301

299-
if (element_idx >= n) return;
302+
if (element_idx >= n)
303+
return;
300304

301305
// Load and unpack
302306
unsigned char packed = input[element_idx / 2];
@@ -327,22 +331,19 @@ __global__ void kDequantizeNVFP4(
327331
// 4 butterfly stages: stride 8, 4, 2, 1. Normalization by 1/4 = 1/sqrt(16).
328332
// In-place operation on FP16/BF16/FP32 tensors.
329333
// ============================================================================
330-
template <typename T>
331-
__global__ void kHadamardRotate16(
332-
T* __restrict__ data,
333-
const int n
334-
) {
334+
template <typename T> __global__ void kHadamardRotate16(T* __restrict__ data, const int n) {
335335
// Each thread handles one element.
336336
// 16 threads form one Hadamard block.
337337
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
338-
if (tid >= n) return;
338+
if (tid >= n)
339+
return;
339340

340341
float val = (float)data[tid];
341342

342-
// Fast Walsh-Hadamard Transform: 4 butterfly stages
343-
// Threads within the same 16-element group exchange via warp shuffles
344-
// lane_in_block: position 0-15 within the 16-element Hadamard block
345-
#pragma unroll
343+
// Fast Walsh-Hadamard Transform: 4 butterfly stages
344+
// Threads within the same 16-element group exchange via warp shuffles
345+
// lane_in_block: position 0-15 within the 16-element Hadamard block
346+
#pragma unroll
346347
for (int stride = 8; stride >= 1; stride >>= 1) {
347348
float other = __shfl_xor_sync(0xFFFFFFFF, val, stride);
348349
// Butterfly: if bit is 0, add; if bit is 1, subtract
@@ -365,20 +366,20 @@ template <typename T>
365366
__global__ void kFusedHadamardQuantizeNVFP4(
366367
const T* __restrict__ input,
367368
unsigned char* __restrict__ output, // Packed FP4: n/2 bytes
368-
unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
369-
const float tensor_scale,
370-
const int n
369+
unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
370+
const float tensor_scale, const int n
371371
) {
372372
// Each thread handles 1 element for the Hadamard transform,
373373
// then pairs of threads pack 2 elements into 1 byte.
374374
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
375-
if (tid >= n) return;
375+
if (tid >= n)
376+
return;
376377

377378
// Load and convert to float
378379
float val = (float)input[tid];
379380

380-
// Apply Hadamard rotation (FWHT, 4 butterfly stages)
381-
#pragma unroll
381+
// Apply Hadamard rotation (FWHT, 4 butterfly stages)
382+
#pragma unroll
382383
for (int stride = 8; stride >= 1; stride >>= 1) {
383384
float other = __shfl_xor_sync(0xFFFFFFFF, val, stride);
384385
int bit = tid & stride;
@@ -392,7 +393,7 @@ __global__ void kFusedHadamardQuantizeNVFP4(
392393

393394
// Compute block absmax via warp shuffle (16 threads per Hadamard block)
394395
float local_max = fabsf(scaled_val);
395-
#pragma unroll
396+
#pragma unroll
396397
for (int offset = 8; offset >= 1; offset >>= 1) {
397398
float other = __shfl_xor_sync(0xFFFFFFFF, local_max, offset);
398399
local_max = fmaxf(local_max, other);
@@ -2872,28 +2873,28 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
28722873

28732874
// NVFP4 kernel template instantiations
28742875
template __global__ void kQuantizeNVFP4<half>(
2875-
const half* __restrict__ input, unsigned char* __restrict__ output,
2876-
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2876+
const half* __restrict__ input, unsigned char* __restrict__ output, unsigned char* __restrict__ block_scales,
2877+
const float tensor_scale, const int n
28772878
);
28782879
template __global__ void kQuantizeNVFP4<__nv_bfloat16>(
28792880
const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output,
28802881
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
28812882
);
28822883
template __global__ void kQuantizeNVFP4<float>(
2883-
const float* __restrict__ input, unsigned char* __restrict__ output,
2884-
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2884+
const float* __restrict__ input, unsigned char* __restrict__ output, unsigned char* __restrict__ block_scales,
2885+
const float tensor_scale, const int n
28852886
);
28862887
template __global__ void kDequantizeNVFP4<half>(
2887-
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2888-
const float tensor_scale, half* __restrict__ output, const int n
2888+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales, const float tensor_scale,
2889+
half* __restrict__ output, const int n
28892890
);
28902891
template __global__ void kDequantizeNVFP4<__nv_bfloat16>(
2891-
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2892-
const float tensor_scale, __nv_bfloat16* __restrict__ output, const int n
2892+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales, const float tensor_scale,
2893+
__nv_bfloat16* __restrict__ output, const int n
28932894
);
28942895
template __global__ void kDequantizeNVFP4<float>(
2895-
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2896-
const float tensor_scale, float* __restrict__ output, const int n
2896+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales, const float tensor_scale,
2897+
float* __restrict__ output, const int n
28972898
);
28982899

28992900
// Hadamard rotation kernel instantiations
@@ -2903,16 +2904,16 @@ template __global__ void kHadamardRotate16<float>(float* __restrict__ data, cons
29032904

29042905
// Fused Hadamard + NVFP4 quantize kernel instantiations
29052906
template __global__ void kFusedHadamardQuantizeNVFP4<half>(
2906-
const half* __restrict__ input, unsigned char* __restrict__ output,
2907-
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2907+
const half* __restrict__ input, unsigned char* __restrict__ output, unsigned char* __restrict__ block_scales,
2908+
const float tensor_scale, const int n
29082909
);
29092910
template __global__ void kFusedHadamardQuantizeNVFP4<__nv_bfloat16>(
29102911
const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output,
29112912
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
29122913
);
29132914
template __global__ void kFusedHadamardQuantizeNVFP4<float>(
2914-
const float* __restrict__ input, unsigned char* __restrict__ output,
2915-
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2915+
const float* __restrict__ input, unsigned char* __restrict__ output, unsigned char* __restrict__ block_scales,
2916+
const float tensor_scale, const int n
29162917
);
29172918

29182919
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \

0 commit comments

Comments
 (0)