Skip to content

Commit a767eda

Browse files
authored
feat!: add minimum probes and maximum probes to IVF search (#3903)
Currently IVF search takes an `nprobes` parameter that selects how many partitions are searched. This can be a problem with highly selective prefilters because the partitions closest to the query vector might not have any results. Setting a larger `nprobes` value will lead to more results being returned but all queries will be more expensive. I think we will eventually want to work around this problem through a combination of column statistics, selectivity estimation, and partition bloom filters. However, in the short term, this PR splits `nprobes` into `minimum_nprobes` and `maximum_nprobes`. At first we will search `minimum_nprobes`, like we do today. If we don't find enough results after searching those partitions then we will search up to `maximum_nprobes` partitions until we do. This PR also changes the default nprobes to 20. The previous default was 1 which was unlikely to give good recall and could set misleading expectations. The new default of 20 is the same as the default in lancedb. We also consider the prefilter max_results, which will allow us to stop early if we have found all results matching the prefilter. Also, if the number of prefilter results is very small, we will skip the late search entirely and just return all prefiltered row ids to be sorted in the refine stage. The approach is slightly complex but one of my primary goals is to avoid impacting our current search performance which relies on our ability to select partitions and execute the prefilter in parallel (so, for example, we won't know the actual selectivity until after we've already loaded several partitions). BREAKING CHANGE: the default for nprobes is now 20 instead of 1
1 parent 11af7bf commit a767eda

22 files changed

Lines changed: 1459 additions & 311 deletions

File tree

java/core/lance-jni/src/blocking_scanner.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ fn inner_create_scanner<'local>(
190190
let k = env.get_int_as_usize_from_method(&java_obj, "getK")?;
191191
let _ = scanner.nearest(&column, &key, k);
192192

193-
let nprobes = env.get_int_as_usize_from_method(&java_obj, "getNprobes")?;
194-
scanner.nprobs(nprobes);
193+
let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?;
194+
scanner.minimum_nprobes(minimum_nprobes);
195+
196+
let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?;
197+
if let Some(maximum_nprobes) = maximum_nprobes {
198+
scanner.maximum_nprobes(maximum_nprobes);
199+
}
195200

196201
if let Some(ef) = env.get_optional_usize_from_method(&java_obj, "getEf")? {
197202
scanner.ef(ef);

java/core/lance-jni/src/utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result<Option<Query>>
9999
let key = Arc::new(Float32Array::from(key_array));
100100

101101
let k = env.get_int_as_usize_from_method(&java_obj, "getK")?;
102-
let nprobes = env.get_int_as_usize_from_method(&java_obj, "getNprobes")?;
102+
let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?;
103+
let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?;
103104

104105
let ef = env.get_optional_usize_from_method(&java_obj, "getEf")?;
105106

@@ -120,7 +121,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result<Option<Query>>
120121
k,
121122
lower_bound: None,
122123
upper_bound: None,
123-
nprobes,
124+
minimum_nprobes,
125+
maximum_nprobes,
124126
ef,
125127
refine_factor,
126128
metric_type: distance_type,

java/core/src/main/java/com/lancedb/lance/ipc/Query.java

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public class Query {
2525
private final String column;
2626
private final float[] key;
2727
private final int k;
28-
private final int nprobes;
28+
private final int minimumNprobes;
29+
private final Optional<Integer> maximumNprobes;
2930
private final Optional<Integer> ef;
3031
private final Optional<Integer> refineFactor;
3132
private final DistanceType distanceType;
@@ -36,9 +37,15 @@ private Query(Builder builder) {
3637
Preconditions.checkArgument(!builder.column.isEmpty(), "Column must not be empty");
3738
this.key = Preconditions.checkNotNull(builder.key, "Key must be set");
3839
Preconditions.checkArgument(builder.k > 0, "K must be greater than 0");
39-
Preconditions.checkArgument(builder.nprobes > 0, "Nprobes must be greater than 0");
40+
Preconditions.checkArgument(
41+
builder.minimumNprobes > 0, "Minimum Nprobes must be greater than 0");
42+
Preconditions.checkArgument(
43+
!builder.maximumNprobes.isPresent()
44+
|| builder.maximumNprobes.get() >= builder.minimumNprobes,
45+
"Maximum Nprobes must be greater than minimum Nprobes");
4046
this.k = builder.k;
41-
this.nprobes = builder.nprobes;
47+
this.minimumNprobes = builder.minimumNprobes;
48+
this.maximumNprobes = builder.maximumNprobes;
4249
this.ef = builder.ef;
4350
this.refineFactor = builder.refineFactor;
4451
this.distanceType = Preconditions.checkNotNull(builder.distanceType, "Metric type must be set");
@@ -57,8 +64,12 @@ public int getK() {
5764
return k;
5865
}
5966

60-
public int getNprobes() {
61-
return nprobes;
67+
public int getMinimumNprobes() {
68+
return minimumNprobes;
69+
}
70+
71+
public Optional<Integer> getMaximumNprobes() {
72+
return maximumNprobes;
6273
}
6374

6475
public Optional<Integer> getEf() {
@@ -83,7 +94,8 @@ public String toString() {
8394
.append("column", column)
8495
.append("key", key)
8596
.append("k", k)
86-
.append("nprobes", nprobes)
97+
.append("minimumNprobes", minimumNprobes)
98+
.append("maximumNprobes", maximumNprobes.orElse(null))
8799
.append("ef", ef.orElse(null))
88100
.append("refineFactor", refineFactor.orElse(null))
89101
.append("distanceType", distanceType)
@@ -95,7 +107,8 @@ public static class Builder {
95107
private String column;
96108
private float[] key;
97109
private int k = 10;
98-
private int nprobes = 1;
110+
private int minimumNprobes = 20;
111+
private Optional<Integer> maximumNprobes = Optional.empty();
99112
private Optional<Integer> ef = Optional.empty();
100113
private Optional<Integer> refineFactor = Optional.empty();
101114
private DistanceType distanceType = DistanceType.L2;
@@ -137,11 +150,46 @@ public Builder setK(int k) {
137150
/**
138151
* Sets the number of probes to load and search.
139152
*
153+
* <p>This is a convenience method that sets both the minimum and maximum number of probes to
154+
* the same value.
155+
*
140156
* @param nprobes The number of probes.
141157
* @return The Builder instance for method chaining.
142158
*/
143159
public Builder setNprobes(int nprobes) {
144-
this.nprobes = nprobes;
160+
this.minimumNprobes = nprobes;
161+
this.maximumNprobes = Optional.of(nprobes);
162+
return this;
163+
}
164+
165+
/**
166+
* Sets the minimum number of partitions to search.
167+
*
168+
* <p>This many partitions will always be loaded and searched on the query. Increasing this
169+
* number can improve recall at the cost of latency.
170+
*
171+
* @param minimumNprobes The minimum number of partitions to search.
172+
* @return The Builder instance for method chaining.
173+
*/
174+
public Builder setMinimumNprobes(int minimumNprobes) {
175+
this.minimumNprobes = minimumNprobes;
176+
return this;
177+
}
178+
179+
/**
180+
* Sets the maximum number of partitions to search.
181+
*
182+
* <p>These partitions will only be loaded and searched if we have not found the desired number
183+
* of results after searching the minimum number of partitions. Increasing this number can avoid
184+
* false negatives on queries with a highly selective prefilter. This setting does not affect
185+
* the recall of the query and will only affect the latency if the prefilter is highly
186+
* selective.
187+
*
188+
* @param maximumNprobes The maximum number of partitions to search.
189+
* @return The Builder instance for method chaining.
190+
*/
191+
public Builder setMaximumNprobes(int maximumNprobes) {
192+
this.maximumNprobes = Optional.of(maximumNprobes);
145193
return this;
146194
}
147195

python/python/lance/dataset.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ def scanner(
404404
"column": <embedding col name>,
405405
"q": <query vector as pa.Float32Array>,
406406
"k": 10,
407-
"nprobes": 1,
407+
"minimum_nprobes": 20,
408+
"maximum_nprobes": 50,
408409
"refine_factor": 1
409410
}
410411
@@ -643,7 +644,8 @@ def to_table(
643644
"q": <query vector as pa.Float32Array>,
644645
"k": 10,
645646
"metric": "cosine",
646-
"nprobes": 1,
647+
"minimum_nprobes": 20,
648+
"maximum_nprobes": 50,
647649
"refine_factor": 1
648650
}
649651
@@ -3402,6 +3404,8 @@ def nearest(
34023404
k: Optional[int] = None,
34033405
metric: Optional[str] = None,
34043406
nprobes: Optional[int] = None,
3407+
minimum_nprobes: Optional[int] = None,
3408+
maximum_nprobes: Optional[int] = None,
34053409
refine_factor: Optional[int] = None,
34063410
use_index: bool = True,
34073411
ef: Optional[int] = None,
@@ -3435,6 +3439,26 @@ def nearest(
34353439
raise ValueError(f"Nearest-K must be > 0 but got {k}")
34363440
if nprobes is not None and int(nprobes) <= 0:
34373441
raise ValueError(f"Nprobes must be > 0 but got {nprobes}")
3442+
if minimum_nprobes is not None and int(minimum_nprobes) < 0:
3443+
raise ValueError(f"Minimum nprobes must be >= 0 but got {minimum_nprobes}")
3444+
if maximum_nprobes is not None and int(maximum_nprobes) < 0:
3445+
raise ValueError(f"Maximum nprobes must be >= 0 but got {maximum_nprobes}")
3446+
3447+
if nprobes is not None:
3448+
if minimum_nprobes is not None or maximum_nprobes is not None:
3449+
raise ValueError(
3450+
"nprobes cannot be set in combination with minimum_nprobes or "
3451+
"maximum_nprobes"
3452+
)
3453+
else:
3454+
minimum_nprobes = nprobes
3455+
maximum_nprobes = nprobes
3456+
if (
3457+
minimum_nprobes is not None
3458+
and maximum_nprobes is not None
3459+
and minimum_nprobes > maximum_nprobes
3460+
):
3461+
raise ValueError("minimum_nprobes must be <= maximum_nprobes")
34383462
if refine_factor is not None and int(refine_factor) < 1:
34393463
raise ValueError(f"Refine factor must be 1 or more got {refine_factor}")
34403464
if ef is not None and int(ef) <= 0:
@@ -3446,7 +3470,8 @@ def nearest(
34463470
"q": q,
34473471
"k": k,
34483472
"metric": metric,
3449-
"nprobes": nprobes,
3473+
"minimum_nprobes": minimum_nprobes,
3474+
"maximum_nprobes": maximum_nprobes,
34503475
"refine_factor": refine_factor,
34513476
"use_index": use_index,
34523477
"ef": ef,

python/python/tests/test_vector_index.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import pytest
1515
from lance import LanceFragment
1616
from lance.dataset import VectorIndexReader
17-
18-
torch = pytest.importorskip("torch")
1917
from lance.util import validate_vector_index # noqa: E402
2018
from lance.vector import vec_to_table # noqa: E402
2119

@@ -255,6 +253,8 @@ def test_index_with_nans(tmp_path):
255253

256254

257255
def test_torch_index_with_nans(tmp_path):
256+
torch = pytest.importorskip("torch")
257+
258258
# 1024 rows, the entire table should be sampled
259259
tbl = create_table(nvec=1000, nans=24)
260260

@@ -271,6 +271,8 @@ def test_torch_index_with_nans(tmp_path):
271271

272272

273273
def test_index_with_no_centroid_movement(tmp_path):
274+
torch = pytest.importorskip("torch")
275+
274276
# this test makes the centroids essentially [1..]
275277
# this makes sure the early stop condition in the index building code
276278
# doesn't do divide by zero
@@ -343,6 +345,10 @@ def test_create_index_using_cuda(tmp_path, nullify):
343345

344346

345347
def test_create_index_unsupported_accelerator(tmp_path):
348+
# Even attempting to use an accelerator will trigger torch import
349+
# so make sure it's available
350+
pytest.importorskip("torch")
351+
346352
tbl = create_table()
347353
dataset = lance.write_dataset(tbl, tmp_path)
348354
with pytest.raises(ValueError):
@@ -896,6 +902,7 @@ def query_index(ds, ntimes, q=None):
896902
nearest={
897903
"column": "vector",
898904
"q": q if q is not None else rng.standard_normal(ndim),
905+
"minimum_nprobes": 1,
899906
},
900907
)
901908

@@ -1033,6 +1040,8 @@ def test_dynamic_projection_with_vectors_index(tmp_path: Path):
10331040

10341041

10351042
def test_index_cast_centroids(tmp_path):
1043+
torch = pytest.importorskip("torch")
1044+
10361045
tbl = create_table(nvec=1000)
10371046

10381047
dataset = lance.write_dataset(tbl, tmp_path)
@@ -1260,3 +1269,53 @@ def test_vector_index_with_prefilter_and_scalar_index(indexed_dataset):
12601269
prefilter=True,
12611270
)
12621271
assert len(res) == 10
1272+
1273+
1274+
def test_vector_index_with_nprobes(indexed_dataset):
1275+
res = indexed_dataset.scanner(
1276+
nearest={
1277+
"column": "vector",
1278+
"q": np.random.randn(128),
1279+
"k": 10,
1280+
"nprobes": 7,
1281+
}
1282+
).explain_plan()
1283+
1284+
assert "minimum_nprobes=7" in res
1285+
assert "maximum_nprobes=Some(7)" in res
1286+
1287+
res = indexed_dataset.scanner(
1288+
nearest={
1289+
"column": "vector",
1290+
"q": np.random.randn(128),
1291+
"k": 10,
1292+
"minimum_nprobes": 7,
1293+
}
1294+
).explain_plan()
1295+
1296+
assert "minimum_nprobes=7" in res
1297+
assert "maximum_nprobes=None" in res
1298+
1299+
res = indexed_dataset.scanner(
1300+
nearest={
1301+
"column": "vector",
1302+
"q": np.random.randn(128),
1303+
"k": 10,
1304+
"minimum_nprobes": 7,
1305+
"maximum_nprobes": 10,
1306+
}
1307+
).explain_plan()
1308+
1309+
assert "minimum_nprobes=7" in res
1310+
assert "maximum_nprobes=Some(10)" in res
1311+
1312+
res = indexed_dataset.scanner(
1313+
nearest={
1314+
"column": "vector",
1315+
"q": np.random.randn(128),
1316+
"k": 10,
1317+
"maximum_nprobes": 30,
1318+
}
1319+
).analyze_plan()
1320+
1321+
print(res)

python/src/dataset.rs

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub mod commit;
9292
pub mod optimize;
9393
pub mod stats;
9494

95-
const DEFAULT_NPROBS: usize = 1;
95+
const DEFAULT_NPROBS: usize = 20;
9696
const DEFAULT_INDEX_CACHE_SIZE: usize = 256;
9797
const DEFAULT_METADATA_CACHE_SIZE: usize = 256;
9898

@@ -746,15 +746,41 @@ impl Dataset {
746746
10
747747
};
748748

749-
let nprobes: usize = if let Some(nprobes) = nearest.get_item("nprobes")? {
750-
if nprobes.is_none() {
751-
DEFAULT_NPROBS
752-
} else {
753-
nprobes.extract()?
749+
let mut minimum_nprobes = DEFAULT_NPROBS;
750+
let mut maximum_nprobes = None;
751+
752+
if let Some(nprobes) = nearest.get_item("nprobes")? {
753+
if !nprobes.is_none() {
754+
minimum_nprobes = nprobes.extract()?;
755+
maximum_nprobes = Some(minimum_nprobes);
754756
}
755-
} else {
756-
DEFAULT_NPROBS
757-
};
757+
}
758+
759+
if let Some(min_nprobes) = nearest.get_item("minimum_nprobes")? {
760+
if !min_nprobes.is_none() {
761+
minimum_nprobes = min_nprobes.extract()?;
762+
}
763+
}
764+
765+
if let Some(max_nprobes) = nearest.get_item("maximum_nprobes")? {
766+
if !max_nprobes.is_none() {
767+
maximum_nprobes = Some(max_nprobes.extract()?);
768+
}
769+
}
770+
771+
if minimum_nprobes > maximum_nprobes.unwrap_or(usize::MAX) {
772+
return Err(PyValueError::new_err(
773+
"minimum_nprobes must be <= maximum_nprobes",
774+
));
775+
}
776+
777+
if minimum_nprobes < 1 {
778+
return Err(PyValueError::new_err("minimum_nprobes must be >= 1"));
779+
}
780+
781+
if maximum_nprobes.unwrap_or(usize::MAX) < 1 {
782+
return Err(PyValueError::new_err("maximum_nprobes must be >= 1"));
783+
}
758784

759785
let metric_type: Option<MetricType> =
760786
if let Some(metric) = nearest.get_item("metric")? {
@@ -813,7 +839,10 @@ impl Dataset {
813839
};
814840
scanner
815841
.map(|s| {
816-
let mut s = s.nprobs(nprobes);
842+
let mut s = s.minimum_nprobes(minimum_nprobes);
843+
if let Some(maximum_nprobes) = maximum_nprobes {
844+
s = s.maximum_nprobes(maximum_nprobes);
845+
}
817846
if let Some(factor) = refine_factor {
818847
s = s.refine(factor);
819848
}

0 commit comments

Comments
 (0)