diff --git a/docs/src/operations/dql/vector-search.md b/docs/src/operations/dql/vector-search.md new file mode 100644 index 00000000..f4a306cc --- /dev/null +++ b/docs/src/operations/dql/vector-search.md @@ -0,0 +1,185 @@ +# Vector Search (`lance_vector_search`) + +Executes an Approximate-Nearest-Neighbour (kNN) search over a Lance vector column from Spark SQL. +Implemented as a **table-valued function**, so it composes cleanly with `WHERE`, `JOIN`, +`GROUP BY`, and projections — no grammar extension required. + +!!! warning "Spark Extension Required" + This feature requires the Lance Spark SQL extension to be enabled. See + [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +!!! tip "Also see" + - [CREATE INDEX](../ddl/create-index.md) — build the vector (`ivf_*`) index the search uses. + - [Select](select.md) — general read path. + +## Syntax + +The function takes four required positional arguments plus five optional ones: + +``` +lance_vector_search( + table, -- STRING required catalog-qualified name OR filesystem URI + column, -- STRING required name of the vector column + query, -- ARRAY required query vector, dimension must match column + k, -- INT required number of neighbours (> 0) + [metric], -- STRING optional l2 (default) | cosine | dot | hamming + [nprobes], -- INT optional IVF probe count, default 20 + [refine_factor], -- INT optional PQ re-rank factor, default 1 + [ef], -- INT optional HNSW search depth + [use_index] -- BOOLEAN optional default true; false = brute force +) +``` + +Spark 3.5+ also accepts **named** arguments (`query => array(...)`, `k => 10`, …). +Spark 3.4 only accepts positional arguments. + +## Basic Usage + +=== "SQL" + ```sql + SELECT id, category + FROM lance_vector_search( + 'lance.db.items', + 'embedding', + array(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8), + 10 + ); + ``` + +=== "PySpark" + ```python + from pyspark.sql import functions as F + + q = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + spark.sql(f""" + SELECT id, category + FROM lance_vector_search( + 'lance.db.items', + 'embedding', + array({', '.join(str(x) for x in q)}), + 10 + ) + """).show() + ``` + +## Arguments + +### `table` +A catalog-qualified name (e.g. `lance.db.items`) **or** a filesystem URI +(e.g. `s3://bucket/path/to/items.lance`). Catalog-qualified names are resolved through the +currently configured Spark catalog; URIs are passed straight through to the Lance DataSource. + +### `column` +The name of the vector column to search. Must be a vector column (see +[CREATE TABLE → Vector Columns](../ddl/create-table.md)). + +### `query` +The query vector, as a Spark `ARRAY` or `ARRAY` literal / foldable expression. The +length must match the column's `arrow.fixed-size-list.size` metadata, otherwise Lance raises an +error at scan time. Double-precision arrays are automatically down-cast to float32. + +### `k` +Number of neighbours to return. Must be positive. Lance may return fewer rows if the table is +smaller than `k` (or the pre-filter eliminates enough rows). + +### `metric` +Which distance metric to use. See the metric table in +[CREATE INDEX → Distance Metrics](../ddl/create-index.md#distance-metrics). +If omitted, the metric stored inside the index is used. + +### `nprobes` +Number of IVF partitions to probe. Higher values improve recall at the cost of latency. Default +`20`. Only relevant for IVF-family indexes. + +### `refine_factor` +PQ re-rank factor. The scan returns `k × refine_factor` PQ-approximate candidates, then re-ranks +them using the exact codebook centroids. `1` (default) disables re-ranking. + +### `ef` +HNSW candidate-list size at search time. Higher values improve recall. Relevant for +`ivf_hnsw_*` indexes. + +### `use_index` +`true` (default) uses the ANN index; `false` forces a brute-force scan of every fragment. Useful +for recall evaluation or when no index exists yet. + +## Composing with the Rest of SQL + +### Projection + +Project any subset of the source columns after the TVF: + +```sql +SELECT id FROM lance_vector_search('lance.db.items', 'embedding', array(...), 10); +``` + +### Pre-filters + +Filters on scalar columns that sit directly above the TVF are pushed into Lance and applied +**before** the kNN search — meaning `k` applies to the filtered subset, not the whole table. + +```sql +SELECT id, category +FROM lance_vector_search('lance.db.items', 'embedding', array(...), 10) +WHERE category = 'books' AND price < 50.0; +``` + +### Joins / group-by + +The TVF result is a regular Dataset, so all downstream operators work unchanged: + +```sql +SELECT s.id, s.category, i.name +FROM lance_vector_search('lance.db.items', 'embedding', array(...), 50) s +JOIN lance.db.inventory i ON i.id = s.id +WHERE i.in_stock; +``` + +## Brute Force vs. Indexed Search + +When you want ground truth — for recall evaluation or for tables that are too small to justify an +index — pass `use_index => false`: + +```sql +SELECT id +FROM lance_vector_search('lance.db.items', 'embedding', array(...), 10, 'l2', 20, 1, 64, false); +``` + +Brute force scans every row in every fragment. It returns exact top-k per fragment; Spark unions +the per-fragment results. + +## Tuning Recall vs. Latency + +| Knob | Effect on recall | Effect on latency | +|-------------------|------------------|-------------------| +| `nprobes` ↑ | ↑ | ↑ | +| `ef` ↑ | ↑ | ↑ | +| `refine_factor` ↑ | ↑ | ↑ | +| `num_partitions` ↑ at index time | neutral | ↓ (each probe is smaller) | +| `m` / `ef_construction` ↑ at index time | ↑ | neutral (one-time cost) | + +A common starting recipe for IVF-PQ on a few million rows: +`num_partitions = 256`, `num_sub_vectors = 16`, `nprobes = 20`, `refine_factor = 10`. + +## Errors + +| Condition | Result | +|--------------------------------------------------------|---------------------------------------------------------------| +| `k <= 0` | `IllegalArgumentException("… 'k' must be positive")` | +| Unknown metric (`'manhattan'`, etc.) | `IllegalArgumentException("… unsupported metric …")` | +| Non-constant `query` / `k` / `column` | `IllegalArgumentException("… must be a constant expression")` | +| `column` not a vector column | Raised by Lance at scan time (dimension mismatch). | +| `table` not found | `IllegalArgumentException("… could not resolve table …")` | + +## Notes and Limitations + +- **Fragment-local top-k**: the scan today performs search per fragment and unions the results, so + the raw TVF output may contain up to `k × num_fragments` rows. Add a global + `ORDER BY … LIMIT k` on top if you need the true global top-k. +- **Single column**: the `column` argument is a single string — you cannot combine two vector + columns in one call. +- **Query vector is a driver-side literal**: Spark evaluates the `query` expression on the driver + when planning the scan. Non-foldable expressions (e.g. a column reference) are rejected. +- **Named arguments**: require Spark 3.5+. On Spark 3.4 pass all arguments positionally. +- **`_distance` column**: not yet exposed in the TVF's output schema — see the roadmap issue for + progress. diff --git a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..835f5fdd 100644 --- a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -17,6 +17,7 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.read.LanceVectorSearchTableFunction class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -28,5 +29,12 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) + + // lance_vector_search(table, column, query, k, ...) table-valued function + extensions.injectTableFunction( + ( + LanceVectorSearchTableFunction.IDENTIFIER, + LanceVectorSearchTableFunction.INFO, + LanceVectorSearchTableFunction.BUILDER)) } } diff --git a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java new file mode 100644 index 00000000..6a0c09ed --- /dev/null +++ b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java @@ -0,0 +1,16 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +public class LanceVectorSearchTest extends BaseLanceVectorSearchTest {} diff --git a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..835f5fdd 100644 --- a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -17,6 +17,7 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.read.LanceVectorSearchTableFunction class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -28,5 +29,12 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) + + // lance_vector_search(table, column, query, k, ...) table-valued function + extensions.injectTableFunction( + ( + LanceVectorSearchTableFunction.IDENTIFIER, + LanceVectorSearchTableFunction.INFO, + LanceVectorSearchTableFunction.BUILDER)) } } diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java new file mode 100644 index 00000000..6a0c09ed --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/LanceVectorSearchTest.java @@ -0,0 +1,16 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +public class LanceVectorSearchTest extends BaseLanceVectorSearchTest {} diff --git a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..835f5fdd 100644 --- a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -17,6 +17,7 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.read.LanceVectorSearchTableFunction class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -28,5 +29,12 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) + + // lance_vector_search(table, column, query, k, ...) table-valued function + extensions.injectTableFunction( + ( + LanceVectorSearchTableFunction.IDENTIFIER, + LanceVectorSearchTableFunction.INFO, + LanceVectorSearchTableFunction.BUILDER)) } } diff --git a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d..835f5fdd 100644 --- a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -17,6 +17,7 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.read.LanceVectorSearchTableFunction class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -28,5 +29,12 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) + + // lance_vector_search(table, column, query, k, ...) table-valued function + extensions.injectTableFunction( + ( + LanceVectorSearchTableFunction.IDENTIFIER, + LanceVectorSearchTableFunction.INFO, + LanceVectorSearchTableFunction.BUILDER)) } } diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/DistanceTypes.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/DistanceTypes.scala new file mode 100644 index 00000000..09009407 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/DistanceTypes.scala @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read + +import org.lance.index.DistanceType + +object DistanceTypes { + + val Supported: Seq[String] = Seq("l2", "cosine", "dot", "hamming") + + def parse(metric: String, errPrefix: String): DistanceType = + metric.trim.toLowerCase match { + case "l2" | "euclidean" => DistanceType.L2 + case "cosine" => DistanceType.Cosine + case "dot" | "inner_product" | "ip" => DistanceType.Dot + case "hamming" => DistanceType.Hamming + case other => + throw new IllegalArgumentException( + s"$errPrefix: unsupported metric '$other'. " + + s"Expected one of: ${Supported.mkString(", ")}.") + } +} diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/LanceVectorSearchTableFunction.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/LanceVectorSearchTableFunction.scala new file mode 100644 index 00000000..779ba7a3 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/read/LanceVectorSearchTableFunction.scala @@ -0,0 +1,350 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, LongType} +import org.apache.spark.unsafe.types.UTF8String +import org.lance.ipc.Query +import org.lance.spark.{LanceDataset, LanceDataSource, LanceSparkReadOptions} +import org.lance.spark.utils.QueryUtils + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +/** + * `lance_vector_search` — table-valued function exposing Lance ANN / kNN pushdown to Spark SQL. + * + * Usage (positional; first four required): + * {{{ + * SELECT id, category + * FROM lance_vector_search( + * 'lance.db.items', -- table ref (catalog-qualified id OR filesystem URI) + * 'embedding', -- vector column + * array(0.1f, 0.2f, ...), -- query vector (float / double array) + * 10 -- k + * [, 'cosine' -- metric: l2 | cosine | dot | hamming + * [, 20 -- nprobes (IVF) + * [, 1 -- refine_factor (PQ) + * [, 64 -- ef (HNSW) + * [, true -- use_index (false = brute force) + * ]]]]] + * ) + * WHERE category = 'books' + * ORDER BY _distance; -- only if column was projected through by the scan + * }}} + * + * Named arguments (Spark 3.5+) are also accepted and recognised by parameter name: + * `table`, `column`, `query`, `k`, `metric`, `nprobes`, `refine_factor`, `ef`, `use_index`. + * + * The resulting LogicalPlan is a standard Lance DataSource read with the + * [[LanceSparkReadOptions.CONFIG_NEAREST]] option populated — so all existing + * pushdown, filter, and projection paths apply unchanged. + */ +object LanceVectorSearchTableFunction { + + val NAME = "lance_vector_search" + + val IDENTIFIER: FunctionIdentifier = FunctionIdentifier(NAME) + + val INFO: ExpressionInfo = new ExpressionInfo( + "org.lance.spark.read.LanceVectorSearchTableFunction", + null, + NAME, + "_FUNC_(table, column, query, k" + + "[, metric[, nprobes[, refine_factor[, ef[, use_index]]]]]) - " + + "Approximate nearest-neighbour search over a Lance vector column.", + "", + """ + | Examples: + | > SELECT id, _distance FROM _FUNC_('lance.db.items', 'embedding', + | array(0.1f, 0.2f, 0.3f), 10, 'cosine'); + """.stripMargin, + "", + "table_funcs", + "", + "", + "built-in") + + val BUILDER: Seq[Expression] => LogicalPlan = (args: Seq[Expression]) => buildPlan(args) + + /** + * Core builder. Separated from [[BUILDER]] to make unit testing straightforward. + */ + def buildPlan(args: Seq[Expression]): LogicalPlan = { + val parsed = parseArgs(args) + + val query = buildQuery(parsed) + val queryJson = QueryUtils.queryToString(query) + + val spark = SparkSession.active + val (datasetUri, storageOptions) = resolveDatasetLocation(spark, parsed.table) + + val reader = spark.read.format(LanceDataSource.name) + // Apply catalog-derived storage options first so the caller-facing `nearest` + // cannot be clobbered by per-table config. + storageOptions.foreach { case (k, v) => reader.option(k, v) } + reader.option(LanceSparkReadOptions.CONFIG_NEAREST, queryJson) + + reader.load(datasetUri).queryExecution.analyzed + } + + // ─── Argument parsing ────────────────────────────────────────────────────── + + private[read] case class ParsedArgs( + table: String, + column: String, + query: Array[Float], + k: Int, + metric: Option[String], + nprobes: Option[Int], + refineFactor: Option[Int], + ef: Option[Int], + useIndex: Option[Boolean]) + + /** Test-only hook so unit tests can exercise argument parsing without a SparkSession. */ + private[read] def parseArgsForTest(args: Seq[Expression]): ParsedArgs = parseArgs(args) + + private def parseArgs(args: Seq[Expression]): ParsedArgs = { + val byName = scala.collection.mutable.Map.empty[String, Expression] + val positional = scala.collection.mutable.ArrayBuffer.empty[Expression] + args.foreach { expr => + NamedArgExtractor.unapply(expr) match { + case Some((name, value)) => byName(name.toLowerCase) = value + case None => positional += expr + } + } + + def atName(name: String): Option[Expression] = byName.get(name) + def atPos(idx: Int): Option[Expression] = + if (idx < positional.length) Some(positional(idx)) else None + def pick(name: String, idx: Int): Option[Expression] = atName(name).orElse(atPos(idx)) + + val tableExpr = pick("table", 0) + .getOrElse(throw missing("table")) + val columnExpr = pick("column", 1) + .getOrElse(throw missing("column")) + val queryExpr = pick("query", 2) + .getOrElse(throw missing("query")) + val kExpr = pick("k", 3) + .getOrElse(throw missing("k")) + + val table = evalString(tableExpr, "table") + val column = evalString(columnExpr, "column") + val queryVec = evalFloatArray(queryExpr, "query") + val k = evalInt(kExpr, "k") + require(k > 0, s"lance_vector_search: 'k' must be positive, got $k") + + val metric = pick("metric", 4).map(evalString(_, "metric")) + val nprobes = pick("nprobes", 5).map(evalInt(_, "nprobes")) + val refineFactor = pick("refine_factor", 6).map(evalInt(_, "refine_factor")) + val ef = pick("ef", 7).map(evalInt(_, "ef")) + val useIndex = pick("use_index", 8).map(evalBoolean(_, "use_index")) + + ParsedArgs(table, column, queryVec, k, metric, nprobes, refineFactor, ef, useIndex) + } + + private def missing(name: String): IllegalArgumentException = + new IllegalArgumentException(s"lance_vector_search: missing required argument '$name'") + + private def evalLiteral(expr: Expression, argName: String): Any = { + val folded = expr match { + case l: Literal => l + case other if other.foldable => Literal(other.eval(), other.dataType) + case other => + throw new IllegalArgumentException( + s"lance_vector_search: argument '$argName' must be a constant expression, got $other") + } + folded.value + } + + private def evalString(expr: Expression, argName: String): String = + evalLiteral(expr, argName) match { + case null => + throw new IllegalArgumentException(s"lance_vector_search: '$argName' cannot be null") + case s: UTF8String => s.toString + case s: String => s + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be a STRING literal, got ${other.getClass.getName}") + } + + private def evalInt(expr: Expression, argName: String): Int = evalLiteral(expr, argName) match { + case null => + throw new IllegalArgumentException(s"lance_vector_search: '$argName' cannot be null") + case i: java.lang.Integer => i.intValue() + case i: Int => i + case l: java.lang.Long => l.intValue() + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be an integer, got ${other.getClass.getName}") + } + + private def evalBoolean(expr: Expression, argName: String): Boolean = + evalLiteral(expr, argName) match { + case null => + throw new IllegalArgumentException(s"lance_vector_search: '$argName' cannot be null") + case b: java.lang.Boolean => b.booleanValue() + case b: Boolean => b + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be a BOOLEAN, got ${other.getClass.getName}") + } + + private def evalFloatArray(expr: Expression, argName: String): Array[Float] = { + val value = evalLiteral(expr, argName) + value match { + case null => + throw new IllegalArgumentException(s"lance_vector_search: '$argName' cannot be null") + case arr: ArrayData => + val dt = expr.dataType match { + case org.apache.spark.sql.types.ArrayType(elementType, _) => elementType + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be an ARRAY of numeric values, got $other") + } + def checkNotNull(i: Int): Unit = + if (arr.isNullAt(i)) { + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must not contain null elements (index $i)") + } + dt match { + case FloatType => + val out = new Array[Float](arr.numElements()) + var i = 0 + while (i < out.length) { + checkNotNull(i) + out(i) = arr.getFloat(i) + i += 1 + } + out + case DoubleType => + val out = new Array[Float](arr.numElements()) + var i = 0 + while (i < out.length) { + checkNotNull(i) + out(i) = arr.getDouble(i).toFloat + i += 1 + } + out + case IntegerType | LongType => + val out = new Array[Float](arr.numElements()) + var i = 0 + while (i < out.length) { + checkNotNull(i) + out(i) = if (dt == IntegerType) arr.getInt(i).toFloat else arr.getLong(i).toFloat + i += 1 + } + out + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be ARRAY, " + + s"got ARRAY<$other>") + } + case other => + throw new IllegalArgumentException( + s"lance_vector_search: '$argName' must be an ARRAY literal, got ${other.getClass.getName}") + } + } + + // ─── Query assembly ─────────────────────────────────────────────────────── + + private def buildQuery(p: ParsedArgs): Query = { + val b = new Query.Builder() + .setColumn(p.column) + .setKey(p.query) + .setK(p.k) + .setUseIndex(p.useIndex.getOrElse(true)) + p.metric.foreach(m => b.setDistanceType(DistanceTypes.parse(m, "lance_vector_search"))) + p.nprobes.foreach(b.setMinimumNprobes) + p.refineFactor.foreach(b.setRefineFactor) + p.ef.foreach(b.setEf) + b.build() + } + + // ─── Table resolution ───────────────────────────────────────────────────── + + /** + * Resolves a user-supplied table reference to a Lance dataset URI plus any storage options + * inherited from the catalog. Accepts either a catalog-qualified name (e.g. `lance.db.t`) or a + * plain filesystem URI. Catalog lookup uses [[SparkSession.table]] and walks the analysed plan + * for a [[LanceDataset]]. + */ + private def resolveDatasetLocation( + spark: SparkSession, + tableRef: String): (String, Map[String, String]) = { + val trimmed = tableRef.trim + // Heuristic: only treat as a filesystem URI if it carries a scheme or an absolute/relative + // path prefix. A bare `.lance` suffix is *not* enough — a catalog identifier may legitimately + // end in `.lance` (e.g. `cat.db.my.lance`). + val looksLikeUri = trimmed.contains("://") || trimmed.startsWith("/") || + trimmed.startsWith("./") || trimmed.startsWith("../") + if (looksLikeUri) { + return (trimmed, Map.empty) + } + val plan = + try { + spark.table(trimmed).queryExecution.analyzed + } catch { + case NonFatal(e) => + throw new IllegalArgumentException( + s"lance_vector_search: could not resolve table '$tableRef' " + + "(treat it as a catalog identifier or a Lance URI).", + e) + } + val lanceTable = plan.collectFirst { + case rel: DataSourceV2Relation if rel.table.isInstanceOf[LanceDataset] => + rel.table.asInstanceOf[LanceDataset] + }.getOrElse(throw new IllegalArgumentException( + s"lance_vector_search: table '$tableRef' does not resolve to a Lance dataset.")) + val readOpts = lanceTable.readOptions() + val storage = Option(readOpts.getStorageOptions) + .map(_.asScala.toMap) + .getOrElse(Map.empty[String, String]) + (readOpts.getDatasetUri, storage) + } + + /** + * Extracts a name → expression pair from a [[NamedArgumentExpression]] in Spark 3.5+. + * Uses reflection so this file compiles against older Spark versions where the class does not + * exist — those versions simply never produce such expressions, so the extractor returns None. + */ + private object NamedArgExtractor { + private val clazz: Class[_] = + try { + Class.forName("org.apache.spark.sql.catalyst.analysis.NamedArgumentExpression") + } catch { + case _: ClassNotFoundException => null + } + private val keyMethod: java.lang.reflect.Method = + if (clazz == null) null else clazz.getMethod("key") + private val valueMethod: java.lang.reflect.Method = + if (clazz == null) null else clazz.getMethod("value") + + def unapply(expr: Expression): Option[(String, Expression)] = { + if (clazz == null || !clazz.isInstance(expr)) { + None + } else { + Some(( + keyMethod.invoke(expr).asInstanceOf[String], + valueMethod.invoke(expr).asInstanceOf[Expression])) + } + } + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseLanceVectorSearchTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseLanceVectorSearchTest.java new file mode 100644 index 00000000..2fd12102 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseLanceVectorSearchTest.java @@ -0,0 +1,280 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * End-to-end coverage for the {@code lance_vector_search} SQL table-valued function. + * + *

All tests run in brute-force mode ({@code use_index=false}) so this class has no dependency on + * the vector-index DDL feature — it exercises the TVF mechanics in isolation: argument parsing, + * positional vs. named call sites, error surface, and {@code WHERE} composition. + */ +public abstract class BaseLanceVectorSearchTest { + + protected static final int DIM = 16; + protected static final int ROWS = 256; + protected static final long SEED = 1234567L; + + protected String catalogName = "lance_vec"; + protected String tableName; + protected String fullTable; + + protected SparkSession spark; + + @TempDir Path tempDir; + + @BeforeEach + public void setup() throws IOException { + Path rootPath = tempDir.resolve(UUID.randomUUID().toString()); + Files.createDirectories(rootPath); + String testRoot = rootPath.toString(); + this.spark = + SparkSession.builder() + .appName("lance-vector-search-test") + .master("local[2]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", testRoot) + .config("spark.sql.catalog." + catalogName + ".single_level_ns", "true") + .getOrCreate(); + this.tableName = "vec_" + UUID.randomUUID().toString().replace("-", ""); + this.fullTable = catalogName + ".default." + this.tableName; + } + + @AfterEach + public void tearDown() throws IOException { + if (spark != null) { + spark.close(); + spark = null; + } + } + + // ─── Tests ──────────────────────────────────────────────────────────────── + + @Test + public void testTvfBruteForceReturnsPlantedNeighbor() { + prepareDataset(); + Set ids = collectIds(runTvfSql(/* k= */ 10, "l2")); + Assertions.assertTrue( + ids.contains(plantedRowId()), + "Planted neighbour id=" + plantedRowId() + " missing from top-k results " + ids); + } + + @Test + public void testTvfPreFilter() { + prepareDataset(); + Dataset result = + spark.sql( + "SELECT id, category FROM lance_vector_search('" + + fullTable + + "', 'emb', " + + queryVectorLiteral() + + ", 10, 'l2', 20, 1, 64, false) " + + "WHERE category = 'odd'"); + List rows = result.collectAsList(); + Assertions.assertFalse(rows.isEmpty(), "Pre-filter must leave at least one row"); + for (Row r : rows) { + Assertions.assertEquals("odd", r.getString(1)); + } + } + + @Test + public void testTvfRejectsNonPositiveK() { + prepareDataset(); + Exception ex = + Assertions.assertThrows( + Exception.class, + () -> + spark + .sql( + "SELECT * FROM lance_vector_search('" + + fullTable + + "', 'emb', " + + queryVectorLiteral() + + ", 0)") + .collect()); + String msg = rootCauseMessage(ex); + Assertions.assertTrue( + msg.contains("k") && msg.contains("positive"), + "Expected complaint about non-positive k, got: " + msg); + } + + @Test + public void testTvfRejectsUnknownMetric() { + prepareDataset(); + Exception ex = + Assertions.assertThrows( + Exception.class, + () -> + spark + .sql( + "SELECT * FROM lance_vector_search('" + + fullTable + + "', 'emb', " + + queryVectorLiteral() + + ", 5, 'manhattan')") + .collect()); + String msg = rootCauseMessage(ex); + Assertions.assertTrue( + msg.toLowerCase(Locale.ROOT).contains("metric"), + "Expected complaint about unsupported metric, got: " + msg); + } + + @Test + public void testTvfRejectsNonExistentTable() { + Exception ex = + Assertions.assertThrows( + Exception.class, + () -> + spark + .sql( + "SELECT * FROM lance_vector_search('" + + catalogName + + ".default.does_not_exist_" + + UUID.randomUUID().toString().replace('-', '_') + + "', 'emb', " + + queryVectorLiteral() + + ", 5)") + .collect()); + Assertions.assertNotNull(ex.getMessage()); + } + + // ─── Helpers ────────────────────────────────────────────────────────────── + + /** + * Creates a table with a 16-dim vector column plus two scalar columns (id, category), inserts + * {@link #ROWS} deterministic rows, and "plants" the neighbour closest to {@link #queryVector()} + * at {@link #plantedRowId()}. Data is split across two inserts to force at least two fragments. + */ + protected void prepareDataset() { + spark.sql( + String.format( + "CREATE TABLE %s (id INT NOT NULL, category STRING, emb ARRAY NOT NULL) " + + "USING lance TBLPROPERTIES ('emb.arrow.fixed-size-list.size' = '%d')", + fullTable, DIM)); + int half = ROWS / 2; + insertRange(0, half); + insertRange(half, ROWS); + } + + private void insertRange(int from, int to) { + Random rng = new Random(SEED + from); + StringBuilder sql = new StringBuilder(); + sql.append("INSERT INTO ").append(fullTable).append(" VALUES "); + boolean first = true; + for (int i = from; i < to; i++) { + if (!first) { + sql.append(", "); + } + first = false; + String cat = (i % 2 == 0) ? "even" : "odd"; + sql.append("(") + .append(i) + .append(", '") + .append(cat) + .append("', array(") + .append(vectorLiteral(i, rng)) + .append("))"); + } + spark.sql(sql.toString()); + } + + private String vectorLiteral(int i, Random rng) { + float[] query = queryVector(); + StringBuilder sb = new StringBuilder(); + for (int d = 0; d < DIM; d++) { + if (d > 0) sb.append(", "); + float v; + if (i == plantedRowId()) { + v = query[d] + ((rng.nextFloat() - 0.5f) * 0.001f); + } else { + v = rng.nextFloat() * 10.0f - 5.0f; + } + sb.append(Float.toString(v)).append("f"); + } + return sb.toString(); + } + + protected int plantedRowId() { + return 42; + } + + protected float[] queryVector() { + float[] v = new float[DIM]; + for (int i = 0; i < DIM; i++) { + v[i] = (float) (0.1 * (i + 1)); + } + return v; + } + + protected String queryVectorLiteral() { + float[] v = queryVector(); + StringBuilder sb = new StringBuilder(); + sb.append("array("); + for (int i = 0; i < v.length; i++) { + if (i > 0) sb.append(", "); + sb.append("CAST(").append(v[i]).append(" AS FLOAT)"); + } + sb.append(")"); + return sb.toString(); + } + + protected Dataset runTvfSql(int k, String metric) { + return spark.sql( + "SELECT id FROM lance_vector_search('" + + fullTable + + "', 'emb', " + + queryVectorLiteral() + + ", " + + k + + ", '" + + metric + + "', 20, 1, 64, false)"); + } + + protected Set collectIds(Dataset df) { + return df.collectAsList().stream().map(r -> r.getInt(0)).collect(Collectors.toSet()); + } + + protected static String rootCauseMessage(Throwable t) { + Throwable cur = t; + while (cur.getCause() != null && cur.getCause() != cur) { + cur = cur.getCause(); + } + String msg = cur.getMessage(); + return msg == null ? cur.getClass().getName() : msg; + } +} diff --git a/lance-spark-base_2.12/src/test/scala/org/lance/spark/read/VectorSearchArgParsingTest.scala b/lance-spark-base_2.12/src/test/scala/org/lance/spark/read/VectorSearchArgParsingTest.scala new file mode 100644 index 00000000..e04e99bd --- /dev/null +++ b/lance-spark-base_2.12/src/test/scala/org/lance/spark/read/VectorSearchArgParsingTest.scala @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, LongType} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class VectorSearchArgParsingTest { + + private def floatArrayLit(vs: Float*): Literal = + Literal.create(new GenericArrayData(vs.toArray[Any]), ArrayType(FloatType)) + + @Test def parseArgsAcceptsRequiredPositional(): Unit = { + val args = Seq( + Literal("cat.db.t"), + Literal("emb"), + floatArrayLit(0.1f, 0.2f, 0.3f), + Literal(5)) + val parsed = LanceVectorSearchTableFunction.parseArgsForTest(args) + assertEquals("cat.db.t", parsed.table) + assertEquals("emb", parsed.column) + assertArrayEquals(Array(0.1f, 0.2f, 0.3f), parsed.query) + assertEquals(5, parsed.k) + assertTrue(parsed.metric.isEmpty) + assertTrue(parsed.useIndex.isEmpty) + } + + @Test def parseArgsAcceptsOptionalPositional(): Unit = { + val args = Seq( + Literal("t"), + Literal("c"), + floatArrayLit(1.0f), + Literal(3), + Literal("cosine"), + Literal(20), + Literal(2), + Literal(64), + Literal(false)) + val parsed = LanceVectorSearchTableFunction.parseArgsForTest(args) + assertEquals(Some("cosine"), parsed.metric) + assertEquals(Some(20), parsed.nprobes) + assertEquals(Some(2), parsed.refineFactor) + assertEquals(Some(64), parsed.ef) + assertEquals(Some(false), parsed.useIndex) + } + + @Test def parseArgsConvertsArrayDoubleAndIntQueries(): Unit = { + val asDouble = Literal.create( + new GenericArrayData(Array[Any](0.1d, 0.2d, 0.3d)), + ArrayType(DoubleType)) + val pd = LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), asDouble, Literal(3))) + assertEquals(3, pd.query.length) + assertTrue(math.abs(pd.query(0) - 0.1f) < 1e-6f) + + val asInt = Literal.create( + new GenericArrayData(Array[Any](1, 2, 3)), + ArrayType(IntegerType)) + val pi = LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), asInt, Literal(3))) + assertArrayEquals(Array(1.0f, 2.0f, 3.0f), pi.query) + + val asLong = Literal.create( + new GenericArrayData(Array[Any](1L, 2L)), + ArrayType(LongType)) + val pl = LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), asLong, Literal(3))) + assertArrayEquals(Array(1.0f, 2.0f), pl.query) + } + + @Test def parseArgsRejectsNullElements(): Unit = { + val withNull = Literal.create( + new GenericArrayData(Array[Any](0.1f, null, 0.3f)), + ArrayType(FloatType)) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), withNull, Literal(3)))) + assertTrue(ex.getMessage.contains("null")) + } + + @Test def parseArgsRejectsNonPositiveK(): Unit = { + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), floatArrayLit(1.0f), Literal(0)))) + assertTrue(ex.getMessage.contains("positive")) + } + + @Test def parseArgsReportsMissingArgByName(): Unit = { + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceVectorSearchTableFunction.parseArgsForTest( + Seq(Literal("t"), Literal("c"), floatArrayLit(1.0f)))) + assertTrue(ex.getMessage.contains("'k'")) + } +}