Skip to content

Commit 5eb0b93

Browse files
authored
Update co oc kernel (#347)
1 parent 9128298 commit 5eb0b93

2 files changed

Lines changed: 218 additions & 58 deletions

File tree

src/rapids_singlecell/squidpy_gpu/_co_oc.py

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import math
43
from typing import TYPE_CHECKING
54

65
import cupy as cp
@@ -9,9 +8,10 @@
98

109
from ._utils import _assert_categorical_obs, _assert_spatial_basis
1110
from .kernels._co_oc import (
12-
occur_count_kernel,
13-
occur_count_kernel2,
14-
occur_count_kernel_fast,
11+
occur_count_kernel_pairwise,
12+
occur_count_kernel_pairwise_fast,
13+
occur_reduction_kernel_global,
14+
occur_reduction_kernel_shared,
1515
)
1616

1717
if TYPE_CHECKING:
@@ -99,6 +99,26 @@ def _find_min_max(spatial: NDArrayC) -> tuple[float, float]:
9999
return thres_min.astype(np.float32), thres_max.astype(np.float32)
100100

101101

102+
def calculate_optimal_k(target_occupancy=0.4):
103+
props = cp.cuda.runtime.getDeviceProperties(0)
104+
105+
# Get key SM properties
106+
shared_mem_per_sm = props["sharedMemPerMultiprocessor"] # bytes
107+
max_warps_per_sm = props["maxThreadsPerMultiProcessor"] // 32
108+
block_size = 128 # Your current block size
109+
warps_per_block = block_size // 32 # 4 warps per block
110+
111+
# Target blocks per SM based on desired occupancy
112+
target_blocks = int(max_warps_per_sm * target_occupancy) // warps_per_block
113+
114+
# Maximum shared memory per block to achieve target occupancy
115+
max_shared_per_block = shared_mem_per_sm // target_blocks
116+
117+
# Calculate max k value
118+
max_k = max_shared_per_block // (block_size * cp.dtype("float32").itemsize)
119+
return max_k
120+
121+
102122
def _co_occurrence_helper(
103123
spatial: NDArrayC, v_radium: NDArrayC, labs: NDArrayC, fast: bool = True
104124
) -> NDArrayC:
@@ -125,35 +145,57 @@ def _co_occurrence_helper(
125145
k = len(labs_unique)
126146
l_val = len(v_radium) - 1
127147
thresholds = (v_radium[1:]) ** 2
128-
shared_mem_size_fast = (k * k * 32) * cp.dtype("float32").itemsize
129-
props = cp.cuda.runtime.getDeviceProperties(0)
130-
# if the shared memory is enough, use the fast kernel
131-
if props["sharedMemPerBlock"] > shared_mem_size_fast and fast:
132-
counts = cp.zeros((l_val, k, k), dtype=cp.int32)
133-
grid = (int(math.ceil(n / 32)), l_val)
134-
block = (32, 1)
135-
occur_count_kernel_fast(
136-
grid,
137-
block,
138-
(spatial, thresholds, labs, counts, n, k, l_val),
139-
shared_mem=shared_mem_size_fast,
140-
)
141-
reader = 1
142-
# if the shared memory is not enough, use the slow kernel
143-
else:
148+
use_fast_kernel = False # Flag to track which kernel path was taken
149+
if fast:
150+
# Optimize occupancy vs speed
151+
can_use_fast_kernel = calculate_optimal_k(0.4) > k
152+
# If shared memory is sufficient, use the fast kernel
153+
if can_use_fast_kernel:
154+
shared_mem_size_fast = (k * 128) * cp.dtype("float32").itemsize
155+
counts = cp.zeros((l_val, k, k), dtype=cp.int32)
156+
grid = (n, l_val)
157+
block = (128, 1)
158+
occur_count_kernel_pairwise_fast(
159+
grid,
160+
block,
161+
(spatial, thresholds, labs, counts, n, k, l_val),
162+
shared_mem=shared_mem_size_fast,
163+
)
164+
reader = 1
165+
use_fast_kernel = True
166+
167+
# Fallback to the standard kernel if fast=False or shared memory was insufficient
168+
if not use_fast_kernel:
144169
counts = cp.zeros((k, k, l_val * 2), dtype=cp.int32)
145170
grid = (n,)
146171
block = (32,)
147-
occur_count_kernel(
172+
occur_count_kernel_pairwise(
148173
grid, block, (spatial, thresholds, labs, counts, n, k, l_val)
149174
)
150175
reader = 0
151176

152177
occ_prob = cp.empty((k, k, l_val), dtype=np.float32)
153178
shared_mem_size = (k * k + k) * cp.dtype("float32").itemsize
154-
grid = (l_val,)
155-
block = (32,)
156-
occur_count_kernel2(
157-
grid, block, (counts, occ_prob, k, l_val, reader), shared_mem=shared_mem_size
158-
)
179+
props = cp.cuda.runtime.getDeviceProperties(0)
180+
if fast and shared_mem_size < props["sharedMemPerBlock"]:
181+
grid2 = (l_val,)
182+
block2 = (32,)
183+
occur_reduction_kernel_shared(
184+
grid2,
185+
block2,
186+
(counts, occ_prob, k, l_val, reader),
187+
shared_mem=shared_mem_size,
188+
)
189+
else:
190+
shared_mem_size = (k) * cp.dtype("float32").itemsize
191+
grid2 = (l_val,)
192+
block2 = (32,)
193+
inter_out = cp.zeros((l_val, k, k), dtype=np.float32)
194+
occur_reduction_kernel_global(
195+
grid2,
196+
block2,
197+
(counts, inter_out, occ_prob, k, l_val, reader),
198+
shared_mem=shared_mem_size,
199+
)
200+
159201
return occ_prob

src/rapids_singlecell/squidpy_gpu/kernels/_co_oc.py

Lines changed: 150 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import cupy as cp
44

5-
kernel_code = r"""
5+
kernel_code_pairwise = r"""
66
extern "C" __global__
7-
void occur_count_kernel(const float* __restrict__ spatial,
7+
void occur_count_kernel_pairwise(const float* __restrict__ spatial,
88
const float* __restrict__ thresholds,
99
const int* __restrict__ label_idx,
1010
int* __restrict__ result,
@@ -54,11 +54,13 @@
5454
}
5555
}
5656
"""
57-
occur_count_kernel = cp.RawKernel(kernel_code, "occur_count_kernel")
57+
occur_count_kernel_pairwise = cp.RawKernel(
58+
kernel_code_pairwise, "occur_count_kernel_pairwise"
59+
)
5860

59-
kernel_code_fast = r"""
61+
kernel_code_pairwise_fast = r"""
6062
extern "C" __global__
61-
void occur_count_kernel_fast(const float* __restrict__ spatial,
63+
void occur_count_kernel_pairwise_fast(const float* __restrict__ spatial,
6264
const float* __restrict__ thresholds,
6365
const int* __restrict__ label_idx,
6466
int* __restrict__ result,
@@ -68,52 +70,52 @@
6870
{
6971
extern __shared__ float shared[];
7072
float* Y = shared;
71-
int i = blockIdx.x * blockDim.x + threadIdx.x;
73+
int i = blockIdx.x;
7274
int r = blockIdx.y;
73-
for (int j = 0; j < k * k ; j++){
74-
Y[k * k * threadIdx.x + j] = 0;
75+
for (int j = threadIdx.x; j < k * blockDim.x ; j+= blockDim.x){
76+
Y[j] = 0;
7577
}
7678
__syncthreads();
77-
for(int j = i+1; j< n; j++){
78-
float spx = spatial[i * 2];
79-
float spy = spatial[i * 2 + 1];
79+
80+
float spx = spatial[i * 2];
81+
int low = label_idx[i];
82+
float spy = spatial[i * 2 + 1];
83+
for(int j = i+1+ threadIdx.x; j< n; j+= blockDim.x){
8084
float dx = spx - spatial[j * 2];
8185
float dy = spy - spatial[j * 2 + 1];
8286
float dist_sq = dx * dx + dy * dy;
83-
int low = label_idx[i];
8487
int high = label_idx[j];
85-
if (high < low) {
86-
int tmp = low;
87-
low = high;
88-
high = tmp;
89-
}
90-
9188
if (dist_sq <= thresholds[r]) {
92-
int index = k * k * threadIdx.x + low * k + high;
89+
int index = k * threadIdx.x + high;
9390
Y[index] += 1;
9491
}
9592
}
96-
9793
__syncthreads();
98-
for (int bin = threadIdx.x; bin < k * k; bin += blockDim.x) {
99-
int blockSum = 0;
100-
for (int t = 0; t < blockDim.x; t++) {
101-
blockSum += Y[t * (k * k) + bin];
102-
}
103-
if (blockSum>0) atomicAdd(&result[r*(k*k)+bin], blockSum);
10494
95+
for (int j = threadIdx.x; j < k; j+= blockDim.x){
96+
float sum = 0;
97+
for (int t = 0; t < blockDim.x; t++){
98+
int index = k * t + j;
99+
sum += Y[index];
100+
}
101+
if (low < j){
102+
if (sum>0) atomicAdd(&result[r*(k*k)+low*k+j], sum);
103+
}
104+
else{
105+
if (sum>0) atomicAdd(&result[r*(k*k)+j*k+low], sum);
106+
}
105107
}
106108
__syncthreads();
107109
}
108110
"""
109-
110-
# Compile the kernel.
111-
occur_count_kernel_fast = cp.RawKernel(kernel_code_fast, "occur_count_kernel_fast")
111+
occur_count_kernel_pairwise_fast = cp.RawKernel(
112+
kernel_code_pairwise_fast, "occur_count_kernel_pairwise_fast"
113+
)
112114

113115

114-
kernel_code2 = r"""
116+
occur_reduction_kernel_code_shared = r"""
115117
extern "C" __global__
116-
void occur_reduction_kernel(const int* __restrict__ result,
118+
void occur_reduction_kernel_shared(const int* __restrict__ result,
117119
float* __restrict__ out,
118120
int k,
119121
int l_val,
@@ -221,6 +223,122 @@
221223
out[i * (k * l_val) + j * l_val + r_th] = final_val;
222224
}
223225
}
226+
__syncthreads();
227+
}
228+
"""
229+
occur_reduction_kernel_shared = cp.RawKernel(
230+
occur_reduction_kernel_code_shared, "occur_reduction_kernel_shared"
231+
)
232+
233+
occur_reduction_kernel_code_global = r"""
234+
extern "C" __global__
235+
void occur_reduction_kernel_global(const int* __restrict__ result,
236+
float* __restrict__ inter_out,
237+
float* __restrict__ out,
238+
int k,
239+
int l_val,
240+
int format)
241+
{
242+
// Each block handles one threshold index.
243+
int r_th = blockIdx.x; // threshold index
244+
if (r_th >= l_val)
245+
return;
246+
// Shared memory allocation
247+
extern __shared__ float shared[];
248+
float* Y = inter_out + r_th*k*k;
249+
float* col_sum = shared;
250+
251+
int total_elements = k * k;
252+
253+
// --- Load counts for this threshold and convert to float---
254+
if (format == 0){
255+
for (int i = threadIdx.x; i < k; i += blockDim.x){
256+
for (int j = 0; j<k;j++){
257+
Y[i * k + j] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th]);
258+
Y[j * k + i] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th]);
259+
Y[i * k + j] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th+l_val]);
260+
Y[j * k + i] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th+l_val]);
261+
}
262+
}
263+
}
264+
else{
265+
for (int i = threadIdx.x; i < k; i += blockDim.x){
266+
for (int j = 0; j<k;j++){
267+
Y[i * k + j] += float(result[r_th * (k * k) + i * k + j]);
268+
Y[j * k + i] += float(result[r_th * (k * k) + i * k + j]);
269+
}
270+
}
271+
}
272+
__syncthreads();
273+
274+
// Compute total sum of the counts
275+
__shared__ float total;
276+
float sum_val = 0.0f;
277+
for (int idx = threadIdx.x; idx < total_elements; idx += blockDim.x) {
278+
sum_val += Y[idx];
279+
}
280+
__syncthreads();
281+
// Warp-level reduction
282+
unsigned int mask = 0xFFFFFFFF; // full warp mask
283+
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
284+
sum_val += __shfl_down_sync(mask, sum_val, offset);
285+
}
286+
__syncthreads();
287+
if (threadIdx.x == 0) {
288+
total = sum_val;
289+
}
290+
__syncthreads();
291+
292+
// Normalize the matrix Y = Y / total (if total > 0)
293+
if (total > 0.0f) {
294+
for (int idx = threadIdx.x; idx < total_elements; idx += blockDim.x) {
295+
Y[idx] = Y[idx] / total;
296+
}
297+
} else {
298+
for (int i = threadIdx.x; i < k; i += blockDim.x) {
299+
for (int j = 0; j < k; j++) {
300+
out[i * (k * l_val) + j * l_val + r_th] = 0.0f;
301+
}
302+
}
303+
return;
304+
}
305+
__syncthreads();
306+
307+
// Compute column sums of the normalized matrix
308+
for (int j = threadIdx.x; j < k; j += blockDim.x) {
309+
float sum_col = 0.0f;
310+
for (int i = 0; i < k; i++) {
311+
sum_col += Y[i * k + j];
312+
}
313+
col_sum[j] = sum_col;
314+
}
315+
__syncthreads();
316+
317+
// Compute conditional probabilities
318+
for (int i = threadIdx.x; i < k; i += blockDim.x) {
319+
float row_sum = 0.0f;
320+
for (int j = 0; j < k; j++) {
321+
row_sum += Y[i * k + j];
322+
}
323+
324+
for (int j = 0; j < k; j++) {
325+
float cond = 0.0f;
326+
if (row_sum != 0.0f) {
327+
cond = Y[i * k + j] / row_sum;
328+
}
329+
330+
float final_val = 0.0f;
331+
if (col_sum[j] != 0.0f) {
332+
final_val = cond / col_sum[j];
333+
}
334+
335+
// Write to output with (row, column, threshold) ordering
336+
out[i * (k * l_val) + j * l_val + r_th] = final_val;
337+
}
338+
}
339+
__syncthreads();
224340
}
225341
"""
226-
occur_count_kernel2 = cp.RawKernel(kernel_code2, "occur_reduction_kernel")
342+
occur_reduction_kernel_global = cp.RawKernel(
343+
occur_reduction_kernel_code_global, "occur_reduction_kernel_global"
344+
)

0 commit comments

Comments
 (0)