@@ -55,6 +55,7 @@ def harmonize(
5555 use_gemm : bool = False ,
5656 colsum_algo : COLSUM_ALGO | None = None ,
5757 random_state : int = 0 ,
58+ verbose : bool = False ,
5859) -> cp .array :
5960 """
6061 Integrate data using Harmony algorithm.
@@ -112,6 +113,9 @@ def harmonize(
112113 random_state
113114 Random seed for reproducing results.
114115
116+ verbose
117+ Whether to print benchmarking results for the column sum algorithm and the number of iterations until convergence.
118+
115119 Returns
116120 -------
117121 The integrated embedding by Harmony, of the same shape as the input embedding.
@@ -145,7 +149,7 @@ def harmonize(
145149 colsum_func_big = _choose_colsum_algo_heuristic (n_cells , n_clusters , None )
146150 if colsum_algo == "benchmark" :
147151 colsum_func_small = _choose_colsum_algo_benchmark (
148- int (n_cells * block_proportion ), n_clusters , Z .dtype
152+ int (n_cells * block_proportion ), n_clusters , Z .dtype , verbose = verbose
149153 )
150154 else :
151155 colsum_func_small = _choose_colsum_algo_heuristic (
@@ -180,6 +184,7 @@ def harmonize(
180184
181185 # Main harmony iterations
182186 is_converged = False
187+
183188 for i in range (max_iter_harmony ):
184189 # Clustering step
185190 _clustering (
@@ -198,7 +203,6 @@ def harmonize(
198203 block_proportion = block_proportion ,
199204 colsum_func = colsum_func_small ,
200205 )
201-
202206 # Correction step
203207 Z_hat = _correction (
204208 Z ,
@@ -212,13 +216,13 @@ def harmonize(
212216 cat_offsets = cat_offsets ,
213217 cell_indices = cell_indices ,
214218 )
215-
216219 # Normalize corrected data
217220 Z_norm = _normalize_cp (Z_hat , p = 2 )
218221 # Check for convergence
219222 if _is_convergent_harmony (objectives_harmony , tol = tol_harmony ):
220223 is_converged = True
221- print (f"Harmony converged in { i } iterations" )
224+ if verbose :
225+ print (f"Harmony converged in { i + 1 } iterations" )
222226 break
223227
224228 if not is_converged :
@@ -479,7 +483,6 @@ def _correction_original(
479483 id_mat = cp .eye (n_batches + 1 , n_batches + 1 , dtype = X .dtype )
480484 id_mat [0 , 0 ] = 0
481485 Lambda = ridge_lambda * id_mat
482-
483486 for k in range (n_clusters ):
484487 if Phi is not None :
485488 Phi_t_diag_R = Phi_1 .T * R [:, k ].reshape (1 , - 1 )
@@ -504,12 +507,10 @@ def _correction_original(
504507 )
505508 W = cp .dot (inv_mat , Phi_t_diag_R_X )
506509 W [0 , :] = 0
507-
508510 if Phi is not None :
509- Z -= cp .dot ( Phi_t_diag_R . T , W )
511+ cp .cublas . gemm ( "T" , "N" , Phi_t_diag_R , W , alpha = - 1 , beta = 1 , out = Z )
510512 else :
511513 _Z_correction (Z , W , cats , R_col )
512-
513514 return Z
514515
515516
@@ -537,7 +538,6 @@ def _correction_fast(
537538
538539 Z = X .copy ()
539540 P = cp .eye (n_batches + 1 , n_batches + 1 , dtype = X .dtype )
540-
541541 for k in range (n_clusters ):
542542 O_k = O [:, k ]
543543 N_k = cp .sum (O_k )
@@ -557,7 +557,6 @@ def _correction_fast(
557557 # Set off-diagonal entries
558558 P_t_B_inv [1 :, 0 ] = P [0 , 1 :] * c_inv
559559 inv_mat = cp .dot (P_t_B_inv , P )
560-
561560 if Phi is not None :
562561 Phi_t_diag_R = Phi_1 .T * R [:, k ].reshape (1 , - 1 )
563562 Phi_t_diag_R_X = cp .dot (Phi_t_diag_R , X )
@@ -577,10 +576,9 @@ def _correction_fast(
577576 W [0 , :] = 0
578577
579578 if Phi is not None :
580- Z -= cp .dot ( Phi_t_diag_R . T , W )
579+ cp .cublas . gemm ( "T" , "N" , Phi_t_diag_R , W , alpha = - 1 , beta = 1 , out = Z )
581580 else :
582581 _Z_correction (Z , W , cats , R_col )
583-
584582 return Z
585583
586584
0 commit comments