Skip to content

Commit 33e439d

Browse files
authored
harmony hotfix (#351)
* harmony hotfix * update
1 parent 6e321c8 commit 33e439d

6 files changed

Lines changed: 29 additions & 23 deletions

File tree

docs/release-notes/0.12.3.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
1-
### 0.12.3 {small}`the-future`
2-
3-
```{rubric} Features
4-
```
5-
6-
```{rubric} Performance
7-
```
1+
### 0.12.3 {small}`2025-04-11`
82

93
```{rubric} Bug fixes
104
```
11-
12-
```{rubric} Misc
13-
```
5+
* Fixed `harmony_integrate` breakage caused by undocumented changes to fused kernel float handling in` CuPy 13.4.1` {pr}`351` {smaller}`S Dicks`

docs/release-notes/0.12.4.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
### 0.12.4 {small}`the-future`
2+
3+
```{rubric} Features
4+
```
5+
6+
```{rubric} Performance
7+
```
8+
9+
```{rubric} Bug fixes
10+
```
11+
12+
```{rubric} Misc
13+
```

docs/release-notes/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Release notes
44

55
## Version 0.12.0
6+
```{include} /release-notes/0.12.3.md
7+
```
68
```{include} /release-notes/0.12.2.md
79
```
810
```{include} /release-notes/0.12.1.md

src/rapids_singlecell/preprocessing/_harmony/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def harmonize(
9292
batch_mat: pd.DataFrame,
9393
batch_key: str | list[str],
9494
*,
95-
n_clusters: int = None,
95+
n_clusters: int | None = None,
9696
max_iter_harmony: int = 10,
9797
max_iter_clustering: int = 200,
9898
tol_harmony: float = 1e-4,
@@ -259,14 +259,13 @@ def _initialize_centroids(
259259
kmeans.fit(Z_norm)
260260
Y = kmeans.cluster_centers_.astype(Z_norm.dtype)
261261
Y_norm = _normalize_cp(Y, p=2)
262-
263262
# Initialize R
264-
R = _calc_R(-2 / sigma, cp.dot(Z_norm, Y_norm.T))
263+
term = cp.float64(-2 / sigma).astype(Z_norm.dtype)
264+
R = _calc_R(term, cp.dot(Z_norm, Y_norm.T))
265265
R = _normalize_cp(R, p=1)
266266

267267
E = cp.dot(Pr_b, cp.sum(R, axis=0, keepdims=True))
268268
O = cp.dot(Phi.T, R)
269-
270269
objectives_harmony = []
271270
_compute_objective(
272271
Y_norm,
@@ -278,7 +277,6 @@ def _initialize_centroids(
278277
E=E,
279278
objective_arr=objectives_harmony,
280279
)
281-
282280
return R, E, O, objectives_harmony
283281

284282

@@ -308,7 +306,7 @@ def _clustering(
308306
n_cells = Z_norm.shape[0]
309307
objectives_clustering = []
310308
block_size = int(n_cells * block_proportion)
311-
term = -2 / sigma
309+
term = cp.float64(-2 / sigma).astype(Z_norm.dtype)
312310
for _ in range(max_iter):
313311
# Compute Cluster Centroids
314312
Y = cp.dot(R.T, Z_norm) # Compute centroids

src/rapids_singlecell/preprocessing/_harmony/_fuses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _get_pen(e: cp.ndarray, o: cp.ndarray, theta: cp.ndarray) -> cp.ndarray:
1414

1515

1616
@cp.fuse
17-
def _calc_R(term: cp.ndarray, dotproduct: cp.ndarray) -> cp.ndarray:
17+
def _calc_R(term: float, dotproduct: cp.ndarray) -> cp.ndarray:
1818
return cp.exp(term * (1 - dotproduct))
1919

2020

tests/test_harmony.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
from __future__ import annotations
22

3+
import cupy as cp
34
import pytest
45
import scanpy as sc
56

67
import rapids_singlecell as rsc
78

89

910
@pytest.mark.parametrize("correction_method", ["fast", "original"])
10-
def test_harmony_integrate(correction_method):
11+
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
12+
def test_harmony_integrate(correction_method, dtype):
1113
"""
1214
Test that Harmony integrate works.
1315
1416
This is a very simple test that just checks to see if the Harmony
1517
integrate wrapper successfully added a new field to ``adata.obsm``
1618
and makes sure it has the same dimensions as the original PCA table.
1719
"""
18-
adata = sc.datasets.pbmc3k()
19-
sc.pp.recipe_zheng17(adata)
20-
sc.tl.pca(adata)
21-
adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"]
22-
rsc.pp.harmony_integrate(adata, "batch", correction_method=correction_method)
20+
adata = sc.datasets.pbmc68k_reduced()
21+
rsc.pp.harmony_integrate(
22+
adata, "bulk_labels", correction_method=correction_method, dtype=dtype
23+
)
2324
assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape

0 commit comments

Comments
 (0)