|
2 | 2 |
|
3 | 3 | import cupy as cp |
4 | 4 |
|
5 | | -kernel_code = r""" |
| 5 | +kernel_code_pairwise = r""" |
6 | 6 | extern "C" __global__ |
7 | | -void occur_count_kernel(const float* __restrict__ spatial, |
| 7 | +void occur_count_kernel_pairwise(const float* __restrict__ spatial, |
8 | 8 | const float* __restrict__ thresholds, |
9 | 9 | const int* __restrict__ label_idx, |
10 | 10 | int* __restrict__ result, |
|
54 | 54 | } |
55 | 55 | } |
56 | 56 | """ |
57 | | -occur_count_kernel = cp.RawKernel(kernel_code, "occur_count_kernel") |
| 57 | +occur_count_kernel_pairwise = cp.RawKernel( |
| 58 | + kernel_code_pairwise, "occur_count_kernel_pairwise" |
| 59 | +) |
58 | 60 |
|
59 | | -kernel_code_fast = r""" |
| 61 | +kernel_code_pairwise_fast = r""" |
60 | 62 | extern "C" __global__ |
61 | | -void occur_count_kernel_fast(const float* __restrict__ spatial, |
| 63 | +void occur_count_kernel_pairwise_fast(const float* __restrict__ spatial, |
62 | 64 | const float* __restrict__ thresholds, |
63 | 65 | const int* __restrict__ label_idx, |
64 | 66 | int* __restrict__ result, |
|
68 | 70 | { |
69 | 71 | extern __shared__ float shared[]; |
70 | 72 | float* Y = shared; |
71 | | - int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 73 | + int i = blockIdx.x; |
72 | 74 | int r = blockIdx.y; |
73 | | - for (int j = 0; j < k * k ; j++){ |
74 | | - Y[k * k * threadIdx.x + j] = 0; |
| 75 | + for (int j = threadIdx.x; j < k * blockDim.x ; j+= blockDim.x){ |
| 76 | + Y[j] = 0; |
75 | 77 | } |
76 | 78 | __syncthreads(); |
77 | | - for(int j = i+1; j< n; j++){ |
78 | | - float spx = spatial[i * 2]; |
79 | | - float spy = spatial[i * 2 + 1]; |
| 79 | +
|
| 80 | + float spx = spatial[i * 2]; |
| 81 | + int low = label_idx[i]; |
| 82 | + float spy = spatial[i * 2 + 1]; |
| 83 | + for(int j = i+1+ threadIdx.x; j< n; j+= blockDim.x){ |
80 | 84 | float dx = spx - spatial[j * 2]; |
81 | 85 | float dy = spy - spatial[j * 2 + 1]; |
82 | 86 | float dist_sq = dx * dx + dy * dy; |
83 | | - int low = label_idx[i]; |
84 | 87 | int high = label_idx[j]; |
85 | | - if (high < low) { |
86 | | - int tmp = low; |
87 | | - low = high; |
88 | | - high = tmp; |
89 | | - } |
90 | | -
|
91 | 88 | if (dist_sq <= thresholds[r]) { |
92 | | - int index = k * k * threadIdx.x + low * k + high; |
| 89 | + int index = k * threadIdx.x + high; |
93 | 90 | Y[index] += 1; |
94 | 91 | } |
95 | 92 | } |
96 | | -
|
97 | 93 | __syncthreads(); |
98 | | - for (int bin = threadIdx.x; bin < k * k; bin += blockDim.x) { |
99 | | - int blockSum = 0; |
100 | | - for (int t = 0; t < blockDim.x; t++) { |
101 | | - blockSum += Y[t * (k * k) + bin]; |
102 | | - } |
103 | | - if (blockSum>0) atomicAdd(&result[r*(k*k)+bin], blockSum); |
104 | 94 |
|
| 95 | + for (int j = threadIdx.x; j < k; j+= blockDim.x){ |
| 96 | + float sum = 0; |
| 97 | + for (int t = 0; t < blockDim.x; t++){ |
| 98 | + int index = k * t + j; |
| 99 | + sum += Y[index]; |
| 100 | + } |
| 101 | + if (low < j){ |
| 102 | + if (sum>0) atomicAdd(&result[r*(k*k)+low*k+j], sum); |
| 103 | + } |
| 104 | + else{ |
| 105 | + if (sum>0) atomicAdd(&result[r*(k*k)+j*k+low], sum); |
| 106 | + } |
105 | 107 | } |
106 | 108 | __syncthreads(); |
107 | 109 | } |
108 | 110 | """ |
109 | | - |
110 | | -# Compile the kernel. |
111 | | -occur_count_kernel_fast = cp.RawKernel(kernel_code_fast, "occur_count_kernel_fast") |
| 111 | +occur_count_kernel_pairwise_fast = cp.RawKernel( |
| 112 | + kernel_code_pairwise_fast, "occur_count_kernel_pairwise_fast" |
| 113 | +) |
112 | 114 |
|
113 | 115 |
|
114 | | -kernel_code2 = r""" |
| 116 | +occur_reduction_kernel_code_shared = r""" |
115 | 117 | extern "C" __global__ |
116 | | -void occur_reduction_kernel(const int* __restrict__ result, |
| 118 | +void occur_reduction_kernel_shared(const int* __restrict__ result, |
117 | 119 | float* __restrict__ out, |
118 | 120 | int k, |
119 | 121 | int l_val, |
|
221 | 223 | out[i * (k * l_val) + j * l_val + r_th] = final_val; |
222 | 224 | } |
223 | 225 | } |
| 226 | + __syncthreads(); |
| 227 | +} |
| 228 | +""" |
| 229 | +occur_reduction_kernel_shared = cp.RawKernel( |
| 230 | + occur_reduction_kernel_code_shared, "occur_reduction_kernel_shared" |
| 231 | +) |
| 232 | + |
| 233 | +occur_reduction_kernel_code_global = r""" |
| 234 | +extern "C" __global__ |
| 235 | +void occur_reduction_kernel_global(const int* __restrict__ result, |
| 236 | + float* __restrict__ inter_out, |
| 237 | + float* __restrict__ out, |
| 238 | + int k, |
| 239 | + int l_val, |
| 240 | + int format) |
| 241 | +{ |
| 242 | + // Each block handles one threshold index. |
| 243 | + int r_th = blockIdx.x; // threshold index |
| 244 | + if (r_th >= l_val) |
| 245 | + return; |
| 246 | + // Shared memory allocation |
| 247 | + extern __shared__ float shared[]; |
| 248 | + float* Y = inter_out + r_th*k*k; |
| 249 | + float* col_sum = shared; |
| 250 | +
|
| 251 | + int total_elements = k * k; |
| 252 | +
|
| 253 | + // --- Load counts for this threshold and convert to float--- |
| 254 | + if (format == 0){ |
| 255 | + for (int i = threadIdx.x; i < k; i += blockDim.x){ |
| 256 | + for (int j = 0; j<k;j++){ |
| 257 | + Y[i * k + j] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th]); |
| 258 | + Y[j * k + i] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th]); |
| 259 | + Y[i * k + j] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th+l_val]); |
| 260 | + Y[j * k + i] += float(result[i * (k * l_val*2) + j * l_val*2 + r_th+l_val]); |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + else{ |
| 265 | + for (int i = threadIdx.x; i < k; i += blockDim.x){ |
| 266 | + for (int j = 0; j<k;j++){ |
| 267 | + Y[i * k + j] += float(result[r_th * (k * k) + i * k + j]); |
| 268 | + Y[j * k + i] += float(result[r_th * (k * k) + i * k + j]); |
| 269 | + } |
| 270 | + } |
| 271 | + } |
| 272 | + __syncthreads(); |
| 273 | +
|
| 274 | + // Compute total sum of the counts |
| 275 | + __shared__ float total; |
| 276 | + float sum_val = 0.0f; |
| 277 | + for (int idx = threadIdx.x; idx < total_elements; idx += blockDim.x) { |
| 278 | + sum_val += Y[idx]; |
| 279 | + } |
| 280 | + __syncthreads(); |
| 281 | + // Warp-level reduction |
| 282 | + unsigned int mask = 0xFFFFFFFF; // full warp mask |
| 283 | + for (int offset = warpSize / 2; offset > 0; offset /= 2) { |
| 284 | + sum_val += __shfl_down_sync(mask, sum_val, offset); |
| 285 | + } |
| 286 | + __syncthreads(); |
| 287 | + if (threadIdx.x == 0) { |
| 288 | + total = sum_val; |
| 289 | + } |
| 290 | + __syncthreads(); |
| 291 | +
|
| 292 | + // Normalize the matrix Y = Y / total (if total > 0) |
| 293 | + if (total > 0.0f) { |
| 294 | + for (int idx = threadIdx.x; idx < total_elements; idx += blockDim.x) { |
| 295 | + Y[idx] = Y[idx] / total; |
| 296 | + } |
| 297 | + } else { |
| 298 | + for (int i = threadIdx.x; i < k; i += blockDim.x) { |
| 299 | + for (int j = 0; j < k; j++) { |
| 300 | + out[i * (k * l_val) + j * l_val + r_th] = 0.0f; |
| 301 | + } |
| 302 | + } |
| 303 | + return; |
| 304 | + } |
| 305 | + __syncthreads(); |
| 306 | +
|
| 307 | + // Compute column sums of the normalized matrix |
| 308 | + for (int j = threadIdx.x; j < k; j += blockDim.x) { |
| 309 | + float sum_col = 0.0f; |
| 310 | + for (int i = 0; i < k; i++) { |
| 311 | + sum_col += Y[i * k + j]; |
| 312 | + } |
| 313 | + col_sum[j] = sum_col; |
| 314 | + } |
| 315 | + __syncthreads(); |
| 316 | +
|
| 317 | + // Compute conditional probabilities |
| 318 | + for (int i = threadIdx.x; i < k; i += blockDim.x) { |
| 319 | + float row_sum = 0.0f; |
| 320 | + for (int j = 0; j < k; j++) { |
| 321 | + row_sum += Y[i * k + j]; |
| 322 | + } |
| 323 | +
|
| 324 | + for (int j = 0; j < k; j++) { |
| 325 | + float cond = 0.0f; |
| 326 | + if (row_sum != 0.0f) { |
| 327 | + cond = Y[i * k + j] / row_sum; |
| 328 | + } |
| 329 | +
|
| 330 | + float final_val = 0.0f; |
| 331 | + if (col_sum[j] != 0.0f) { |
| 332 | + final_val = cond / col_sum[j]; |
| 333 | + } |
| 334 | +
|
| 335 | + // Write to output with (row, column, threshold) ordering |
| 336 | + out[i * (k * l_val) + j * l_val + r_th] = final_val; |
| 337 | + } |
| 338 | + } |
| 339 | + __syncthreads(); |
224 | 340 | } |
225 | 341 | """ |
226 | | -occur_count_kernel2 = cp.RawKernel(kernel_code2, "occur_reduction_kernel") |
| 342 | +occur_reduction_kernel_global = cp.RawKernel( |
| 343 | + occur_reduction_kernel_code_global, "occur_reduction_kernel_global" |
| 344 | +) |
0 commit comments