Skip to content

Commit a57ec81

Browse files
authored
feat(index): support float16 and float64 in IVF_FLAT (#6476)
## Feature ### What is the new feature? This PR adds native `float16` and `float64` support for `IVF_FLAT` and `IVF_HNSW_FLAT`. ### Why do we need this feature? Flat IVF indexing previously only worked end to end for `float32`. That meant users could not build, merge, reload, and query flat IVF indexes on `float16` or `float64` vectors without running into `Float32`-specific assumptions in flat storage, writer initialization, and merge/query paths. ### How does it work? The implementation makes flat IVF paths dispatch on the actual Arrow element type from stored flat data instead of assuming `Float32`. - `FlatFloatStorage` now dispatches distance calculators for `float16`, `float32`, and `float64`. - Query/training helpers that previously special-cased `Float32` now accept the native float dtype where needed. - Tests now cover flat storage distance, partition serde roundtrip, IVF create/query/remap, and distributed merge behavior for `float16` / `float64`. ## Validation - `cargo fmt --all` - `cargo check -p lance-index --lib` - `cargo check -p lance --lib` - `cargo test -p lance-index test_flat_float_storage_distance_f16 -- --nocapture` - `cargo test -p lance-index test_merge_ivf_flat_preserves_float64_schema -- --nocapture` - `cargo test -p lance test_build_ivf_flat -- --nocapture` - `cargo test -p lance test_create_ivf_hnsw_flat -- --nocapture` - `cargo test -p lance test_create_ivf_flat_f16 -- --nocapture` ## Benchmark Note I also benchmarked float32 `IVF_FLAT` before vs after, no obvious performance diffs
1 parent f2db129 commit a57ec81

7 files changed

Lines changed: 556 additions & 89 deletions

File tree

rust/lance-index/src/vector/distributed/index_merger.rs

Lines changed: 151 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ pub async fn init_writer_for_flat(
176176
object_store: &lance_io::object_store::ObjectStore,
177177
aux_out: &object_store::path::Path,
178178
d0: usize,
179+
item_type: &DataType,
179180
dt: DistanceType,
180181
format_version: LanceFileVersion,
181182
) -> Result<FileWriter> {
@@ -184,7 +185,7 @@ pub async fn init_writer_for_flat(
184185
Field::new(
185186
crate::vector::flat::storage::FLAT_COLUMN,
186187
DataType::FixedSizeList(
187-
Arc::new(Field::new("item", DataType::Float32, true)),
188+
Arc::new(Field::new("item", item_type.clone(), true)),
188189
d0 as i32,
189190
),
190191
true,
@@ -1129,9 +1130,13 @@ pub async fn merge_partial_vector_auxiliary_files(
11291130
.iter()
11301131
.find(|f| f.name() == crate::vector::flat::storage::FLAT_COLUMN)
11311132
.ok_or_else(|| Error::index("FLAT column missing".to_string()))?;
1132-
let d0 = match flat_field.data_type() {
1133-
DataType::FixedSizeList(_, sz) => *sz as usize,
1134-
_ => 0,
1133+
let (d0, item_type) = match flat_field.data_type() {
1134+
DataType::FixedSizeList(item, sz) => (*sz as usize, item.data_type().clone()),
1135+
_ => {
1136+
return Err(Error::index(
1137+
"FLAT column is not a FixedSizeList in shard schema".to_string(),
1138+
));
1139+
}
11351140
};
11361141
dim.get_or_insert(d0);
11371142
if let Some(dprev) = dim
@@ -1140,53 +1145,29 @@ pub async fn merge_partial_vector_auxiliary_files(
11401145
return Err(Error::index("Dimension mismatch across shards".to_string()));
11411146
}
11421147
if v2w_opt.is_none() {
1143-
let w = init_writer_for_flat(object_store, &aux_out, d0, dt, fv).await?;
1148+
let w = init_writer_for_flat(object_store, &aux_out, d0, &item_type, dt, fv)
1149+
.await?;
11441150
v2w_opt = Some(w);
11451151
}
11461152
}
11471153
SupportedIvfIndexType::IvfHnswFlat => {
1148-
// Treat HNSW_FLAT storage the same as FLAT: create schema with ROW_ID + flat vectors
1149-
// Determine dimension from shard schema (flat column) or fallback to STORAGE_METADATA_KEY
1154+
// Treat HNSW_FLAT storage the same as FLAT and preserve the actual flat item dtype.
11501155
let schema_arrow: ArrowSchema = reader.schema().as_ref().into();
1151-
// Try to find flat column and derive dim
1152-
let d0 = if let Some(flat_field) = schema_arrow
1156+
let Some(flat_field) = schema_arrow
11531157
.fields
11541158
.iter()
11551159
.find(|f| f.name() == crate::vector::flat::storage::FLAT_COLUMN)
1156-
{
1157-
match flat_field.data_type() {
1158-
DataType::FixedSizeList(_, sz) => *sz as usize,
1159-
_ => 0,
1160-
}
1161-
} else {
1162-
// Fallback to STORAGE_METADATA_KEY FlatMetadata
1163-
if let Some(storage_meta_json) = reader
1164-
.metadata()
1165-
.file_schema
1166-
.metadata
1167-
.get(STORAGE_METADATA_KEY)
1168-
{
1169-
let storage_metadata_vec: Vec<String> =
1170-
serde_json::from_str(storage_meta_json).map_err(|e| {
1171-
Error::index(format!("Failed to parse storage metadata: {}", e))
1172-
})?;
1173-
if let Some(first_meta) = storage_metadata_vec.first() {
1174-
if let Ok(flat_meta) = serde_json::from_str::<FlatMetadata>(first_meta)
1175-
{
1176-
flat_meta.dim
1177-
} else {
1178-
return Err(Error::index(
1179-
"FLAT metadata missing in storage metadata".to_string(),
1180-
));
1181-
}
1182-
} else {
1183-
return Err(Error::index(
1184-
"FLAT metadata missing in storage metadata".to_string(),
1185-
));
1186-
}
1187-
} else {
1160+
else {
1161+
return Err(Error::index(
1162+
"FLAT column missing from IVF_HNSW_FLAT shard schema".to_string(),
1163+
));
1164+
};
1165+
let (d0, item_type) = match flat_field.data_type() {
1166+
DataType::FixedSizeList(item, sz) => (*sz as usize, item.data_type().clone()),
1167+
_ => {
11881168
return Err(Error::index(
1189-
"FLAT column missing and no storage metadata".to_string(),
1169+
"FLAT column is not a FixedSizeList in IVF_HNSW_FLAT shard schema"
1170+
.to_string(),
11901171
));
11911172
}
11921173
};
@@ -1197,7 +1178,8 @@ pub async fn merge_partial_vector_auxiliary_files(
11971178
return Err(Error::index("Dimension mismatch across shards".to_string()));
11981179
}
11991180
if v2w_opt.is_none() {
1200-
let w = init_writer_for_flat(object_store, &aux_out, d0, dt, fv).await?;
1181+
let w = init_writer_for_flat(object_store, &aux_out, d0, &item_type, dt, fv)
1182+
.await?;
12011183
v2w_opt = Some(w);
12021184
}
12031185
}
@@ -1523,7 +1505,9 @@ pub async fn merge_partial_vector_auxiliary_files(
15231505
mod tests {
15241506
use super::*;
15251507

1526-
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt8Array, UInt64Array};
1508+
use arrow_array::{
1509+
FixedSizeListArray, Float32Array, Float64Array, RecordBatch, UInt8Array, UInt64Array,
1510+
};
15271511
use arrow_schema::Field;
15281512
use bytes::Bytes;
15291513
use futures::StreamExt;
@@ -1614,6 +1598,71 @@ mod tests {
16141598
Ok(total_rows)
16151599
}
16161600

1601+
async fn write_flat_partial_aux_f64(
1602+
store: &ObjectStore,
1603+
aux_path: &Path,
1604+
dim: i32,
1605+
lengths: &[u32],
1606+
base_row_id: u64,
1607+
distance_type: DistanceType,
1608+
) -> Result<usize> {
1609+
let arrow_schema = ArrowSchema::new(vec![
1610+
(*ROW_ID_FIELD).clone(),
1611+
Field::new(
1612+
crate::vector::flat::storage::FLAT_COLUMN,
1613+
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), dim),
1614+
true,
1615+
),
1616+
]);
1617+
1618+
let writer = store.create(aux_path).await?;
1619+
let mut v2w = V2Writer::try_new(
1620+
writer,
1621+
lance_core::datatypes::Schema::try_from(&arrow_schema)?,
1622+
V2WriterOptions::default(),
1623+
)?;
1624+
v2w.add_schema_metadata(DISTANCE_TYPE_KEY, distance_type.to_string());
1625+
1626+
let ivf_meta = pb::Ivf {
1627+
centroids: Vec::new(),
1628+
offsets: Vec::new(),
1629+
lengths: lengths.to_vec(),
1630+
centroids_tensor: None,
1631+
loss: None,
1632+
};
1633+
let buf = Bytes::from(ivf_meta.encode_to_vec());
1634+
let pos = v2w.add_global_buffer(buf).await?;
1635+
v2w.add_schema_metadata(IVF_METADATA_KEY, pos.to_string());
1636+
1637+
let total_rows: usize = lengths.iter().map(|v| *v as usize).sum();
1638+
let mut row_ids = Vec::with_capacity(total_rows);
1639+
let mut values = Vec::with_capacity(total_rows * dim as usize);
1640+
1641+
let mut current_row_id = base_row_id;
1642+
for (pid, len) in lengths.iter().enumerate() {
1643+
for _ in 0..*len {
1644+
row_ids.push(current_row_id);
1645+
current_row_id += 1;
1646+
for d in 0..dim {
1647+
values.push(pid as f64 + d as f64 * 0.01);
1648+
}
1649+
}
1650+
}
1651+
1652+
let row_id_arr = UInt64Array::from(row_ids);
1653+
let value_arr = Float64Array::from(values);
1654+
let fsl = FixedSizeListArray::try_new_from_values(value_arr, dim).unwrap();
1655+
let batch = RecordBatch::try_new(
1656+
Arc::new(arrow_schema),
1657+
vec![Arc::new(row_id_arr), Arc::new(fsl)],
1658+
)
1659+
.unwrap();
1660+
1661+
v2w.write_batch(&batch).await?;
1662+
v2w.finish().await?;
1663+
Ok(total_rows)
1664+
}
1665+
16171666
#[tokio::test]
16181667
async fn test_merge_ivf_flat_success_basic() {
16191668
let object_store = ObjectStore::memory();
@@ -1831,6 +1880,64 @@ mod tests {
18311880
}
18321881
}
18331882

1883+
#[tokio::test]
1884+
async fn test_merge_ivf_flat_preserves_float64_schema() {
1885+
let object_store = ObjectStore::memory();
1886+
let index_dir = Path::from("index/float64_uuid");
1887+
1888+
let partial0 = index_dir.child("partial_0");
1889+
let partial1 = index_dir.child("partial_1");
1890+
let aux0 = partial0.child(INDEX_AUXILIARY_FILE_NAME);
1891+
let aux1 = partial1.child(INDEX_AUXILIARY_FILE_NAME);
1892+
1893+
let lengths = vec![2_u32, 2_u32];
1894+
let dim = 3_i32;
1895+
1896+
write_flat_partial_aux_f64(&object_store, &aux0, dim, &lengths, 0, DistanceType::L2)
1897+
.await
1898+
.unwrap();
1899+
write_flat_partial_aux_f64(&object_store, &aux1, dim, &lengths, 100, DistanceType::L2)
1900+
.await
1901+
.unwrap();
1902+
1903+
merge_partial_vector_auxiliary_files(
1904+
&object_store,
1905+
&[aux0.clone(), aux1.clone()],
1906+
&index_dir,
1907+
Arc::new(RecordingProgress::default()),
1908+
)
1909+
.await
1910+
.unwrap();
1911+
1912+
let aux_out = index_dir.child(INDEX_AUXILIARY_FILE_NAME);
1913+
let sched = ScanScheduler::new(
1914+
Arc::new(object_store.clone()),
1915+
SchedulerConfig::max_bandwidth(&object_store),
1916+
);
1917+
let fh = sched
1918+
.open_file(&aux_out, &CachedFileSize::unknown())
1919+
.await
1920+
.unwrap();
1921+
let reader = V2Reader::try_open(
1922+
fh,
1923+
None,
1924+
Arc::default(),
1925+
&lance_core::cache::LanceCache::no_cache(),
1926+
V2ReaderOptions::default(),
1927+
)
1928+
.await
1929+
.unwrap();
1930+
1931+
let flat_field = reader
1932+
.schema()
1933+
.field(crate::vector::flat::storage::FLAT_COLUMN)
1934+
.unwrap();
1935+
let DataType::FixedSizeList(item, _) = flat_field.data_type() else {
1936+
panic!("flat column should be a fixed size list");
1937+
};
1938+
assert_eq!(item.data_type(), &DataType::Float64);
1939+
}
1940+
18341941
#[allow(clippy::too_many_arguments)]
18351942
async fn write_pq_partial_aux(
18361943
store: &ObjectStore,

0 commit comments

Comments
 (0)