Skip to content

Commit 71c4874

Browse files
TimDettmersclaude
andcommitted
Add datacenter GPU macro, L2 prefetch hints, and generalized reduction
- Add BNB_DATACENTER_GPU compile-time macro for Hopper (sm_90) and Blackwell datacenter (sm_100). Consumer GPUs (sm_89, sm_120) are explicitly excluded. - Add prefetch_l2() helper: issues prefetch.global.L2 on datacenter GPUs, compiles to no-op on consumer. - Add L2 prefetch hints in MMA pipeline loop (prefetch tile kt+2) - Add L2 prefetch hints in grouped MMA pipeline loop (same pattern) - Add L2 prefetch hints in scalar GEMV inner loop (next iteration's B) - Generalize scalar GEMV warp reduction to loop over NUM_WARPS instead of hardcoding 2-warp sum (cleaner, same behavior at NUM_WARPS=2) - Update bench_tiled_vs_flat.py to benchmark v2 kernel alongside flat and tiled Zero consumer regression: 174/174 tests pass, RTX 4090 benchmarks unchanged. L2 prefetch effect on H100 is neutral for scalar GEMV (expected — single-iteration-per-thread case). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 72a374c commit 71c4874

File tree

2 files changed

+111
-33
lines changed

2 files changed

+111
-33
lines changed

benchmarks/bench_tiled_vs_flat.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Benchmark tiled vs flat scalar GEMV with pre-allocated output buffers.
22
33
Measures kernel-only time by pre-allocating all buffers before the timing loop.
4-
No allocations inside the measured region — fair comparison between flat and tiled.
4+
No allocations inside the measured region — fair comparison between flat, tiled, and tiled v2.
55
66
Usage:
77
python benchmarks/bench_tiled_vs_flat.py
@@ -40,11 +40,20 @@
4040
M_VALUES = [1, 2, 4]
4141

4242
if args.graph:
43-
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'±flat':>6} {'tiled_us':>8} {'±tiled':>6} {'diff%':>7}")
44-
print("-" * 76)
43+
print(
44+
f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2}"
45+
f" {'flat_us':>8} {'±flat':>6}"
46+
f" {'tiled_us':>8} {'±tl':>4}"
47+
f" {'v2_us':>8} {'±v2':>4}"
48+
f" {'tl/fl%':>7} {'v2/fl%':>7}"
49+
)
50+
print("-" * 100)
4551
else:
46-
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}")
47-
print("-" * 60)
52+
print(
53+
f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2}"
54+
f" {'flat_us':>8} {'tiled_us':>8} {'v2_us':>8} {'tl/fl%':>7} {'v2/fl%':>7}"
55+
)
56+
print("-" * 75)
4857

4958
for name, K_dim, N in SHAPES:
5059
for k in K_VALUES:
@@ -63,16 +72,27 @@
6372
# Pre-allocate output buffers
6473
out_flat = torch.empty(M, N, dtype=torch.float16, device="cuda")
6574
out_tiled = torch.empty(M, N, dtype=torch.float16, device="cuda")
75+
out_v2 = torch.empty(M, N, dtype=torch.float16, device="cuda")
76+
77+
# v2 workspace
78+
n_tiles = N // 128
79+
C_workspace = torch.zeros(M, N, dtype=torch.float32, device="cuda")
80+
tile_counters = torch.zeros(n_tiles, dtype=torch.int32, device="cuda")
6681

6782
if args.ncu:
68-
# NCU mode: single call each, profiler captures kernel time
6983
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
7084
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
7185
)
7286
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
7387
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
7488
)
75-
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {'ncu':>8} {'ncu':>8} {'ncu':>7}")
89+
torch.ops.bitsandbytes.kbit_scalar_gemv_v2_(
90+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_v2, C_workspace, tile_counters
91+
)
92+
print(
93+
f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2}"
94+
f" {'ncu':>8} {'ncu':>8} {'ncu':>8} {'ncu':>7} {'ncu':>7}"
95+
)
7696
continue
7797

7898
start = torch.cuda.Event(enable_timing=True)
@@ -88,11 +108,16 @@ def call_tiled():
88108
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
89109
)
90110

111+
def call_v2():
112+
torch.ops.bitsandbytes.kbit_scalar_gemv_v2_(
113+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_v2, C_workspace, tile_counters
114+
)
115+
91116
if args.graph:
92117
import statistics
93118

94119
# CUDA graph replay — measures kernel-only time
95-
for fn in (call_flat, call_tiled):
120+
for fn in (call_flat, call_tiled, call_v2):
96121
for _ in range(3):
97122
fn()
98123
torch.cuda.synchronize()
@@ -117,38 +142,44 @@ def bench_graph(fn, trials, iters):
117142

118143
flat_us, flat_std = bench_graph(call_flat, args.trials, args.iters)
119144
tiled_us, tiled_std = bench_graph(call_tiled, args.trials, args.iters)
145+
v2_us, v2_std = bench_graph(call_v2, args.trials, args.iters)
120146
else:
121-
# CUDA events timing (includes Python dispatch overhead)
122-
for _ in range(args.warmup):
123-
call_flat()
124-
torch.cuda.synchronize()
125-
start.record()
126-
for _ in range(args.iters):
127-
call_flat()
128-
end.record()
129-
torch.cuda.synchronize()
130-
flat_us = start.elapsed_time(end) * 1000 / args.iters
147+
def bench_events(fn):
148+
for _ in range(args.warmup):
149+
fn()
150+
torch.cuda.synchronize()
151+
start.record()
152+
for _ in range(args.iters):
153+
fn()
154+
end.record()
155+
torch.cuda.synchronize()
156+
return start.elapsed_time(end) * 1000 / args.iters
131157

132-
for _ in range(args.warmup):
133-
call_tiled()
134-
torch.cuda.synchronize()
135-
start.record()
136-
for _ in range(args.iters):
137-
call_tiled()
138-
end.record()
139-
torch.cuda.synchronize()
140-
tiled_us = start.elapsed_time(end) * 1000 / args.iters
158+
flat_us = bench_events(call_flat)
159+
tiled_us = bench_events(call_tiled)
160+
v2_us = bench_events(call_v2)
141161

142-
diff_pct = (tiled_us - flat_us) / flat_us * 100
162+
tl_pct = (tiled_us - flat_us) / flat_us * 100
163+
v2_pct = (v2_us - flat_us) / flat_us * 100
143164
if args.graph:
144165
print(
145166
f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2}"
146-
f" {flat_us:>8.1f} {flat_std:>5.1f}σ {tiled_us:>8.1f} {tiled_std:>5.1f}σ {diff_pct:>+7.1f}%"
167+
f" {flat_us:>8.1f} {flat_std:>5.1f}σ"
168+
f" {tiled_us:>8.1f} {tiled_std:>3.1f}σ"
169+
f" {v2_us:>8.1f} {v2_std:>3.1f}σ"
170+
f" {tl_pct:>+7.1f}% {v2_pct:>+7.1f}%"
147171
)
148172
else:
149-
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%")
173+
print(
174+
f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2}"
175+
f" {flat_us:>8.1f} {tiled_us:>8.1f} {v2_us:>8.1f} {tl_pct:>+7.1f}% {v2_pct:>+7.1f}%"
176+
)
150177

151178
# Correctness check (once per shape/k)
152179
assert torch.equal(out_flat, out_tiled) or torch.allclose(out_flat, out_tiled, rtol=0.05, atol=0.1), (
153-
f"MISMATCH {name} k={k}: max diff = {(out_flat - out_tiled).abs().max().item()}"
180+
f"MISMATCH flat vs tiled {name} k={k}: max diff = {(out_flat - out_tiled).abs().max().item()}"
181+
)
182+
# v2 uses split-K so small FP diffs are expected
183+
assert torch.allclose(out_flat.float(), out_v2.float(), rtol=0.1, atol=1.0), (
184+
f"MISMATCH flat vs v2 {name} k={k}: max diff = {(out_flat.float() - out_v2.float()).abs().max().item()}"
154185
)

csrc/ops.cu

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,23 @@ void repackKbit(
10111011
CUDA_CHECK_RETURN(cudaPeekAtLastError());
10121012
}
10131013

1014+
// Datacenter GPU detection: Hopper (sm_90) and Blackwell datacenter (sm_100).
1015+
// NOTE: sm_120 (RTX 5090, Blackwell consumer) lacks TMA/wgmma — must NOT match.
1016+
#if defined(__CUDA_ARCH__)
1017+
#define BNB_DATACENTER_GPU (__CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000)
1018+
#else
1019+
#define BNB_DATACENTER_GPU 0
1020+
#endif
1021+
1022+
// L2 prefetch hint (datacenter GPUs only — consumer GPUs ignore it)
1023+
__device__ __forceinline__ void prefetch_l2(const void* ptr) {
1024+
#if BNB_DATACENTER_GPU
1025+
asm volatile("prefetch.global.L2 [%0];" ::"l"(ptr));
1026+
#else
1027+
(void)ptr;
1028+
#endif
1029+
}
1030+
10141031
// cp.async helpers (sm_80+) — used by production MMA and grouped MMA kernels
10151032
__device__ __forceinline__ void cp_async_cg_16(void* __restrict__ smem, const void* __restrict__ gmem) {
10161033
uint32_t smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(smem));
@@ -1299,6 +1316,12 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
12991316
if (kt + 1 < kt_end) {
13001317
fetch_tile((kt + 1 - kt_start) % 2, kt + 1);
13011318
cp_async_fence();
1319+
// L2 prefetch for tile kt+2 (warms L2 before next fetch_tile issues cp.async)
1320+
if (kt + 2 < kt_end) {
1321+
const int pf_tile = (kt + 2) * n_tiles + n_tile;
1322+
prefetch_l2(B_packed + pf_tile * B_STAGE_WORDS);
1323+
prefetch_l2(B_absmax + pf_tile * ABS_STAGE_ELEMS);
1324+
}
13021325
cp_async_wait<1>();
13031326
} else {
13041327
cp_async_wait<0>();
@@ -1736,6 +1759,12 @@ __global__ void kbit_grouped_gemm_prod(
17361759
if (kt + 1 < kt_end) {
17371760
fetch_tile((kt - kt_start + 1) % 2, kt + 1);
17381761
cp_async_fence();
1762+
// L2 prefetch for tile kt+2
1763+
if (kt + 2 < kt_end) {
1764+
const int pf_tile = (kt + 2) * n_tiles + n_tile;
1765+
prefetch_l2(B_packed + pf_tile * B_STAGE_WORDS);
1766+
prefetch_l2(B_absmax + pf_tile * ABS_STAGE_ELEMS);
1767+
}
17391768
cp_async_wait<1>();
17401769
} else {
17411770
cp_async_wait<0>();
@@ -2009,6 +2038,21 @@ __global__ void __launch_bounds__(64, M_VAL <= 2 ? 24 : 16) kbit_scalar_gemv(
20092038
abs_idx = tile_base * ABS_PER_TILE + col_in_tile * KB_PER_TILE + kb;
20102039
}
20112040

2041+
// L2 prefetch for next iteration's B data
2042+
{
2043+
const int next_block_idx = block_idx + BLOCK_SIZE;
2044+
if (next_block_idx < num_k_blocks) {
2045+
if constexpr (!TILED) {
2046+
prefetch_l2(&B_col[next_block_idx * K_BITS]);
2047+
} else {
2048+
const int nk_tile = next_block_idx / KB_PER_TILE;
2049+
const int nkb = next_block_idx % KB_PER_TILE;
2050+
const int ntb = nk_tile * n_tiles + n_tile;
2051+
prefetch_l2(&B_packed[ntb * WORDS_PER_TILE + (col_in_tile * KB_PER_TILE + nkb) * K_BITS]);
2052+
}
2053+
}
2054+
}
2055+
20122056
// Load k bit-plane words (guarded; invalid threads get 0)
20132057
// Vector loads for power-of-2 K_BITS, scalar for others.
20142058
const unsigned int* B_src = TILED ? B_packed : B_col;
@@ -2092,12 +2136,15 @@ __global__ void __launch_bounds__(64, M_VAL <= 2 ? 24 : 16) kbit_scalar_gemv(
20922136
}
20932137
__syncthreads();
20942138

2095-
// Thread 0 sums both warps and writes output
2139+
// Thread 0 sums all warps and writes output
20962140
if (threadIdx.x == 0) {
20972141
#pragma unroll
20982142
for (int m = 0; m < M_VAL; m++) {
20992143
if (m < M) {
2100-
float sum = s_partial[0 * M_MAX + m] + s_partial[1 * M_MAX + m];
2144+
float sum = 0.0f;
2145+
#pragma unroll
2146+
for (int w = 0; w < NUM_WARPS; w++)
2147+
sum += s_partial[w * M_MAX + m];
21012148
C[m * N + col] = ScalarOps<scalar_t>::from_float(sum);
21022149
}
21032150
}

0 commit comments

Comments
 (0)