Skip to content

Commit 1895945

Browse files
TimDettmersclaude
andcommitted
Cache split-K workspace buffers to avoid per-call torch.zeros allocation
Add _WorkspaceCache that reuses C_workspace (float32) and tile_counters (int32) across calls to kbit/vq gemm_prod and grouped_gemm ops. Saves ~6-13 us per call on GLM-4.7 shapes (5-13% speedup on RTX 4090) by eliminating redundant allocation + double-zeroing overhead. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6e76b52 commit 1895945

File tree

1 file changed

+57
-15
lines changed
  • bitsandbytes/backends/cuda

1 file changed

+57
-15
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,6 +2106,49 @@ def _(data: torch.Tensor, dim: int, signs: Optional[torch.Tensor]) -> torch.Tens
21062106
return data
21072107

21082108

2109+
class _WorkspaceCache:
2110+
"""Per-device cache for split-K workspace buffers (C_workspace + tile_counters).
2111+
2112+
Avoids repeated torch.zeros allocations in the default (non-workspace) path.
2113+
Buffers are allocated at the max size seen per device and reused via views.
2114+
The _impl functions call .zero_() on the views, so only used elements are zeroed.
2115+
2116+
Memory cost is modest: at M=16 with N=5120, C_workspace is 320 KB (float32)
2117+
and tile_counters is <1 KB. For MoE with 8 experts × max_M=16, C_workspace
2118+
is ~2.5 MB. Buffers are never freed until process exit.
2119+
2120+
Not thread-safe — assumes single-threaded inference (typical for LLM serving).
2121+
"""
2122+
2123+
def __init__(self):
2124+
# {device_index: (flat_ws_tensor, flat_tc_tensor)}
2125+
self._cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
2126+
2127+
def get(self, device: torch.device, ws_numel: int, tc_numel: int):
2128+
"""Return (C_workspace_flat, tile_counters_flat) views of cached buffers.
2129+
2130+
Grows the cache if needed, never shrinks.
2131+
"""
2132+
idx = device.index if device.index is not None else 0
2133+
if idx in self._cache:
2134+
ws_buf, tc_buf = self._cache[idx]
2135+
if ws_buf.numel() >= ws_numel and tc_buf.numel() >= tc_numel:
2136+
return ws_buf[:ws_numel], tc_buf[:tc_numel]
2137+
2138+
# Allocate with 2x headroom to reduce re-allocations
2139+
ws_buf = torch.empty(max(ws_numel * 2, 1), device=device, dtype=torch.float32)
2140+
tc_buf = torch.empty(max(tc_numel * 2, 1024), device=device, dtype=torch.int32)
2141+
self._cache[idx] = (ws_buf, tc_buf)
2142+
return ws_buf[:ws_numel], tc_buf[:tc_numel]
2143+
2144+
def clear(self):
2145+
"""Free all cached buffers."""
2146+
self._cache.clear()
2147+
2148+
2149+
_workspace_cache = _WorkspaceCache()
2150+
2151+
21092152
def _kbit_gemm_prod_check(A, B_packed, B_absmax, codebook, N, k, k_chunks):
21102153
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
21112154
torch._check(
@@ -2164,15 +2207,14 @@ def _(
21642207
M = A.shape[0]
21652208
C = torch.empty(M, N, device=A.device, dtype=A.dtype)
21662209

2167-
# The persistent kernel auto-selects k_splits and TILE_N internally.
2168-
# TILE_N=64 for M<=16 gives more tiles; allocate for worst case.
21692210
TILE_M = 16
21702211
TILE_N = 64 # worst case (most tiles)
21712212
m_tiles = (M + TILE_M - 1) // TILE_M
21722213
n_tiles = N // TILE_N
21732214

2174-
C_workspace = torch.zeros(M, N, device=A.device, dtype=torch.float32)
2175-
tile_counters = torch.zeros(m_tiles * n_tiles, device=A.device, dtype=torch.int32)
2215+
ws_flat, tc_flat = _workspace_cache.get(A.device, M * N, m_tiles * n_tiles)
2216+
C_workspace = ws_flat.view(M, N)
2217+
tile_counters = tc_flat
21762218

21772219
_kbit_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, k, k_chunks, C, C_workspace, tile_counters)
21782220
return C
@@ -2247,8 +2289,9 @@ def _(
22472289
m_tiles = (M + TILE_M - 1) // TILE_M
22482290
n_tiles = N // TILE_N
22492291

2250-
C_workspace = torch.zeros(M, N, device=A.device, dtype=torch.float32)
2251-
tile_counters = torch.zeros(m_tiles * n_tiles, device=A.device, dtype=torch.int32)
2292+
ws_flat, tc_flat = _workspace_cache.get(A.device, M * N, m_tiles * n_tiles)
2293+
C_workspace = ws_flat.view(M, N)
2294+
tile_counters = tc_flat
22522295

22532296
_vq_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks, C, C_workspace, tile_counters, index_bits)
22542297
return C
@@ -2354,10 +2397,6 @@ def _(
23542397
total_M = A_concat.shape[0]
23552398
C_concat = torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
23562399

2357-
# Workspace for split-K atomicAdd reduction
2358-
C_workspace = torch.zeros(total_M, N, device=A_concat.device, dtype=torch.float32)
2359-
# Tile counters for split-K last-block detection
2360-
# Upper bound: num_experts * max_m_tiles * max_n_tiles
23612400
m_blocks = 1
23622401
if max_M > 48:
23632402
m_blocks = 4
@@ -2369,7 +2408,10 @@ def _(
23692408
n_tiles = N // tile_n
23702409
m_tiles = (max_M + m_blocks * 16 - 1) // (m_blocks * 16)
23712410
mn_tiles = num_experts * m_tiles * n_tiles
2372-
tile_counters = torch.zeros(mn_tiles, device=A_concat.device, dtype=torch.int32)
2411+
2412+
ws_flat, tc_flat = _workspace_cache.get(A_concat.device, total_M * N, mn_tiles)
2413+
C_workspace = ws_flat.view(total_M, N)
2414+
tile_counters = tc_flat
23732415

23742416
_kbit_grouped_gemm_impl(
23752417
A_concat,
@@ -2506,9 +2548,6 @@ def _(
25062548
total_M = A_concat.shape[0]
25072549
C_concat = torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
25082550

2509-
# Workspace for split-K atomicAdd reduction
2510-
C_workspace = torch.zeros(total_M, N, device=A_concat.device, dtype=torch.float32)
2511-
# Tile counters for split-K last-block detection
25122551
m_blocks = 1
25132552
if max_M > 48:
25142553
m_blocks = 4
@@ -2520,7 +2559,10 @@ def _(
25202559
n_tiles = N // tile_n
25212560
m_tiles = (max_M + m_blocks * 16 - 1) // (m_blocks * 16)
25222561
mn_tiles = num_experts * m_tiles * n_tiles
2523-
tile_counters = torch.zeros(mn_tiles, device=A_concat.device, dtype=torch.int32)
2562+
2563+
ws_flat, tc_flat = _workspace_cache.get(A_concat.device, total_M * N, mn_tiles)
2564+
C_workspace = ws_flat.view(total_M, N)
2565+
tile_counters = tc_flat
25242566

25252567
_vq_grouped_gemm_impl(
25262568
A_concat,

0 commit comments

Comments
 (0)