@@ -9,22 +9,22 @@ extern "C" {
99 * Each block processes one group pair, threads collaborate within the block
1010 */
1111__global__ void compute_group_distances (
12- const double * __restrict__ embedding,
12+ const float * __restrict__ embedding,
1313 const int * __restrict__ cat_offsets,
1414 const int * __restrict__ cell_indices,
1515 const int * __restrict__ pair_left,
1616 const int * __restrict__ pair_right,
17- double * __restrict__ d_other,
17+ float * __restrict__ d_other,
1818 int k,
1919 int n_features)
2020{
21- extern __shared__ double shared_sums[];
21+ extern __shared__ float shared_sums[];
2222
2323 const int thread_id = threadIdx .x ;
2424 const int block_id = blockIdx .x ;
2525 const int block_size = blockDim .x ;
2626
27- double local_sum = 0.0 ;
27+ float local_sum = 0 .0f ;
2828
2929 const int a = pair_left[block_id];
3030 const int b = pair_right[block_id];
@@ -46,14 +46,14 @@ __global__ void compute_group_distances(
4646 for (int jb = start_b; jb < end_b; ++jb) {
4747 const int idx_j = cell_indices[jb];
4848
49- double dist_sq = 0.0 ;
49+ float dist_sq = 0 .0f ;
5050 #pragma unroll
5151 for (int feat = 0 ; feat < n_features; ++feat) {
52- double diff = embedding[idx_i * n_features + feat] -
52+ float diff = embedding[idx_i * n_features + feat] -
5353 embedding[idx_j * n_features + feat];
5454 dist_sq += diff * diff;
5555 }
56- local_sum += sqrt (dist_sq);
56+ local_sum += sqrtf (dist_sq);
5757 }
5858 }
5959
@@ -70,7 +70,7 @@ __global__ void compute_group_distances(
7070
7171 if (thread_id == 0 ) {
7272 // Store mean between-group distance
73- d_other[block_id] = shared_sums[0 ] / (double )(n_a * n_b);
73+ d_other[block_id] = shared_sums[0 ] / (float )(n_a * n_b);
7474 }
7575}
7676
0 commit comments