@@ -148,6 +148,69 @@ def test_rank_genes_groups_wilcoxon_subset_and_bonferroni(reference):
148148 assert np .all (adjusted <= 1.0 )
149149
150150
151+ @pytest .mark .parametrize (
152+ ("groups" , "reference" ),
153+ [
154+ (["0" ], "rest" ),
155+ (["0" , "2" ], "rest" ),
156+ (["0" ], "1" ),
157+ (["0" , "2" ], "1" ),
158+ ],
159+ )
160+ @pytest .mark .parametrize ("tie_correct" , [False , True ])
161+ @pytest .mark .parametrize ("pre_load" , [False , True ])
162+ def test_rank_genes_groups_wilcoxon_subset_matches_scanpy (
163+ groups , reference , tie_correct , pre_load
164+ ):
165+ np .random .seed (42 )
166+ adata_gpu = sc .datasets .blobs (n_variables = 8 , n_centers = 5 , n_observations = 200 )
167+ adata_gpu .obs ["blobs" ] = adata_gpu .obs ["blobs" ].astype ("category" )
168+ adata_cpu = adata_gpu .copy ()
169+
170+ rsc .tl .rank_genes_groups (
171+ adata_gpu ,
172+ "blobs" ,
173+ method = "wilcoxon" ,
174+ groups = groups ,
175+ reference = reference ,
176+ use_raw = False ,
177+ tie_correct = tie_correct ,
178+ pre_load = pre_load ,
179+ )
180+ sc .tl .rank_genes_groups (
181+ adata_cpu ,
182+ "blobs" ,
183+ method = "wilcoxon" ,
184+ groups = groups ,
185+ reference = reference ,
186+ use_raw = False ,
187+ tie_correct = tie_correct ,
188+ )
189+
190+ gpu_result = adata_gpu .uns ["rank_genes_groups" ]
191+ cpu_result = adata_cpu .uns ["rank_genes_groups" ]
192+
193+ assert gpu_result ["names" ].dtype .names == cpu_result ["names" ].dtype .names
194+ for group in gpu_result ["names" ].dtype .names :
195+ gpu_names = list (gpu_result ["names" ][group ])
196+ cpu_names = list (cpu_result ["names" ][group ])
197+ for field in ("scores" , "logfoldchanges" , "pvals" , "pvals_adj" ):
198+ gpu_map = dict (
199+ zip (gpu_names , np .asarray (gpu_result [field ][group ], dtype = float ))
200+ )
201+ cpu_map = dict (
202+ zip (cpu_names , np .asarray (cpu_result [field ][group ], dtype = float ))
203+ )
204+ for gene , gpu_val in gpu_map .items ():
205+ np .testing .assert_allclose (
206+ gpu_val ,
207+ cpu_map [gene ],
208+ rtol = 1e-6 ,
209+ atol = 1e-8 ,
210+ err_msg = f"{ field } mismatch for gene { gene } group { group } " ,
211+ )
212+
213+
151214@pytest .mark .parametrize (
152215 "reference_before,reference_after" ,
153216 [("rest" , "rest" ), ("1" , "One" )],
0 commit comments