@@ -92,7 +92,7 @@ def harmonize(
9292 batch_mat : pd .DataFrame ,
9393 batch_key : str | list [str ],
9494 * ,
95- n_clusters : int = None ,
95+ n_clusters : int | None = None ,
9696 max_iter_harmony : int = 10 ,
9797 max_iter_clustering : int = 200 ,
9898 tol_harmony : float = 1e-4 ,
@@ -259,14 +259,13 @@ def _initialize_centroids(
259259 kmeans .fit (Z_norm )
260260 Y = kmeans .cluster_centers_ .astype (Z_norm .dtype )
261261 Y_norm = _normalize_cp (Y , p = 2 )
262-
263262 # Initialize R
264- R = _calc_R (- 2 / sigma , cp .dot (Z_norm , Y_norm .T ))
263+ term = cp .float64 (- 2 / sigma ).astype (Z_norm .dtype )
264+ R = _calc_R (term , cp .dot (Z_norm , Y_norm .T ))
265265 R = _normalize_cp (R , p = 1 )
266266
267267 E = cp .dot (Pr_b , cp .sum (R , axis = 0 , keepdims = True ))
268268 O = cp .dot (Phi .T , R )
269-
270269 objectives_harmony = []
271270 _compute_objective (
272271 Y_norm ,
@@ -278,7 +277,6 @@ def _initialize_centroids(
278277 E = E ,
279278 objective_arr = objectives_harmony ,
280279 )
281-
282280 return R , E , O , objectives_harmony
283281
284282
@@ -308,7 +306,7 @@ def _clustering(
308306 n_cells = Z_norm .shape [0 ]
309307 objectives_clustering = []
310308 block_size = int (n_cells * block_proportion )
311- term = - 2 / sigma
309+ term = cp . float64 ( - 2 / sigma ). astype ( Z_norm . dtype )
312310 for _ in range (max_iter ):
313311 # Compute Cluster Centroids
314312 Y = cp .dot (R .T , Z_norm ) # Compute centroids
0 commit comments