Skip to content

Commit 5fd8412

Browse files
authored
feat: add IVF_HNSW_FLAT index (#3671)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent af8de1d commit 5fd8412

8 files changed

Lines changed: 118 additions & 4 deletions

File tree

python/python/lance/dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2068,7 +2068,13 @@ def create_index(
20682068
kwargs["metric_type"] = metric
20692069

20702070
index_type = index_type.upper()
2071-
valid_index_types = ["IVF_FLAT", "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"]
2071+
valid_index_types = [
2072+
"IVF_FLAT",
2073+
"IVF_PQ",
2074+
"IVF_HNSW_FLAT",
2075+
"IVF_HNSW_PQ",
2076+
"IVF_HNSW_SQ",
2077+
]
20722078
if index_type not in valid_index_types:
20732079
raise NotImplementedError(
20742080
f"Only {valid_index_types} index types supported. Got {index_type}"

python/python/tests/test_vector_index.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,18 @@ def test_create_ivf_hnsw_sq_index(dataset, tmp_path):
561561
assert ann_ds.list_indices()[0]["fields"] == ["vector"]
562562

563563

564+
def test_create_ivf_hnsw_flat_index(dataset, tmp_path):
565+
assert not dataset.has_index
566+
ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance")
567+
ann_ds = ann_ds.create_index(
568+
"vector",
569+
index_type="IVF_HNSW_FLAT",
570+
num_partitions=4,
571+
num_sub_vectors=16,
572+
)
573+
assert ann_ds.list_indices()[0]["fields"] == ["vector"]
574+
575+
564576
def test_multivec_ann(indexed_multivec_dataset: lance.LanceDataset):
565577
query = np.random.rand(5, 128)
566578
results = indexed_multivec_dataset.scanner(

python/src/dataset.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,9 @@ impl Dataset {
13441344
"NGRAM" => IndexType::NGram,
13451345
"LABEL_LIST" => IndexType::LabelList,
13461346
"INVERTED" | "FTS" => IndexType::Inverted,
1347-
"IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector,
1347+
"IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_FLAT" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => {
1348+
IndexType::Vector
1349+
}
13481350
_ => {
13491351
return Err(PyValueError::new_err(format!(
13501352
"Index type '{index_type}' is not supported."
@@ -2153,6 +2155,12 @@ fn prepare_vector_index_params(
21532155
m_type, ivf_params, pq_params,
21542156
))),
21552157

2158+
"IVF_HNSW_FLAT" => Ok(Box::new(VectorIndexParams::ivf_hnsw(
2159+
m_type,
2160+
ivf_params,
2161+
hnsw_params,
2162+
))),
2163+
21562164
"IVF_HNSW_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_hnsw_pq_params(
21572165
m_type,
21582166
ivf_params,

rust/lance-index/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ pub enum IndexType {
9999
IvfPq = 103,
100100
IvfHnswSq = 104,
101101
IvfHnswPq = 105,
102+
IvfHnswFlat = 106,
102103
}
103104

104105
impl std::fmt::Display for IndexType {
@@ -115,6 +116,7 @@ impl std::fmt::Display for IndexType {
115116
Self::IvfSq => write!(f, "IVF_SQ"),
116117
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
117118
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
119+
Self::IvfHnswFlat => write!(f, "IVF_HNSW_FLAT"),
118120
}
119121
}
120122
}
@@ -136,6 +138,7 @@ impl TryFrom<i32> for IndexType {
136138
v if v == Self::IvfPq as i32 => Ok(Self::IvfPq),
137139
v if v == Self::IvfHnswSq as i32 => Ok(Self::IvfHnswSq),
138140
v if v == Self::IvfHnswPq as i32 => Ok(Self::IvfHnswPq),
141+
v if v == Self::IvfHnswFlat as i32 => Ok(Self::IvfHnswFlat),
139142
_ => Err(Error::InvalidInput {
140143
source: format!("the input value {} is not a valid IndexType", value).into(),
141144
location: location!(),
@@ -164,6 +167,7 @@ impl IndexType {
164167
| Self::IvfPq
165168
| Self::IvfHnswSq
166169
| Self::IvfHnswPq
170+
| Self::IvfHnswFlat
167171
| Self::IvfFlat
168172
| Self::IvfSq
169173
)
@@ -191,7 +195,8 @@ impl IndexType {
191195
| Self::IvfSq
192196
| Self::IvfPq
193197
| Self::IvfHnswSq
194-
| Self::IvfHnswPq => 1,
198+
| Self::IvfHnswPq
199+
| Self::IvfHnswFlat => 1,
195200
}
196201
}
197202
}

rust/lance/src/index.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,18 @@ impl DatasetIndexInternalExt for Dataset {
10911091
Ok(Arc::new(ivf) as Arc<dyn VectorIndex>)
10921092
}
10931093

1094+
"IVF_HNSW_FLAT" => {
1095+
let ivf = IVFIndex::<HNSW, FlatQuantizer>::try_new(
1096+
self.object_store.clone(),
1097+
self.indices_dir(),
1098+
uuid.to_owned(),
1099+
Arc::downgrade(&self.session),
1100+
fri,
1101+
)
1102+
.await?;
1103+
Ok(Arc::new(ivf) as Arc<dyn VectorIndex>)
1104+
}
1105+
10941106
"IVF_HNSW_SQ" => {
10951107
let ivf = IVFIndex::<HNSW, ScalarQuantizer>::try_new(
10961108
self.object_store.clone(),

rust/lance/src/index/vector.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ impl VectorIndexParams {
145145
}
146146
}
147147

148+
pub fn ivf_hnsw(
149+
distance_type: DistanceType,
150+
ivf: IvfBuildParams,
151+
hnsw: HnswBuildParams,
152+
) -> Self {
153+
let stages = vec![StageParams::Ivf(ivf), StageParams::Hnsw(hnsw)];
154+
Self {
155+
stages,
156+
metric_type: distance_type,
157+
version: IndexFileVersion::V3,
158+
}
159+
}
160+
148161
/// Create index parameters with `IVF`, `PQ` and `HNSW` parameters, respectively.
149162
/// This is used for `IVF_HNSW_PQ` index.
150163
pub fn with_ivf_hnsw_pq_params(
@@ -392,6 +405,21 @@ pub(crate) async fn build_vector_index(
392405
});
393406
}
394407
}
408+
} else {
409+
// without quantization
410+
IvfIndexBuilder::<HNSW, FlatQuantizer>::new(
411+
dataset.clone(),
412+
column.to_owned(),
413+
dataset.indices_dir().child(uuid),
414+
params.metric_type,
415+
Box::new(shuffler),
416+
Some(ivf_params.clone()),
417+
Some(()),
418+
hnsw_params.clone(),
419+
fri,
420+
)?
421+
.build()
422+
.await?;
395423
}
396424
} else {
397425
return Err(Error::Index {

rust/lance/src/index/vector/ivf.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,29 @@ pub(crate) async fn optimize_vector_indices_v2(
447447
.build()
448448
.await?;
449449
}
450+
// IVF_HNSW_FLAT
451+
(SubIndexType::Hnsw, QuantizationType::Flat) => {
452+
IvfIndexBuilder::<HNSW, FlatQuantizer>::new(
453+
dataset.clone(),
454+
vector_column.to_owned(),
455+
index_dir,
456+
distance_type,
457+
shuffler,
458+
None,
459+
None,
460+
// TODO: get the HNSW parameters from the existing indices
461+
HnswBuildParams::default(),
462+
fri,
463+
)?
464+
.with_ivf(ivf_model.clone())
465+
.with_quantizer(quantizer.try_into()?)
466+
.with_existing_indices(indices_to_merge)
467+
.retrain(options.retrain)
468+
.shuffle_data(unindexed)
469+
.await?
470+
.build()
471+
.await?;
472+
}
450473
// IVF_HNSW_SQ
451474
(SubIndexType::Hnsw, QuantizationType::Scalar) => {
452475
IvfIndexBuilder::<HNSW, ScalarQuantizer>::new(

rust/lance/src/index/vector/ivf/v2.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> Index for IVFIndex<S,
348348
(SubIndexType::Flat, QuantizationType::Scalar) => IndexType::IvfSq,
349349
(SubIndexType::Hnsw, QuantizationType::Product) => IndexType::IvfHnswPq,
350350
(SubIndexType::Hnsw, QuantizationType::Scalar) => IndexType::IvfHnswSq,
351-
_ => IndexType::Vector,
351+
(SubIndexType::Hnsw, QuantizationType::Flat) => IndexType::IvfHnswFlat,
352352
}
353353
}
354354

@@ -1244,6 +1244,26 @@ mod tests {
12441244
test_optimize_strategy(params).await;
12451245
}
12461246

1247+
#[rstest]
1248+
#[case(4, DistanceType::L2, 0.9)]
1249+
#[case(4, DistanceType::Cosine, 0.9)]
1250+
#[case(4, DistanceType::Dot, 0.85)]
1251+
#[tokio::test]
1252+
async fn test_create_ivf_hnsw_flat(
1253+
#[case] nlist: usize,
1254+
#[case] distance_type: DistanceType,
1255+
#[case] recall_requirement: f32,
1256+
) {
1257+
let ivf_params = IvfBuildParams::new(nlist);
1258+
let hnsw_params = HnswBuildParams::default();
1259+
let params = VectorIndexParams::ivf_hnsw(distance_type, ivf_params, hnsw_params);
1260+
test_index(params.clone(), nlist, recall_requirement, None).await;
1261+
if distance_type == DistanceType::Cosine {
1262+
test_index_multivec(params.clone(), nlist, recall_requirement).await;
1263+
}
1264+
test_optimize_strategy(params).await;
1265+
}
1266+
12471267
#[rstest]
12481268
#[case(4, DistanceType::L2, 0.9)]
12491269
#[case(4, DistanceType::Cosine, 0.9)]

0 commit comments

Comments
 (0)