Skip to content

Commit a629ab1

Browse files
committed
fix: Add stream parameter to cgemm_nvfp4 for CUDA graph support
The kernel launch now uses the caller's stream via <<<grid, threads, 0, stream>>>. The Python dispatch passes _get_tensor_stream(A_packed). This enables CUDA graph capture for accurate benchmarking.
1 parent 9c96365 commit a629ab1

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ def _(
928928
ct.c_int(M),
929929
ct.c_int(N),
930930
ct.c_int(K),
931+
_get_tensor_stream(A_packed),
931932
)
932933

933934
# Apply tensor scales (the GEMM kernel operates on raw quantized values)

csrc/kernels_nvfp4_sm120.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,13 +485,13 @@ __global__ void kGemmNVFP4_simple(
485485
// ============================================================================
486486
extern "C" void cgemm_nvfp4(
487487
const unsigned char* A, const unsigned char* B, const unsigned char* SFA, const unsigned char* SFB, float* D, int M,
488-
int N, int K
488+
int N, int K, cudaStream_t stream
489489
) {
490490
int num_m_blocks = (M + BLOCK_M_DIM - 1) / BLOCK_M_DIM;
491491
int num_n_blocks = (N + BLOCK_N_DIM - 1) / BLOCK_N_DIM;
492492

493493
dim3 grid(num_n_blocks, num_m_blocks);
494494
int threads_per_block = WARPS_PER_BLOCK * 32; // 256
495495

496-
kGemmNVFP4_smem<<<grid, threads_per_block>>>(A, B, SFA, SFB, D, M, N, K);
496+
kGemmNVFP4_smem<<<grid, threads_per_block, 0, stream>>>(A, B, SFA, SFB, D, M, N, K);
497497
}

0 commit comments

Comments
 (0)