Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/release-notes/0.12.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

```{rubric} Bug fixes
```
* Fixed `harmony_integrate` breakage caused by undocumented changes to fused kernel float handling in` CuPy 13.4.1` {pr}`351` {smaller}`S Dicks`
* Fixed `harmony_integrate` breakage caused by undocumented changes to fused kernel float handling in `CuPy 13.4.1` {pr}`351` {smaller}`S Dicks`
1 change: 1 addition & 0 deletions docs/release-notes/0.12.7.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

```{rubric} Performance
```
* Speed up `pp.harmony_integrate` even more {pr}`379` {smaller}`S Dicks`

```{rubric} Bug fixes
```
Expand Down
22 changes: 10 additions & 12 deletions src/rapids_singlecell/preprocessing/_harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def harmonize(
use_gemm: bool = False,
colsum_algo: COLSUM_ALGO | None = None,
random_state: int = 0,
verbose: bool = False,
) -> cp.array:
"""
Integrate data using Harmony algorithm.
Expand Down Expand Up @@ -112,6 +113,9 @@ def harmonize(
random_state
Random seed for reproducing results.

verbose
Whether to print benchmarking results for the column sum algorithm and the number of iterations until convergence.

Returns
-------
The integrated embedding by Harmony, of the same shape as the input embedding.
Expand Down Expand Up @@ -145,7 +149,7 @@ def harmonize(
colsum_func_big = _choose_colsum_algo_heuristic(n_cells, n_clusters, None)
if colsum_algo == "benchmark":
colsum_func_small = _choose_colsum_algo_benchmark(
int(n_cells * block_proportion), n_clusters, Z.dtype
int(n_cells * block_proportion), n_clusters, Z.dtype, verbose=verbose
)
else:
colsum_func_small = _choose_colsum_algo_heuristic(
Expand Down Expand Up @@ -180,6 +184,7 @@ def harmonize(

# Main harmony iterations
is_converged = False

for i in range(max_iter_harmony):
# Clustering step
_clustering(
Expand All @@ -198,7 +203,6 @@ def harmonize(
block_proportion=block_proportion,
colsum_func=colsum_func_small,
)

# Correction step
Z_hat = _correction(
Z,
Expand All @@ -212,13 +216,13 @@ def harmonize(
cat_offsets=cat_offsets,
cell_indices=cell_indices,
)

# Normalize corrected data
Z_norm = _normalize_cp(Z_hat, p=2)
# Check for convergence
if _is_convergent_harmony(objectives_harmony, tol=tol_harmony):
is_converged = True
print(f"Harmony converged in {i} iterations")
if verbose:
print(f"Harmony converged in {i + 1} iterations")
break

if not is_converged:
Expand Down Expand Up @@ -479,7 +483,6 @@ def _correction_original(
id_mat = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype)
id_mat[0, 0] = 0
Lambda = ridge_lambda * id_mat

for k in range(n_clusters):
if Phi is not None:
Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1)
Expand All @@ -504,12 +507,10 @@ def _correction_original(
)
W = cp.dot(inv_mat, Phi_t_diag_R_X)
W[0, :] = 0

if Phi is not None:
Z -= cp.dot(Phi_t_diag_R.T, W)
cp.cublas.gemm("T", "N", Phi_t_diag_R, W, alpha=-1, beta=1, out=Z)
else:
_Z_correction(Z, W, cats, R_col)

return Z


Expand Down Expand Up @@ -537,7 +538,6 @@ def _correction_fast(

Z = X.copy()
P = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype)

for k in range(n_clusters):
O_k = O[:, k]
N_k = cp.sum(O_k)
Expand All @@ -557,7 +557,6 @@ def _correction_fast(
# Set off-diagonal entries
P_t_B_inv[1:, 0] = P[0, 1:] * c_inv
inv_mat = cp.dot(P_t_B_inv, P)

if Phi is not None:
Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1)
Phi_t_diag_R_X = cp.dot(Phi_t_diag_R, X)
Expand All @@ -577,10 +576,9 @@ def _correction_fast(
W[0, :] = 0

if Phi is not None:
Z -= cp.dot(Phi_t_diag_R.T, W)
cp.cublas.gemm("T", "N", Phi_t_diag_R, W, alpha=-1, beta=1, out=Z)
else:
_Z_correction(Z, W, cats, R_col)

return Z


Expand Down
10 changes: 5 additions & 5 deletions src/rapids_singlecell/preprocessing/_harmony/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
_get_aggregated_matrix_kernel,
_get_scatter_add_kernel_optimized,
_get_scatter_add_kernel_with_bias_block,
_get_scatter_add_kernel_with_bias_cat0,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -196,14 +195,15 @@ def _scatter_add_cp_bias_csr(
n_pcs = X.shape[1]

threads_per_block = 1024
blocks = int((n_pcs + 1) / 2)
scatter_kernel0 = _get_scatter_add_kernel_with_bias_cat0(X.dtype)
scatter_kernel0((blocks, 8), (threads_per_block,), (X, n_cells, n_pcs, out, bias))
# blocks = int((n_pcs + 1) / 2)
# scatter_kernel0 = _get_scatter_add_kernel_with_bias_cat0(X.dtype)
# scatter_kernel0((blocks, 8), (threads_per_block,), (X, n_cells, n_pcs, out, bias))
out[0] = X.T @ bias
blocks = int((n_batches) * (n_pcs + 1) / 2)
scatter_kernel = _get_scatter_add_kernel_with_bias_block(X.dtype)
scatter_kernel(
(blocks,),
(1024,),
(threads_per_block,),
(X, cat_offsets, cell_indices, n_cells, n_pcs, n_batches, out, bias),
)

Expand Down
Loading