Skip to content

Commit 6b7d7b2

Browse files
committed
swich to f32
1 parent fb10db6 commit 6b7d7b2

2 files changed

Lines changed: 12 additions & 13 deletions

File tree

src/rapids_singlecell/pertpy_gpu/_distances_standalone.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def compute_pairwise_means_gpu(
7676
num_pairs = len(pair_left) # k * (k-1) pairs instead of k²
7777

7878
# Allocate output for off-diagonal distances only
79-
d_other_offdiag = cp.zeros(num_pairs, dtype=np.float64)
79+
d_other_offdiag = cp.zeros(num_pairs, dtype=np.float32)
8080

8181
# Choose optimal block size
8282
props = cp.cuda.runtime.getDeviceProperties(0)
@@ -85,7 +85,7 @@ def compute_pairwise_means_gpu(
8585
chosen_threads = None
8686
shared_mem_size = 0 # TODO: think of a better way to do this
8787
for tpb in (1024, 512, 256, 128, 64, 32):
88-
required = tpb * cp.dtype(cp.float64).itemsize
88+
required = tpb * cp.dtype(cp.float32).itemsize
8989
if required <= max_smem:
9090
chosen_threads = tpb
9191
shared_mem_size = required
@@ -111,7 +111,7 @@ def compute_pairwise_means_gpu(
111111
)
112112

113113
# Build full k x k matrix
114-
pairwise_means = cp.zeros((k, k), dtype=np.float64)
114+
pairwise_means = cp.zeros((k, k), dtype=np.float32)
115115

116116
# Fill the full matrix
117117
for i, idx in enumerate(pair_indices.get()):
@@ -322,10 +322,9 @@ def pairwise_edistance_gpu(
322322
df : pd.DataFrame
323323
Final edistance matrix
324324
"""
325-
# 1. Prepare data (same as original)
326325
_assert_categorical_obs(adata, key=groupby)
327326

328-
embedding = cp.array(adata.obsm[obsm_key]).astype(np.float64)
327+
embedding = cp.array(adata.obsm[obsm_key]).astype(np.float32) # Changed from float64
329328
original_groups = adata.obs[groupby]
330329
group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}
331330
group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)

src/rapids_singlecell/pertpy_gpu/kernels/edistance_kernels.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,22 @@ extern "C" {
99
* Each block processes one group pair, threads collaborate within the block
1010
*/
1111
__global__ void compute_group_distances(
12-
const double* __restrict__ embedding,
12+
const float* __restrict__ embedding,
1313
const int* __restrict__ cat_offsets,
1414
const int* __restrict__ cell_indices,
1515
const int* __restrict__ pair_left,
1616
const int* __restrict__ pair_right,
17-
double* __restrict__ d_other,
17+
float* __restrict__ d_other,
1818
int k,
1919
int n_features)
2020
{
21-
extern __shared__ double shared_sums[];
21+
extern __shared__ float shared_sums[];
2222

2323
const int thread_id = threadIdx.x;
2424
const int block_id = blockIdx.x;
2525
const int block_size = blockDim.x;
2626

27-
double local_sum = 0.0;
27+
float local_sum = 0.0f;
2828

2929
const int a = pair_left[block_id];
3030
const int b = pair_right[block_id];
@@ -46,14 +46,14 @@ __global__ void compute_group_distances(
4646
for (int jb = start_b; jb < end_b; ++jb) {
4747
const int idx_j = cell_indices[jb];
4848

49-
double dist_sq = 0.0;
49+
float dist_sq = 0.0f;
5050
#pragma unroll
5151
for (int feat = 0; feat < n_features; ++feat) {
52-
double diff = embedding[idx_i * n_features + feat] -
52+
float diff = embedding[idx_i * n_features + feat] -
5353
embedding[idx_j * n_features + feat];
5454
dist_sq += diff * diff;
5555
}
56-
local_sum += sqrt(dist_sq);
56+
local_sum += sqrtf(dist_sq);
5757
}
5858
}
5959

@@ -70,7 +70,7 @@ __global__ void compute_group_distances(
7070

7171
if (thread_id == 0) {
7272
// Store mean between-group distance
73-
d_other[block_id] = shared_sums[0] / (double)(n_a * n_b);
73+
d_other[block_id] = shared_sums[0] / (float)(n_a * n_b);
7474
}
7575
}
7676

0 commit comments

Comments
 (0)