Skip to content

Commit 9b0c274

Browse files
authored
feat: support to create FTS index on list of strings (#3622)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 30c3356 commit 9b0c274

6 files changed

Lines changed: 181 additions & 39 deletions

File tree

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/python/lance/dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,11 +1752,15 @@ def create_scalar_index(
17521752
if not pa.types.is_string(field_type):
17531753
raise TypeError(f"NGRAM index column {column} must be a string")
17541754
elif index_type in ["INVERTED", "FTS"]:
1755-
if not pa.types.is_string(field_type) and not pa.types.is_large_string(
1756-
field_type
1755+
value_type = field_type
1756+
if pa.types.is_list(field_type) or pa.types.is_large_list(field_type):
1757+
value_type = field_type.value_type
1758+
if not pa.types.is_string(value_type) and not pa.types.is_large_string(
1759+
value_type
17571760
):
17581761
raise TypeError(
1759-
f"INVERTED index column {column} must be string or large string"
1762+
f"INVERTED index column {column} must be string, large string"
1763+
" or list of strings, but got {value_type}"
17601764
)
17611765

17621766
if pa.types.is_duration(field_type):

python/python/tests/test_scalar_index.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,28 @@ def test_indexed_filter_with_fts_index(tmp_path):
371371
assert results["_rowid"].to_pylist() == [2, 3]
372372

373373

374+
def test_fts_on_list(tmp_path):
375+
data = pa.table(
376+
{
377+
"text": [
378+
["lance database", "the", "search"],
379+
["lance database"],
380+
["lance", "search"],
381+
["database", "search"],
382+
["unrelated", "doc"],
383+
]
384+
}
385+
)
386+
ds = lance.write_dataset(data, tmp_path)
387+
ds.create_scalar_index("text", "INVERTED", with_position=True)
388+
389+
results = ds.to_table(full_text_query="lance")
390+
assert results.num_rows == 3
391+
392+
results = ds.to_table(full_text_query=PhraseQuery("lance database", "text"))
393+
assert results.num_rows == 2
394+
395+
374396
def test_fts_fuzzy_query(tmp_path):
375397
data = pa.table(
376398
{

rust/lance-index/src/scalar/inverted/builder.rs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ use crate::scalar::{IndexReader, IndexStore, IndexWriter, InvertedIndexParams};
1212
use crate::vector::graph::OrderedFloat;
1313
use arrow::array::{ArrayBuilder, AsArray, Int32Builder, StringBuilder};
1414
use arrow::datatypes;
15-
use arrow_array::{Int32Array, RecordBatch, StringArray};
16-
use arrow_schema::SchemaRef;
15+
use arrow_array::{Array, Int32Array, RecordBatch, StringArray, UInt64Array};
16+
use arrow_schema::{Field, Schema, SchemaRef};
1717
use crossbeam_queue::ArrayQueue;
1818
use datafusion::execution::SendableRecordBatchStream;
1919
use deepsize::DeepSizeOf;
@@ -22,10 +22,11 @@ use itertools::Itertools;
2222
use lance_arrow::iter_str_array;
2323
use lance_core::cache::FileMetadataCache;
2424
use lance_core::utils::tokio::{get_num_compute_intensive_cpus, CPU_RUNTIME};
25-
use lance_core::{Result, ROW_ID};
25+
use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
2626
use lance_io::object_store::ObjectStore;
2727
use lazy_static::lazy_static;
2828
use object_store::path::Path;
29+
use snafu::location;
2930
use tempfile::{tempdir, TempDir};
3031
use tracing::instrument;
3132

@@ -108,6 +109,23 @@ impl InvertedIndexBuilder {
108109

109110
#[instrument(level = "debug", skip_all)]
110111
async fn update_index(&mut self, stream: SendableRecordBatchStream) -> Result<()> {
112+
let flatten_stream = stream.map(|batch| {
113+
let batch = batch?;
114+
let doc_col = batch.column(0);
115+
match doc_col.data_type() {
116+
datatypes::DataType::Utf8 | datatypes::DataType::LargeUtf8 => Ok(batch),
117+
datatypes::DataType::List(_) => {
118+
flatten_string_list::<i32>(&batch, doc_col)
119+
}
120+
datatypes::DataType::LargeList(_) => {
121+
flatten_string_list::<i64>(&batch, doc_col)
122+
}
123+
_ => {
124+
Err(Error::Index { message: format!("expect data type String, LargeString or List of String/LargeString, but got {}", doc_col.data_type()), location: location!() })
125+
}
126+
}
127+
});
128+
111129
let num_shards = *LANCE_FTS_NUM_SHARDS;
112130

113131
// init the token maps
@@ -159,13 +177,15 @@ impl InvertedIndexBuilder {
159177
for _ in 0..num_shards {
160178
let _ = tokenizer_pool.push(tokenizer.clone());
161179
}
162-
let mut stream = stream
180+
let mut stream = flatten_stream
163181
.map(move |batch| {
164182
let senders = senders.clone();
165183
let tokenizer_pool = tokenizer_pool.clone();
166184
CPU_RUNTIME.spawn_blocking(move || {
167185
let batch = batch?;
168-
let doc_iter = iter_str_array(batch.column(0));
186+
187+
let doc_col = batch.column(0);
188+
let doc_iter = iter_str_array(doc_col);
169189
let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
170190
let docs = doc_iter
171191
.zip(row_id_col.values().iter())
@@ -721,3 +741,42 @@ pub fn inverted_list_schema(with_position: bool) -> SchemaRef {
721741
}
722742
Arc::new(arrow_schema::Schema::new(fields))
723743
}
744+
745+
fn flatten_string_list<Offset: arrow::array::OffsetSizeTrait>(
746+
batch: &RecordBatch,
747+
doc_col: &Arc<dyn Array>,
748+
) -> Result<RecordBatch> {
749+
let docs = doc_col.as_list::<Offset>();
750+
let row_ids = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
751+
752+
let row_ids = row_ids
753+
.values()
754+
.iter()
755+
.zip(docs.iter())
756+
.flat_map(|(row_id, doc)| std::iter::repeat_n(*row_id, doc.map(|d| d.len()).unwrap_or(0)));
757+
758+
let row_ids = Arc::new(UInt64Array::from_iter_values(row_ids));
759+
let docs = match docs.value_type() {
760+
datatypes::DataType::Utf8 | datatypes::DataType::LargeUtf8 => docs.values().clone(),
761+
_ => {
762+
return Err(Error::Index {
763+
message: format!(
764+
"expect data type String or LargeString but got {}",
765+
docs.value_type()
766+
),
767+
location: location!(),
768+
});
769+
}
770+
};
771+
772+
let schema = Schema::new(vec![
773+
Field::new(
774+
batch.schema().field(0).name(),
775+
docs.data_type().clone(),
776+
true,
777+
),
778+
ROW_ID_FIELD.clone(),
779+
]);
780+
let batch = RecordBatch::try_new(Arc::new(schema), vec![docs, row_ids])?;
781+
Ok(batch)
782+
}

rust/lance/src/dataset.rs

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,7 +1742,7 @@ mod tests {
17421742
use crate::index::vector::VectorIndexParams;
17431743
use crate::utils::test::TestDatasetGenerator;
17441744

1745-
use arrow::array::{as_struct_array, AsArray};
1745+
use arrow::array::{as_struct_array, AsArray, GenericListBuilder, GenericStringBuilder};
17461746
use arrow::compute::concat_batches;
17471747
use arrow::datatypes::UInt64Type;
17481748
use arrow_array::{
@@ -5045,7 +5045,11 @@ mod tests {
50455045
assert_eq!(row_ids, &[0]);
50465046
}
50475047

5048-
async fn create_fts_dataset<Offset: arrow::array::OffsetSizeTrait>(
5048+
async fn create_fts_dataset<
5049+
Offset: arrow::array::OffsetSizeTrait,
5050+
ListOffset: arrow::array::OffsetSizeTrait,
5051+
>(
5052+
is_list: bool,
50495053
with_position: bool,
50505054
tokenizer: TokenizerConfig,
50515055
) -> Dataset {
@@ -5055,19 +5059,46 @@ mod tests {
50555059

50565060
let mut params = InvertedIndexParams::default().with_position(with_position);
50575061
params.tokenizer_config = tokenizer;
5058-
let doc_col = GenericStringArray::<Offset>::from(vec![
5059-
"lance database the search",
5060-
"lance database",
5061-
"lance search",
5062-
"database search",
5063-
"unrelated doc",
5064-
"unrelated",
5065-
"mots accentués",
5066-
]);
5062+
let doc_col: Arc<dyn Array> = if is_list {
5063+
let string_builder = GenericStringBuilder::<Offset>::new();
5064+
let mut list_col = GenericListBuilder::<ListOffset, _>::new(string_builder);
5065+
// Create a list of strings
5066+
list_col.values().append_value("lance database"); // for testing phrase query
5067+
list_col.values().append_value("the");
5068+
list_col.values().append_value("search");
5069+
list_col.append(true);
5070+
list_col.values().append_value("lance database"); // for testing phrase query
5071+
list_col.append(true);
5072+
list_col.values().append_value("lance");
5073+
list_col.values().append_value("search");
5074+
list_col.append(true);
5075+
list_col.values().append_value("database");
5076+
list_col.values().append_value("search");
5077+
list_col.append(true);
5078+
list_col.values().append_value("unrelated doc");
5079+
list_col.append(true);
5080+
list_col.values().append_value("unrelated");
5081+
list_col.append(true);
5082+
list_col.values().append_value("mots");
5083+
list_col.values().append_value("accentués");
5084+
list_col.append(true);
5085+
list_col.append(false);
5086+
Arc::new(list_col.finish())
5087+
} else {
5088+
Arc::new(GenericStringArray::<Offset>::from(vec![
5089+
"lance database the search",
5090+
"lance database",
5091+
"lance search",
5092+
"database search",
5093+
"unrelated doc",
5094+
"unrelated",
5095+
"mots accentués",
5096+
]))
5097+
};
50675098
let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64);
50685099
let batch = RecordBatch::try_new(
50695100
arrow_schema::Schema::new(vec![
5070-
arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), false),
5101+
arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true),
50715102
arrow_schema::Field::new("id", DataType::UInt64, false),
50725103
])
50735104
.into(),
@@ -5086,8 +5117,15 @@ mod tests {
50865117
dataset
50875118
}
50885119

5089-
async fn test_fts_index<Offset: arrow::array::OffsetSizeTrait>() {
5090-
let ds = create_fts_dataset::<Offset>(false, TokenizerConfig::default()).await;
5120+
async fn test_fts_index<
5121+
Offset: arrow::array::OffsetSizeTrait,
5122+
ListOffset: arrow::array::OffsetSizeTrait,
5123+
>(
5124+
is_list: bool,
5125+
) {
5126+
let ds =
5127+
create_fts_dataset::<Offset, ListOffset>(is_list, false, TokenizerConfig::default())
5128+
.await;
50915129
let result = ds
50925130
.scan()
50935131
.project(&["id"])
@@ -5152,7 +5190,9 @@ mod tests {
51525190
assert!(err.contains("position is not found but required for phrase queries, try recreating the index with position"),"{}",err);
51535191

51545192
// recreate the index with position
5155-
let ds = create_fts_dataset::<Offset>(true, TokenizerConfig::default()).await;
5193+
let ds =
5194+
create_fts_dataset::<Offset, ListOffset>(is_list, true, TokenizerConfig::default())
5195+
.await;
51565196
let result = ds
51575197
.scan()
51585198
.project(&["id"])
@@ -5235,17 +5275,21 @@ mod tests {
52355275

52365276
#[tokio::test]
52375277
async fn test_fts_index_with_string() {
5238-
test_fts_index::<i32>().await;
5278+
test_fts_index::<i32, i32>(false).await;
5279+
test_fts_index::<i32, i32>(true).await;
5280+
test_fts_index::<i32, i64>(true).await;
52395281
}
52405282

52415283
#[tokio::test]
52425284
async fn test_fts_index_with_large_string() {
5243-
test_fts_index::<i64>().await;
5285+
test_fts_index::<i64, i32>(false).await;
5286+
test_fts_index::<i64, i32>(true).await;
5287+
test_fts_index::<i64, i64>(true).await;
52445288
}
52455289

52465290
#[tokio::test]
52475291
async fn test_fts_accented_chars() {
5248-
let ds = create_fts_dataset::<i32>(false, TokenizerConfig::default()).await;
5292+
let ds = create_fts_dataset::<i32, i32>(false, false, TokenizerConfig::default()).await;
52495293
let result = ds
52505294
.scan()
52515295
.project(&["id"])
@@ -5269,8 +5313,12 @@ mod tests {
52695313
assert_eq!(result.num_rows(), 0);
52705314

52715315
// with ascii folding enabled, the search should be accent-insensitive
5272-
let ds =
5273-
create_fts_dataset::<i32>(false, TokenizerConfig::default().ascii_folding(true)).await;
5316+
let ds = create_fts_dataset::<i32, i32>(
5317+
false,
5318+
false,
5319+
TokenizerConfig::default().ascii_folding(true),
5320+
)
5321+
.await;
52745322
let result = ds
52755323
.scan()
52765324
.project(&["id"])

rust/lance/src/dataset/scanner.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,13 +1577,22 @@ impl Scanner {
15771577
let query = if columns.is_empty() {
15781578
// the field is not specified,
15791579
// try to search over all indexed fields
1580-
let string_columns = self.dataset.schema().fields.iter().filter_map(|f| {
1581-
if f.data_type() == DataType::Utf8 || f.data_type() == DataType::LargeUtf8 {
1582-
Some(&f.name)
1583-
} else {
1584-
None
1585-
}
1586-
});
1580+
let string_columns =
1581+
self.dataset
1582+
.schema()
1583+
.fields
1584+
.iter()
1585+
.filter_map(|f| match f.data_type() {
1586+
DataType::Utf8 | DataType::LargeUtf8 => Some(&f.name),
1587+
DataType::List(field) | DataType::LargeList(field) => {
1588+
if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
1589+
Some(&f.name)
1590+
} else {
1591+
None
1592+
}
1593+
}
1594+
_ => None,
1595+
});
15871596

15881597
let mut indexed_columns = Vec::new();
15891598
for column in string_columns {

0 commit comments

Comments
 (0)