Skip to content

Commit 28fa6c2

Browse files
TimDettmersclaude
andcommitted
style: Apply pre-commit formatting (ruff, clang-format)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7b400f4 commit 28fa6c2

File tree

6 files changed

+161
-128
lines changed

6 files changed

+161
-128
lines changed

benchmarks/bench_cuda_events.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1919

2020
import torch
21+
2122
from bitsandbytes.functional import create_normal_float_codebook
2223

2324
WARMUP = 20
@@ -85,9 +86,7 @@ def prepare_dense_data(device):
8586
codebook = create_normal_float_codebook(k, device=device)
8687
W = torch.randn(K_dim * N, device=device, dtype=torch.float32)
8788
packed_flat, absmax_flat = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
88-
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
89-
packed_flat, absmax_flat, K_dim, N, k
90-
)
89+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(packed_flat, absmax_flat, K_dim, N, k)
9190
data[(name, k)] = (K_dim, N, packed_flat, absmax_flat, packed_tiled, absmax_tiled, codebook)
9291
return data
9392

@@ -134,8 +133,17 @@ def bench_mma(data, m_vals, device):
134133
tile_counters = torch.zeros(m_tiles * n_tiles, dtype=torch.int32, device=device)
135134

136135
fn = lambda: torch.ops.bitsandbytes.kbit_gemm_prod_(
137-
A, packed_tiled, absmax_tiled, codebook,
138-
K_dim, N, k, 1, out, C_workspace, tile_counters,
136+
A,
137+
packed_tiled,
138+
absmax_tiled,
139+
codebook,
140+
K_dim,
141+
N,
142+
k,
143+
1,
144+
out,
145+
C_workspace,
146+
tile_counters,
139147
)
140148
avg_us = bench_kernel(fn)
141149
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")
@@ -160,8 +168,14 @@ def bench_scalar(data, m_vals, device):
160168
out = torch.empty(M, N, dtype=torch.float16, device=device)
161169

162170
fn = lambda: torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
163-
A, packed_tiled, absmax_tiled, codebook,
164-
K_dim, N, k, out,
171+
A,
172+
packed_tiled,
173+
absmax_tiled,
174+
codebook,
175+
K_dim,
176+
N,
177+
k,
178+
out,
165179
)
166180
avg_us = bench_kernel(fn)
167181
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")
@@ -184,8 +198,16 @@ def bench_grouped(moe_data, m_vals, device):
184198

185199
# Grouped GEMM doesn't have an _ variant yet — use the allocating version
186200
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_gemm(
187-
A_concat, B_packed_all, B_absmax_all, codebook,
188-
expert_offsets, K_dim, N, k, NUM_EXPERTS, M,
201+
A_concat,
202+
B_packed_all,
203+
B_absmax_all,
204+
codebook,
205+
expert_offsets,
206+
K_dim,
207+
N,
208+
k,
209+
NUM_EXPERTS,
210+
M,
189211
)
190212
avg_us = bench_kernel(fn)
191213
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")

benchmarks/bench_tiled_vs_flat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@
6262

6363
# Quantize and repack
6464
packed_flat, absmax_flat = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
65-
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
66-
packed_flat, absmax_flat, K_dim, N, k
67-
)
65+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(packed_flat, absmax_flat, K_dim, N, k)
6866

6967
for M in M_VALUES:
7068
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
@@ -144,6 +142,7 @@ def bench_graph(fn, trials, iters):
144142
tiled_us, tiled_std = bench_graph(call_tiled, args.trials, args.iters)
145143
v2_us, v2_std = bench_graph(call_v2, args.trials, args.iters)
146144
else:
145+
147146
def bench_events(fn):
148147
for _ in range(args.warmup):
149148
fn()

bitsandbytes/_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,5 @@ def _(
871871
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
872872
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
873873
torch._check(C_workspace.dtype == torch.float32, lambda: f"C_workspace must be float32, got {C_workspace.dtype}")
874-
torch._check(
875-
tile_counters.dtype == torch.int32, lambda: f"tile_counters must be int32, got {tile_counters.dtype}"
876-
)
874+
torch._check(tile_counters.dtype == torch.int32, lambda: f"tile_counters must be int32, got {tile_counters.dtype}")
877875
return out

0 commit comments

Comments
 (0)