Skip to content

Commit 980e0e3

Browse files
authored
Harmony dot (#379)
* add dot update * remove print * add release note
1 parent ce93fda commit 980e0e3

4 files changed

Lines changed: 17 additions & 18 deletions

File tree

docs/release-notes/0.12.3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
```{rubric} Bug fixes
44
```
5-
* Fixed `harmony_integrate` breakage caused by undocumented changes to fused kernel float handling in` CuPy 13.4.1` {pr}`351` {smaller}`S Dicks`
5+
* Fixed `harmony_integrate` breakage caused by undocumented changes to fused kernel float handling in `CuPy 13.4.1` {pr}`351` {smaller}`S Dicks`

docs/release-notes/0.12.7.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
```{rubric} Performance
77
```
8+
* Speed up `pp.harmony_integrate` even more {pr}`379` {smaller}`S Dicks`
89

910
```{rubric} Bug fixes
1011
```

src/rapids_singlecell/preprocessing/_harmony/__init__.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def harmonize(
5555
use_gemm: bool = False,
5656
colsum_algo: COLSUM_ALGO | None = None,
5757
random_state: int = 0,
58+
verbose: bool = False,
5859
) -> cp.array:
5960
"""
6061
Integrate data using Harmony algorithm.
@@ -112,6 +113,9 @@ def harmonize(
112113
random_state
113114
Random seed for reproducing results.
114115
116+
verbose
117+
Whether to print benchmarking results for the column sum algorithm and the number of iterations until convergence.
118+
115119
Returns
116120
-------
117121
The integrated embedding by Harmony, of the same shape as the input embedding.
@@ -145,7 +149,7 @@ def harmonize(
145149
colsum_func_big = _choose_colsum_algo_heuristic(n_cells, n_clusters, None)
146150
if colsum_algo == "benchmark":
147151
colsum_func_small = _choose_colsum_algo_benchmark(
148-
int(n_cells * block_proportion), n_clusters, Z.dtype
152+
int(n_cells * block_proportion), n_clusters, Z.dtype, verbose=verbose
149153
)
150154
else:
151155
colsum_func_small = _choose_colsum_algo_heuristic(
@@ -180,6 +184,7 @@ def harmonize(
180184

181185
# Main harmony iterations
182186
is_converged = False
187+
183188
for i in range(max_iter_harmony):
184189
# Clustering step
185190
_clustering(
@@ -198,7 +203,6 @@ def harmonize(
198203
block_proportion=block_proportion,
199204
colsum_func=colsum_func_small,
200205
)
201-
202206
# Correction step
203207
Z_hat = _correction(
204208
Z,
@@ -212,13 +216,13 @@ def harmonize(
212216
cat_offsets=cat_offsets,
213217
cell_indices=cell_indices,
214218
)
215-
216219
# Normalize corrected data
217220
Z_norm = _normalize_cp(Z_hat, p=2)
218221
# Check for convergence
219222
if _is_convergent_harmony(objectives_harmony, tol=tol_harmony):
220223
is_converged = True
221-
print(f"Harmony converged in {i} iterations")
224+
if verbose:
225+
print(f"Harmony converged in {i + 1} iterations")
222226
break
223227

224228
if not is_converged:
@@ -479,7 +483,6 @@ def _correction_original(
479483
id_mat = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype)
480484
id_mat[0, 0] = 0
481485
Lambda = ridge_lambda * id_mat
482-
483486
for k in range(n_clusters):
484487
if Phi is not None:
485488
Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1)
@@ -504,12 +507,10 @@ def _correction_original(
504507
)
505508
W = cp.dot(inv_mat, Phi_t_diag_R_X)
506509
W[0, :] = 0
507-
508510
if Phi is not None:
509-
Z -= cp.dot(Phi_t_diag_R.T, W)
511+
cp.cublas.gemm("T", "N", Phi_t_diag_R, W, alpha=-1, beta=1, out=Z)
510512
else:
511513
_Z_correction(Z, W, cats, R_col)
512-
513514
return Z
514515

515516

@@ -537,7 +538,6 @@ def _correction_fast(
537538

538539
Z = X.copy()
539540
P = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype)
540-
541541
for k in range(n_clusters):
542542
O_k = O[:, k]
543543
N_k = cp.sum(O_k)
@@ -557,7 +557,6 @@ def _correction_fast(
557557
# Set off-diagonal entries
558558
P_t_B_inv[1:, 0] = P[0, 1:] * c_inv
559559
inv_mat = cp.dot(P_t_B_inv, P)
560-
561560
if Phi is not None:
562561
Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1)
563562
Phi_t_diag_R_X = cp.dot(Phi_t_diag_R, X)
@@ -577,10 +576,9 @@ def _correction_fast(
577576
W[0, :] = 0
578577

579578
if Phi is not None:
580-
Z -= cp.dot(Phi_t_diag_R.T, W)
579+
cp.cublas.gemm("T", "N", Phi_t_diag_R, W, alpha=-1, beta=1, out=Z)
581580
else:
582581
_Z_correction(Z, W, cats, R_col)
583-
584582
return Z
585583

586584

src/rapids_singlecell/preprocessing/_harmony/_helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
_get_aggregated_matrix_kernel,
1818
_get_scatter_add_kernel_optimized,
1919
_get_scatter_add_kernel_with_bias_block,
20-
_get_scatter_add_kernel_with_bias_cat0,
2120
)
2221

2322
if TYPE_CHECKING:
@@ -196,14 +195,15 @@ def _scatter_add_cp_bias_csr(
196195
n_pcs = X.shape[1]
197196

198197
threads_per_block = 1024
199-
blocks = int((n_pcs + 1) / 2)
200-
scatter_kernel0 = _get_scatter_add_kernel_with_bias_cat0(X.dtype)
201-
scatter_kernel0((blocks, 8), (threads_per_block,), (X, n_cells, n_pcs, out, bias))
198+
# blocks = int((n_pcs + 1) / 2)
199+
# scatter_kernel0 = _get_scatter_add_kernel_with_bias_cat0(X.dtype)
200+
# scatter_kernel0((blocks, 8), (threads_per_block,), (X, n_cells, n_pcs, out, bias))
201+
out[0] = X.T @ bias
202202
blocks = int((n_batches) * (n_pcs + 1) / 2)
203203
scatter_kernel = _get_scatter_add_kernel_with_bias_block(X.dtype)
204204
scatter_kernel(
205205
(blocks,),
206-
(1024,),
206+
(threads_per_block,),
207207
(X, cat_offsets, cell_indices, n_cells, n_pcs, n_batches, out, bias),
208208
)
209209

0 commit comments

Comments
 (0)