Skip to content

Commit dc749a9

Browse files
TimDettmersclaude
andcommitted
Fix stale pointer bug in batched MoE GEMM cache
Include data_ptr() values in the init cache key, not just dimensions. CUTLASS initialize() bakes data pointers into kernel params. When different callers (module's _forward_batched vs torch op gemm_nvfp4_moe) use the same dimensions but different buffer addresses, the old cache incorrectly skipped re-init, causing run() to write to stale pointers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8bf8759 commit dc749a9

File tree

1 file changed

+6
-1
lines changed
  • bitsandbytes/backends/cuda

1 file changed

+6
-1
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,12 @@ def _batched_moe_sm100_init_if_needed(
13631363
global _moe_batched_sm100_cache
13641364
_ensure_moe_batched_restype()
13651365

1366-
cache_key = (N, K, max_M, num_experts)
1366+
cache_key = (
1367+
N, K, max_M, num_experts,
1368+
A_batched.data_ptr(), B_all.data_ptr(),
1369+
SFA_batched.data_ptr(), SFB_all.data_ptr(),
1370+
D_out.data_ptr(), alpha.data_ptr(),
1371+
)
13671372

13681373
if (_moe_batched_sm100_cache is not None
13691374
and _moe_batched_sm100_cache["key"] == cache_key):

0 commit comments

Comments
 (0)