Skip to content

Commit a866e4e

Browse files
committed
Record leiden/louvain modularity in adata.uns
1 parent 64e4bdc commit a866e4e

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

src/rapids_singlecell/tools/_clustering.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ def leiden(
229229
resolutions = [resolution]
230230
else:
231231
resolutions = resolution
232+
modularities = []
232233
for resolution in resolutions:
233-
leiden_parts, _ = culeiden(
234+
leiden_parts, modularity = culeiden(
234235
g,
235236
resolution=resolution,
236237
random_state=random_state,
@@ -241,6 +242,7 @@ def leiden(
241242
leiden_parts = leiden_parts.to_backend("pandas").compute()
242243
else:
243244
leiden_parts = leiden_parts.to_pandas()
245+
modularities.append(modularity)
244246

245247
# Format output
246248
groups = leiden_parts.sort_values("vertex")[["partition"]].to_numpy().ravel()
@@ -270,10 +272,13 @@ def leiden(
270272
# store information on the clustering parameters
271273
adata.uns[key_added] = {}
272274
adata.uns[key_added]["params"] = {
273-
"resolution": resolutions,
275+
"resolution": resolutions if len(resolutions) > 1 else resolutions[0],
274276
"random_state": random_state,
275277
"n_iterations": n_iterations,
276278
}
279+
adata.uns[key_added]["modularity"] = (
280+
modularities if len(modularities) > 1 else modularities[0]
281+
)
277282
return adata if copy else None
278283

279284

@@ -383,8 +388,9 @@ def louvain(
383388
resolutions = [resolution]
384389
else:
385390
resolutions = resolution
391+
modularities = []
386392
for resolution in resolutions:
387-
louvain_parts, _ = culouvain(
393+
louvain_parts, modularity = culouvain(
388394
g,
389395
resolution=resolution,
390396
max_level=n_iterations,
@@ -394,6 +400,7 @@ def louvain(
394400
louvain_parts = louvain_parts.to_backend("pandas").compute()
395401
else:
396402
louvain_parts = louvain_parts.to_pandas()
403+
modularities.append(modularity)
397404

398405
# Format output
399406
groups = louvain_parts.sort_values("vertex")[["partition"]].to_numpy().ravel()
@@ -422,10 +429,13 @@ def louvain(
422429
Comms.destroy()
423430
adata.uns[key_added] = {}
424431
adata.uns[key_added]["params"] = {
425-
"resolution": resolutions,
432+
"resolution": resolutions if len(resolutions) > 1 else resolutions[0],
426433
"n_iterations": n_iterations,
427434
"threshold": threshold,
428435
}
436+
adata.uns[key_added]["modularity"] = (
437+
modularities if len(modularities) > 1 else modularities[0]
438+
)
429439
return adata if copy else None
430440

431441

0 commit comments

Comments
 (0)