Skip to content

Commit 9931b25

Browse files
TimDettmersclaude
andcommitted
perf: Cache num_sms for CUDA graph safety in kbit launchers
Replace per-call cudaGetDevice()/cudaDeviceGetAttribute() with a cached static function cachedNumSMs(). This removes CUDA runtime API calls from the kernel launch path, making kbitGemmProd, kbitGroupedGemmProd, and kbitScalarGemvTiledV2 safe for CUDA graph capture. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3a2cf58 commit 9931b25

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

csrc/ops.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,17 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
15241524
} // end persistent work loop
15251525
}
15261526

1527+
// Cached SM count — queried once per process, safe for CUDA graph capture.
1528+
static int cachedNumSMs() {
1529+
static int cached = -1;
1530+
if (cached < 0) {
1531+
int dev;
1532+
cudaGetDevice(&dev);
1533+
cudaDeviceGetAttribute(&cached, cudaDevAttrMultiProcessorCount, dev);
1534+
}
1535+
return cached;
1536+
}
1537+
15271538
// Pipeline stage count: 4 on datacenter GPUs (more shmem), 2 on consumer.
15281539
static int pipelineNumStages() {
15291540
static int cached = -1;
@@ -1608,11 +1619,7 @@ void kbitGemmProd(
16081619
const scalar_t* A, const unsigned int* B_packed, const ABSMAX_T* B_absmax, const float* codebook, scalar_t* C,
16091620
float* C_workspace, int* tile_counters, int M, int K_dim, int N, int k_chunks, cudaStream_t stream
16101621
) {
1611-
// Query SM count for persistent kernel grid sizing and M_BLOCKS dispatch
1612-
int dev;
1613-
cudaGetDevice(&dev);
1614-
int num_sms;
1615-
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev);
1622+
const int num_sms = cachedNumSMs();
16161623

16171624
// Choose M_BLOCKS. With the persistent kernel, the grid always has
16181625
// num_SMs blocks, so the SM utilization concern is gone. Choose the
@@ -2089,10 +2096,7 @@ void kbitGroupedGemmProd(
20892096
if (max_M == 0 || N == 0)
20902097
return;
20912098

2092-
int dev;
2093-
cudaGetDevice(&dev);
2094-
int num_sms;
2095-
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev);
2099+
const int num_sms = cachedNumSMs();
20962100

20972101
int m_blocks = 1;
20982102
if (max_M > 48)
@@ -2648,10 +2652,7 @@ void kbitScalarGemvTiledV2(
26482652
const float* codebook, scalar_t* C, float* C_workspace, int* tile_counters,
26492653
int M, int K_dim, int N, cudaStream_t stream
26502654
) {
2651-
int dev;
2652-
cudaGetDevice(&dev);
2653-
int num_sms;
2654-
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev);
2655+
const int num_sms = cachedNumSMs();
26552656

26562657
#define LAUNCH_GEMV_V2(MV) \
26572658
kbitScalarGemvTiledV2Launch<K, MV, scalar_t, ABSMAX_T>( \

0 commit comments

Comments
 (0)