Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/0.12.7.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```{rubric} Features
```
* Adds `algorithm_kwds` to `pp.neighbors` & `pp.bbknn` to fine-tune `ivfflat`, `ivfpq` & `nn_descent` {pr}`381` {smaller}`S Dicks`

```{rubric} Performance
```
Expand Down
94 changes: 81 additions & 13 deletions src/rapids_singlecell/preprocessing/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def _brute_knn(
X: cp_sparse.spmatrix | cp.ndarray,
Y: cp_sparse.spmatrix | cp.ndarray,
k: int,
*,
metric: _Metrics,
metric_kwds: Mapping,
algorithm_kwds: Mapping,
) -> tuple[cp.ndarray, cp.ndarray]:
from cuml.neighbors import NearestNeighbors

Expand All @@ -91,7 +93,13 @@ def _brute_knn(


def _cagra_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
X: cp.ndarray,
Y: cp.ndarray,
k: int,
*,
metric: _Metrics,
metric_kwds: Mapping,
algorithm_kwds: Mapping,
) -> tuple[cp.ndarray, cp.ndarray]:
if not _cuvs_switch():
try:
Expand Down Expand Up @@ -135,8 +143,20 @@ def _cagra_knn(
return neighbors, distances


def _compute_nlist(N):
base = math.sqrt(N)
next_pow2 = 2 ** math.ceil(math.log2(base))
return int(next_pow2 * 2)


def _ivf_flat_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
X: cp.ndarray,
Y: cp.ndarray,
k: int,
*,
metric: _Metrics,
metric_kwds: Mapping,
algorithm_kwds: Mapping,
) -> tuple[cp.ndarray, cp.ndarray]:
if not _cuvs_switch():
from pylibraft.neighbors import ivf_flat
Expand All @@ -151,12 +171,16 @@ def _ivf_flat_knn(
build_kwargs = {} # cuvs does not need handle/resources
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
# Extract n_lists and nprobes from algorithm_kwds, with defaults
n_lists = algorithm_kwds.get("n_lists", _compute_nlist(X.shape[0]))
n_probes = algorithm_kwds.get("n_probes", 20)
print(f"n_lists: {n_lists}, n_probes: {n_probes}")
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_flat.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_flat.search(
ivf_flat.SearchParams(), index, Y, k, **search_kwargs
)

# Create SearchParams with nprobes if provided
search_params = ivf_flat.SearchParams(n_probes=n_probes)
distances, neighbors = ivf_flat.search(search_params, index, Y, k, **search_kwargs)

if resources is not None:
resources.sync()
Expand All @@ -168,7 +192,13 @@ def _ivf_flat_knn(


def _ivf_pq_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
X: cp.ndarray,
Y: cp.ndarray,
k: int,
*,
metric: _Metrics,
metric_kwds: Mapping,
algorithm_kwds: Mapping,
) -> tuple[cp.ndarray, cp.ndarray]:
if not _cuvs_switch():
from pylibraft.neighbors import ivf_pq
Expand All @@ -183,12 +213,16 @@ def _ivf_pq_knn(
build_kwargs = {}
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
# Extract n_lists and nprobes from algorithm_kwds, with defaults
n_lists = algorithm_kwds.get("n_lists", _compute_nlist(X.shape[0]))
n_probes = algorithm_kwds.get("n_probes", 20)

index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_pq.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_pq.search(
ivf_pq.SearchParams(), index, Y, k, **search_kwargs
)
print(f"n_lists: {n_lists}, n_probes: {n_probes}")
# Create SearchParams with nprobes if provided
search_params = ivf_pq.SearchParams(n_probes=n_probes)
distances, neighbors = ivf_pq.search(search_params, index, Y, k, **search_kwargs)
if resources is not None:
resources.sync()

Expand All @@ -199,7 +233,13 @@ def _ivf_pq_knn(


def _nn_descent_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
X: cp.ndarray,
Y: cp.ndarray,
k: int,
*,
metric: _Metrics,
metric_kwds: Mapping,
algorithm_kwds: Mapping,
) -> tuple[cp.ndarray, cp.ndarray]:
from cuvs import __version__ as cuvs_version

Expand All @@ -210,8 +250,13 @@ def _nn_descent_knn(
)
from cuvs.neighbors import nn_descent

# Extract intermediate_graph_degree from algorithm_kwds, with default
intermediate_graph_degree = algorithm_kwds.get("intermediate_graph_degree", None)

idxparams = nn_descent.IndexParams(
graph_degree=k, metric="sqeuclidean" if metric == "euclidean" else metric
graph_degree=k,
intermediate_graph_degree=intermediate_graph_degree,
metric="sqeuclidean" if metric == "euclidean" else metric,
)
idx = nn_descent.build(
idxparams,
Expand Down Expand Up @@ -392,6 +437,7 @@ def neighbors(
algorithm: _Algorithms = "brute",
metric: _Metrics = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
algorithm_kwds: Mapping[str, Any] = MappingProxyType({}),
key_added: str | None = None,
copy: bool = False,
) -> AnnData | None:
Expand Down Expand Up @@ -437,6 +483,16 @@ def neighbors(
A known metric's name or a callable that returns a distance.
metric_kwds
Options for the metric.
algorithm_kwds
Options for the algorithm. For 'ivfflat' and 'ivfpq' algorithms, the following
parameters can be specified:
* 'n_lists': Number of inverted lists for IVF indexing. Default is 2 * next_power_of_2(sqrt(n_samples)).
* 'n_probes': Number of lists to probe during search. Default is 20. Higher values
increase accuracy but reduce speed.
For 'nn_descent' algorithm, the following parameters can be specified:
* 'intermediate_graph_degree': The degree of the intermediate graph. Default is None.
It is recommended to set it to `>= 1.5 * n_neighbors`.

key_added
If not specified, the neighbors data is stored in .uns['neighbors'],
distances and connectivities are stored in .obsp['distances'] and
Expand Down Expand Up @@ -484,6 +540,7 @@ def neighbors(
k=n_neighbors,
metric=metric,
metric_kwds=metric_kwds,
algorithm_kwds=algorithm_kwds,
)

n_nonzero = n_obs * n_neighbors
Expand Down Expand Up @@ -516,6 +573,7 @@ def neighbors(
random_state=random_state,
metric=metric,
**({"metric_kwds": metric_kwds} if metric_kwds else {}),
**({"algorithm_kwds": algorithm_kwds} if algorithm_kwds else {}),
**({"use_rep": use_rep} if use_rep is not None else {}),
**({"n_pcs": n_pcs} if n_pcs is not None else {}),
)
Expand Down Expand Up @@ -543,6 +601,7 @@ def bbknn(
algorithm: _Algorithms_bbknn = "brute",
metric: _Metrics = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
algorithm_kwds: Mapping[str, Any] = MappingProxyType({}),
trim: int | None = None,
key_added: str | None = None,
copy: bool = False,
Expand Down Expand Up @@ -588,6 +647,13 @@ def bbknn(
A known metric's name or a callable that returns a distance.
metric_kwds
Options for the metric.
algorithm_kwds
Options for the algorithm. For 'ivfflat' and 'ivfpq' algorithms, the following
parameters can be specified:

* 'n_lists': Number of inverted lists for IVF indexing. Default is 2 * next_power_of_2(sqrt(n_samples)).
* 'nprobes': Number of lists to probe during search. Default is 1. Higher values
increase accuracy but reduce speed.
trim
Trim the neighbours of each cell to these many top connectivities.
May help with population independence and improve the tidiness of clustering.
Expand Down Expand Up @@ -660,6 +726,7 @@ def bbknn(
k=neighbors_within_batch,
metric=metric,
metric_kwds=metric_kwds,
algorithm_kwds=algorithm_kwds,
)

col_range = cp.arange(
Expand Down Expand Up @@ -705,6 +772,7 @@ def bbknn(
metric=metric,
trim=trim,
**({"metric_kwds": metric_kwds} if metric_kwds else {}),
**({"algorithm_kwds": algorithm_kwds} if algorithm_kwds else {}),
**({"use_rep": use_rep} if use_rep is not None else {}),
**({"n_pcs": n_pcs} if n_pcs is not None else {}),
)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def test_algo(algo):
neighbors(adata, n_neighbors=5, algorithm=algo)


def test_nn_descent_intermediate_graph_degree():
adata = pbmc68k_reduced()
neighbors(
adata,
n_neighbors=5,
algorithm="nn_descent",
algorithm_kwds={"intermediate_graph_degree": 10},
)


@pytest.mark.parametrize("algo", ["ivfflat", "ivfpq"])
def test_ivf_algorithm_kwds(algo):
adata = pbmc68k_reduced()
neighbors(
adata,
n_neighbors=5,
algorithm=algo,
algorithm_kwds={"n_lists": 10, "n_probes": 10},
)


@pytest.mark.parametrize("algo", ["nn_descent", "ivfpq"])
def test_indices_approx_nn(algo):
adata = pbmc68k_reduced()
Expand Down
Loading