@@ -2106,6 +2106,49 @@ def _(data: torch.Tensor, dim: int, signs: Optional[torch.Tensor]) -> torch.Tens
21062106 return data
21072107
21082108
2109+ class _WorkspaceCache :
2110+ """Per-device cache for split-K workspace buffers (C_workspace + tile_counters).
2111+
2112+ Avoids repeated torch.zeros allocations in the default (non-workspace) path.
2113+ Buffers are allocated at the max size seen per device and reused via views.
2114+ The _impl functions call .zero_() on the views, so only used elements are zeroed.
2115+
2116+ Memory cost is modest: at M=16 with N=5120, C_workspace is 320 KB (float32)
2117+ and tile_counters is <1 KB. For MoE with 8 experts × max_M=16, C_workspace
2118+ is ~2.5 MB. Buffers are never freed until process exit.
2119+
2120+ Not thread-safe — assumes single-threaded inference (typical for LLM serving).
2121+ """
2122+
2123+ def __init__ (self ):
2124+ # {device_index: (flat_ws_tensor, flat_tc_tensor)}
2125+ self ._cache : dict [int , tuple [torch .Tensor , torch .Tensor ]] = {}
2126+
2127+ def get (self , device : torch .device , ws_numel : int , tc_numel : int ):
2128+ """Return (C_workspace_flat, tile_counters_flat) views of cached buffers.
2129+
2130+ Grows the cache if needed, never shrinks.
2131+ """
2132+ idx = device .index if device .index is not None else 0
2133+ if idx in self ._cache :
2134+ ws_buf , tc_buf = self ._cache [idx ]
2135+ if ws_buf .numel () >= ws_numel and tc_buf .numel () >= tc_numel :
2136+ return ws_buf [:ws_numel ], tc_buf [:tc_numel ]
2137+
2138+ # Allocate with 2x headroom to reduce re-allocations
2139+ ws_buf = torch .empty (max (ws_numel * 2 , 1 ), device = device , dtype = torch .float32 )
2140+ tc_buf = torch .empty (max (tc_numel * 2 , 1024 ), device = device , dtype = torch .int32 )
2141+ self ._cache [idx ] = (ws_buf , tc_buf )
2142+ return ws_buf [:ws_numel ], tc_buf [:tc_numel ]
2143+
2144+ def clear (self ):
2145+ """Free all cached buffers."""
2146+ self ._cache .clear ()
2147+
2148+
2149+ _workspace_cache = _WorkspaceCache ()
2150+
2151+
21092152def _kbit_gemm_prod_check (A , B_packed , B_absmax , codebook , N , k , k_chunks ):
21102153 torch ._check (k >= 2 and k <= 5 , lambda : f"k must be 2-5, got { k } " )
21112154 torch ._check (
@@ -2164,15 +2207,14 @@ def _(
21642207 M = A .shape [0 ]
21652208 C = torch .empty (M , N , device = A .device , dtype = A .dtype )
21662209
2167- # The persistent kernel auto-selects k_splits and TILE_N internally.
2168- # TILE_N=64 for M<=16 gives more tiles; allocate for worst case.
21692210 TILE_M = 16
21702211 TILE_N = 64 # worst case (most tiles)
21712212 m_tiles = (M + TILE_M - 1 ) // TILE_M
21722213 n_tiles = N // TILE_N
21732214
2174- C_workspace = torch .zeros (M , N , device = A .device , dtype = torch .float32 )
2175- tile_counters = torch .zeros (m_tiles * n_tiles , device = A .device , dtype = torch .int32 )
2215+ ws_flat , tc_flat = _workspace_cache .get (A .device , M * N , m_tiles * n_tiles )
2216+ C_workspace = ws_flat .view (M , N )
2217+ tile_counters = tc_flat
21762218
21772219 _kbit_gemm_prod_impl (A , B_packed , B_absmax , codebook , K_dim , N , k , k_chunks , C , C_workspace , tile_counters )
21782220 return C
@@ -2247,8 +2289,9 @@ def _(
22472289 m_tiles = (M + TILE_M - 1 ) // TILE_M
22482290 n_tiles = N // TILE_N
22492291
2250- C_workspace = torch .zeros (M , N , device = A .device , dtype = torch .float32 )
2251- tile_counters = torch .zeros (m_tiles * n_tiles , device = A .device , dtype = torch .int32 )
2292+ ws_flat , tc_flat = _workspace_cache .get (A .device , M * N , m_tiles * n_tiles )
2293+ C_workspace = ws_flat .view (M , N )
2294+ tile_counters = tc_flat
22522295
22532296 _vq_gemm_prod_impl (A , B_packed , B_absmax , codebook , K_dim , N , p , k_chunks , C , C_workspace , tile_counters , index_bits )
22542297 return C
@@ -2354,10 +2397,6 @@ def _(
23542397 total_M = A_concat .shape [0 ]
23552398 C_concat = torch .empty (total_M , N , device = A_concat .device , dtype = A_concat .dtype )
23562399
2357- # Workspace for split-K atomicAdd reduction
2358- C_workspace = torch .zeros (total_M , N , device = A_concat .device , dtype = torch .float32 )
2359- # Tile counters for split-K last-block detection
2360- # Upper bound: num_experts * max_m_tiles * max_n_tiles
23612400 m_blocks = 1
23622401 if max_M > 48 :
23632402 m_blocks = 4
@@ -2369,7 +2408,10 @@ def _(
23692408 n_tiles = N // tile_n
23702409 m_tiles = (max_M + m_blocks * 16 - 1 ) // (m_blocks * 16 )
23712410 mn_tiles = num_experts * m_tiles * n_tiles
2372- tile_counters = torch .zeros (mn_tiles , device = A_concat .device , dtype = torch .int32 )
2411+
2412+ ws_flat , tc_flat = _workspace_cache .get (A_concat .device , total_M * N , mn_tiles )
2413+ C_workspace = ws_flat .view (total_M , N )
2414+ tile_counters = tc_flat
23732415
23742416 _kbit_grouped_gemm_impl (
23752417 A_concat ,
@@ -2506,9 +2548,6 @@ def _(
25062548 total_M = A_concat .shape [0 ]
25072549 C_concat = torch .empty (total_M , N , device = A_concat .device , dtype = A_concat .dtype )
25082550
2509- # Workspace for split-K atomicAdd reduction
2510- C_workspace = torch .zeros (total_M , N , device = A_concat .device , dtype = torch .float32 )
2511- # Tile counters for split-K last-block detection
25122551 m_blocks = 1
25132552 if max_M > 48 :
25142553 m_blocks = 4
@@ -2520,7 +2559,10 @@ def _(
25202559 n_tiles = N // tile_n
25212560 m_tiles = (max_M + m_blocks * 16 - 1 ) // (m_blocks * 16 )
25222561 mn_tiles = num_experts * m_tiles * n_tiles
2523- tile_counters = torch .zeros (mn_tiles , device = A_concat .device , dtype = torch .int32 )
2562+
2563+ ws_flat , tc_flat = _workspace_cache .get (A_concat .device , total_M * N , mn_tiles )
2564+ C_workspace = ws_flat .view (total_M , N )
2565+ tile_counters = tc_flat
25242566
25252567 _vq_grouped_gemm_impl (
25262568 A_concat ,
0 commit comments