diff --git a/.github/workflows/spark-search.yml b/.github/workflows/spark-search.yml new file mode 100644 index 000000000..713bcf1b7 --- /dev/null +++ b/.github/workflows/spark-search.yml @@ -0,0 +1,179 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Spark Search Docker + +on: + pull_request: + types: + - opened + - synchronize + - ready_for_review + - reopened + paths: + - ".github/workflows/spark-search.yml" + - "Makefile" + - "docker/**" + - "integration-tests/**" + - "lance-spark-base_2.12/src/main/java/org/lance/spark/search/**" + - "lance-spark-base_2.12/src/main/scala/org/lance/spark/search/**" + - "lance-spark-base_2.12/src/test/java/org/lance/spark/search/**" + - "lance-spark-*/src/main/scala/org/lance/spark/extensions/**" + - "pom.xml" + - "*/pom.xml" + workflow_dispatch: + inputs: + spark-version: + description: "Spark version to test" + required: true + default: "3.5" + scala-version: + description: "Scala version to test" + required: true + default: "2.13" + backends: + description: "Comma-separated test backends: local or local,rest-dir" + required: true + default: "local,rest-dir" + rest-uri: + description: "Optional REST namespace URI. If omitted, tests start a local REST directory namespace." + required: false + default: "" + rest-database: + description: "Optional database header value for an external REST namespace" + required: false + default: "" + docker-run-args: + description: "Extra docker run args for docker-test" + required: false + default: "" + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + SPARK_VERSION: ${{ github.event.inputs['spark-version'] || '3.5' }} + SCALA_VERSION: ${{ github.event.inputs['scala-version'] || '2.13' }} + SEARCH_TEST_BACKENDS: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.backends || 'local,rest-dir' }} + SEARCH_PYTEST_CMD: >- + pytest /home/lance/tests/test_lance_spark.py::TestDQLSearchTableFunctions + -v --timeout=180 + +jobs: + search-docker-test: + name: Search Docker Test + runs-on: ubuntu-24.04 + timeout-minutes: 90 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + - name: Set up Java + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: 17 + cache: "maven" + - name: Resolve Docker build args + id: docker-args + run: | + make print-docker-build-args SPARK_VERSION=${SPARK_VERSION} SCALA_VERSION=${SCALA_VERSION} >> $GITHUB_OUTPUT + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build test-base image (cached) + uses: docker/build-push-action@v6 + with: + context: docker + file: docker/Dockerfile.test-base + load: true + tags: lance-spark-test-base:${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + build-args: | + SPARK_DOWNLOAD_VERSION=${{ steps.docker-args.outputs.spark-download-version }} + SPARK_MAJOR_VERSION=${{ env.SPARK_VERSION }} + SCALA_VERSION=${{ env.SCALA_VERSION }} + PY4J_VERSION=${{ steps.docker-args.outputs.py4j-version }} + SPARK_SCALA_SUFFIX=${{ steps.docker-args.outputs.spark-scala-suffix }} + cache-from: type=gha,scope=search-test-base-${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + cache-to: type=gha,mode=max,scope=search-test-base-${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + - name: Build bundle + run: make bundle SPARK_VERSION=${SPARK_VERSION} SCALA_VERSION=${SCALA_VERSION} + - name: Build test image + run: | + make docker-build-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + LANCE_NAMESPACE_IMPL_VERSION=${{ steps.docker-args.outputs.lance-namespace-impl-version }} + - name: Run directory namespace search tests + if: ${{ contains(env.SEARCH_TEST_BACKENDS, 'local') }} + run: | + make docker-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + TEST_BACKENDS=local \ + PYTEST_CMD="${SEARCH_PYTEST_CMD}" + - name: Resolve REST namespace URI + id: rest + if: ${{ contains(env.SEARCH_TEST_BACKENDS, 'rest-dir') }} + env: + INPUT_REST_URI: ${{ github.event.inputs['rest-uri'] }} + INPUT_DOCKER_RUN_ARGS: ${{ github.event.inputs['docker-run-args'] }} + run: | + rest_uri="${INPUT_REST_URI}" + docker_run_args="${INPUT_DOCKER_RUN_ARGS}" + start_rest_dir="false" + rest_dir_root="" + rest_dir_port="" + + if [ -z "${rest_uri}" ]; then + rest_dir_port="10024" + rest_dir_root="/home/lance/rest-data" + rest_uri="http://127.0.0.1:${rest_dir_port}" + start_rest_dir="true" + fi + + echo "uri=${rest_uri}" >> "$GITHUB_OUTPUT" + echo "start_rest_dir=${start_rest_dir}" >> "$GITHUB_OUTPUT" + echo "rest_dir_root=${rest_dir_root}" >> "$GITHUB_OUTPUT" + echo "rest_dir_port=${rest_dir_port}" >> "$GITHUB_OUTPUT" + { + echo "docker_run_args<> "$GITHUB_OUTPUT" + - name: Run REST directory namespace search tests + if: ${{ contains(env.SEARCH_TEST_BACKENDS, 'rest-dir') }} + env: + LANCE_SPARK_REST_URI: ${{ steps.rest.outputs.uri }} + LANCE_SPARK_REST_API_KEY: ${{ secrets.LANCE_SPARK_REST_API_KEY }} + LANCE_SPARK_REST_DATABASE: ${{ github.event.inputs['rest-database'] }} + LANCE_SPARK_START_REST_DIR: ${{ steps.rest.outputs.start_rest_dir }} + LANCE_SPARK_REST_DIR_ROOT: ${{ steps.rest.outputs.rest_dir_root }} + LANCE_SPARK_REST_DIR_PORT: ${{ steps.rest.outputs.rest_dir_port }} + DOCKER_RUN_ARGS: ${{ steps.rest.outputs.docker_run_args }} + run: | + make docker-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + TEST_BACKENDS=rest-dir \ + LANCE_SPARK_REST_URI="${LANCE_SPARK_REST_URI}" \ + LANCE_SPARK_REST_API_KEY="${LANCE_SPARK_REST_API_KEY}" \ + LANCE_SPARK_REST_DATABASE="${LANCE_SPARK_REST_DATABASE}" \ + LANCE_SPARK_START_REST_DIR="${LANCE_SPARK_START_REST_DIR}" \ + LANCE_SPARK_REST_DIR_ROOT="${LANCE_SPARK_REST_DIR_ROOT}" \ + LANCE_SPARK_REST_DIR_PORT="${LANCE_SPARK_REST_DIR_PORT}" \ + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS}" \ + PYTEST_CMD="${SEARCH_PYTEST_CMD}" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 124cadbdf..08b76145d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -58,6 +58,38 @@ To auto-format the code, run: make format ``` +## Docker Integration Tests + +Build the Spark bundle and Docker integration-test image before running Docker tests: + +```shell +make bundle SPARK_VERSION=3.5 SCALA_VERSION=2.13 +make docker-build-test SPARK_VERSION=3.5 SCALA_VERSION=2.13 +make docker-test SPARK_VERSION=3.5 SCALA_VERSION=2.13 +``` + +Use `PYTEST_CMD` to run a targeted pytest path in the Docker image. For example, run only the SQL search table-function tests against the directory namespace: + +```shell +make docker-test SPARK_VERSION=3.5 SCALA_VERSION=2.13 \ + TEST_BACKENDS=local \ + PYTEST_CMD="pytest /home/lance/tests/test_lance_spark.py::TestDQLSearchTableFunctions -v --timeout=180" +``` + +To also validate a REST namespace backed by a directory namespace, let the Docker test container start the OSS Lance REST adapter: + +```shell +make docker-test SPARK_VERSION=3.5 SCALA_VERSION=2.13 \ + TEST_BACKENDS=local,rest-dir \ + LANCE_SPARK_START_REST_DIR=true \ + LANCE_SPARK_REST_URI=http://127.0.0.1:10024 \ + PYTEST_CMD="pytest /home/lance/tests/test_lance_spark.py::TestDQLSearchTableFunctions -v --timeout=180" +``` + +To run against an already-running compatible REST namespace server instead, omit `LANCE_SPARK_START_REST_DIR` and pass that server's URI with `LANCE_SPARK_REST_URI`. + +The `Spark Search Docker` GitHub Actions workflow runs the same targeted Docker tests. Pull requests run directory namespace and REST-directory namespace coverage automatically. Use workflow dispatch with `rest-uri` only when validating against an external REST namespace server. + ## Documentation ### Setup diff --git a/Makefile b/Makefile index e6346c6dd..e57660eb9 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,7 @@ endif DOCKER_CACHE_FROM ?= DOCKER_CACHE_TO ?= LANCE_NAMESPACE_IMPL_VERSION ?= $(shell sed -n 's:.*\(.*\).*:\1:p' pom.xml | head -n 1) +PYTEST_CMD ?= pytest /home/lance/tests/ -v --timeout=180 DOCKER_COMPOSE := $(shell \ if docker compose version >/dev/null 2>&1; then \ @@ -190,6 +191,12 @@ docker-test: $(if $(LANCEDB_API_KEY),-e LANCEDB_API_KEY=$(LANCEDB_API_KEY)) \ $(if $(LANCEDB_HOST_OVERRIDE),-e LANCEDB_HOST_OVERRIDE=$(LANCEDB_HOST_OVERRIDE)) \ $(if $(LANCEDB_REGION),-e LANCEDB_REGION=$(LANCEDB_REGION)) \ + $(if $(LANCE_SPARK_REST_URI),-e LANCE_SPARK_REST_URI=$(LANCE_SPARK_REST_URI)) \ + $(if $(LANCE_SPARK_REST_API_KEY),-e LANCE_SPARK_REST_API_KEY=$(LANCE_SPARK_REST_API_KEY)) \ + $(if $(LANCE_SPARK_REST_DATABASE),-e LANCE_SPARK_REST_DATABASE=$(LANCE_SPARK_REST_DATABASE)) \ + $(if $(LANCE_SPARK_START_REST_DIR),-e LANCE_SPARK_START_REST_DIR=$(LANCE_SPARK_START_REST_DIR)) \ + $(if $(LANCE_SPARK_REST_DIR_ROOT),-e LANCE_SPARK_REST_DIR_ROOT=$(LANCE_SPARK_REST_DIR_ROOT)) \ + $(if $(LANCE_SPARK_REST_DIR_PORT),-e LANCE_SPARK_REST_DIR_PORT=$(LANCE_SPARK_REST_DIR_PORT)) \ $(if $(TEST_BACKENDS),-e TEST_BACKENDS=$(TEST_BACKENDS)) \ $(if $(LANCE_FTS_FORMAT_VERSION),-e LANCE_FTS_FORMAT_VERSION=$(LANCE_FTS_FORMAT_VERSION)) \ $(if $(AWS_REGION),-e AWS_REGION=$(AWS_REGION)) \ @@ -203,8 +210,9 @@ docker-test: $(if $(AWS_SESSION_TOKEN),-e AWS_SESSION_TOKEN=$(AWS_SESSION_TOKEN)) \ $(if $(AWS_PROFILE),-e AWS_PROFILE=$(AWS_PROFILE)) \ $(if $(AWS_PROFILE),-v $(HOME)/.aws:/root/.aws:ro) \ + $(DOCKER_RUN_ARGS) \ lance-spark-test:$(SPARK_VERSION)_$(SCALA_VERSION) \ - "pytest /home/lance/tests/ -v --timeout=180" + "$(PYTEST_CMD)" # ============================================================================= # Benchmark @@ -295,6 +303,7 @@ help: @echo " docker-build-test-base - Build test base image (system deps + Spark)" @echo " docker-build-test - Build test image (base + bundle JAR)" @echo " docker-test - Run integration tests in lance-spark-test container" + @echo " Override PYTEST_CMD to run a targeted pytest command" @echo "" @echo "Benchmark:" @echo " benchmark-build - Build benchmark jar (shared by TPC-DS and TPC-H)" diff --git a/docker/Dockerfile.test b/docker/Dockerfile.test index dc38e5dcd..e95d4446b 100644 --- a/docker/Dockerfile.test +++ b/docker/Dockerfile.test @@ -35,6 +35,7 @@ RUN mkdir -p /home/lance/warehouse /home/lance/spark-events /home/lance/data # Copy tests RUN mkdir -p /home/lance/tests COPY integration-tests/ /home/lance/tests/ +RUN javac -cp "${SPARK_HOME}/jars/*" /home/lance/tests/LanceRestDirNamespaceServer.java WORKDIR ${SPARK_HOME} COPY docker/entrypoint.sh . diff --git a/docs/src/config.md b/docs/src/config.md index 56bcaf21a..2c5727347 100644 --- a/docs/src/config.md +++ b/docs/src/config.md @@ -49,6 +49,9 @@ Lance provides SQL extensions that add additional functionality beyond standard The following features require the Lance Spark SQL extension to be enabled: +- [VECTOR_SEARCH](operations/dql/vector-search.md) - Run vector similarity search through Lance namespace execution +- [SEARCH](operations/dql/search.md) - Run full-text search through Lance namespace execution +- [HYBRID_SEARCH](operations/dql/hybrid-search.md) - Combine vector and full-text search with reciprocal rank fusion - [ADD COLUMNS with backfill](operations/dml/add-columns.md) - Add new columns and backfill existing rows with data - [UPDATE COLUMNS with backfill](operations/dml/update-columns.md) - Update existing columns using data from a source - [OPTIMIZE](operations/ddl/optimize.md) - Compact table fragments for improved query performance diff --git a/docs/src/index.md b/docs/src/index.md index 48335ae1f..6361b9e8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,6 +17,7 @@ Specifically, you can use the Apache Spark Connector for Lance to: * **Read & Write Lance Datasets**: Seamlessly read and write datasets stored in the Lance format using Spark. * **Distributed, Parallel Scans**: Leverage Spark's distributed computing capabilities to perform parallel scans on Lance datasets. * **Column and Filter Pushdown**: Optimize query performance by pushing down column selections and filters to the data source. +* **SQL Search Table Functions**: Run [vector](operations/dql/vector-search.md), [full-text](operations/dql/search.md), and [hybrid](operations/dql/hybrid-search.md) search through Lance namespace execution. ## Quick Start @@ -28,4 +29,4 @@ make docker-build make docker-up ``` -And then open the notebook at `http://localhost:8888`. \ No newline at end of file +And then open the notebook at `http://localhost:8888`. diff --git a/docs/src/operations/ddl/create-index.md b/docs/src/operations/ddl/create-index.md index 457648fe5..ff50ffafc 100755 --- a/docs/src/operations/ddl/create-index.md +++ b/docs/src/operations/ddl/create-index.md @@ -147,6 +147,8 @@ Create an FTS index on a text column: ); ``` +Query the indexed column with the [SEARCH](../dql/search.md) table function. + ## Output The `CREATE INDEX` command returns the following information about the operation: diff --git a/docs/src/operations/dql/.pages b/docs/src/operations/dql/.pages index 561ff5ee7..c4f34397f 100644 --- a/docs/src/operations/dql/.pages +++ b/docs/src/operations/dql/.pages @@ -1,3 +1,6 @@ title: DQL nav: - select.md + - vector-search.md + - search.md + - hybrid-search.md diff --git a/docs/src/operations/dql/hybrid-search.md b/docs/src/operations/dql/hybrid-search.md new file mode 100644 index 000000000..18285c6a6 --- /dev/null +++ b/docs/src/operations/dql/hybrid-search.md @@ -0,0 +1,86 @@ +# HYBRID_SEARCH + +Run vector search and full-text search together from Spark SQL, then rerank the combined results with reciprocal rank fusion. + +!!! warning "Spark Extension Required" + `HYBRID_SEARCH` requires the Lance Spark SQL extension to be enabled. See [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +!!! note "Namespace Tables Required" + `HYBRID_SEARCH` resolves the `table` argument through a Spark catalog and executes both side queries through the Lance namespace `queryTable` API. Use a Lance namespace catalog table such as `lance.default.documents`, not a raw Lance dataset path. + +!!! note "Named Arguments" + Named arguments require Spark 3.5 or later. On Spark 3.4, use the positional form. + +## Basic Usage + +`HYBRID_SEARCH` returns the selected table columns plus `_distance`, `_score`, and `_relevance_score`. Rows that only match one side have null for the other side's metric. + +=== "SQL" + ```sql + SELECT id, body, _distance, _score, _relevance_score + FROM HYBRID_SEARCH( + table => 'lance.default.documents', + query_vector => array(0.12, 0.34, 0.56, 0.78), + query => 'vector database', + vector_column => 'embedding', + search_columns => array('body'), + columns => array('id', 'body'), + num_results => 10, + candidates => 50, + rrf_k => 60.0 + ) + ORDER BY _relevance_score DESC; + ``` + +## Positional Form + +Use positional arguments for simple calls and Spark 3.4 compatibility. + +=== "SQL" + ```sql + SELECT * + FROM HYBRID_SEARCH('lance.default.documents', array(0.12, 0.34, 0.56), 'lance', 5); + ``` + +## Arguments + +| Argument | Type | Required | Description | +|----------|------|----------|-------------| +| `table` | String | Yes | Catalog table name to search. | +| `query_vector` | Array numeric literal | Yes | Query vector. | +| `query` or `search_query` | String | Yes | Full-text query string. | +| `vector_column` | String | No | Vector column name. Lance defaults to `vector` when omitted. | +| `search_columns` | Array string literal | No | Text columns to search. When omitted, Lance uses the indexed columns configured for the FTS index. | +| `num_results`, `limit`, or `k` | Integer | No | Number of final reranked results. Defaults to `10`. | +| `candidates`, `num_candidates`, or `candidate_count` | Integer | No | Number of rows to fetch from each side before reranking. Defaults to `num_results + offset`. Values below `num_results + offset` are raised to that minimum. | +| `rrf_k` | Float | No | Reciprocal rank fusion constant. Defaults to `60.0`. | +| `columns` | Array string literal | No | Output table columns. `_distance`, `_score`, and `_relevance_score` are always included. Use `array('*')` or omit this argument for all table columns. | +| `filter` | String | No | SQL filter expression evaluated by Lance on both side queries. | +| `offset` | Integer | No | Number of reranked results to skip after fusion. Defaults to `0`. | +| `version` | Long | No | Lance table version to search. | +| `distance_type` | String | No | Distance metric such as `l2`, `cosine`, or `dot`. | +| `nprobes`, `ef`, `refine_factor` | Integer | No | Vector index search tuning parameters. | +| `lower_bound`, `upper_bound` | Float | No | Distance bounds. | +| `bypass_vector_index`, `fast_search`, `prefilter`, `with_row_id` | Boolean | No | Lance query options. `with_row_id` adds `_rowid` to the output. | + +## Reranking + +Hybrid search performs reciprocal rank fusion in Spark: + +```text +_relevance_score = sum(1.0 / (rank + rrf_k)) +``` + +Ranks are zero-based in each side's result set. `candidates` controls how many rows are fetched from each side before reranking. + +## Output + +The result includes the requested table columns plus nullable `_distance` and `_score` float columns and a non-null `_relevance_score` float column. If `with_row_id => true`, or if `_rowid` is listed in `columns`, the result also includes Lance row ids. + +## Execution + +Spark plans `HYBRID_SEARCH` as a DataSource V2 batch read with one input partition. The partition reader issues one vector `queryTable` request and one full-text `queryTable` request through the Lance namespace API, merges the two result sets in Spark with reciprocal rank fusion, and returns the final rows. With a REST namespace the two side searches can be handled by the REST server, while the final fusion currently happens in the Spark task. + +## Validation + +The Docker integration suite covers `HYBRID_SEARCH` against the directory namespace and a REST namespace backed by a directory namespace. The `Spark Search Docker` GitHub Actions workflow runs both backends for pull requests. diff --git a/docs/src/operations/dql/search.md b/docs/src/operations/dql/search.md new file mode 100644 index 000000000..c112155ee --- /dev/null +++ b/docs/src/operations/dql/search.md @@ -0,0 +1,79 @@ +# SEARCH + +Run Lance full-text search from Spark SQL using Lance namespace execution. + +!!! warning "Spark Extension Required" + `SEARCH` requires the Lance Spark SQL extension to be enabled. See [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +!!! note "Namespace Tables Required" + `SEARCH` resolves the `table` argument through a Spark catalog and executes through the Lance namespace `queryTable` API. Use a Lance namespace catalog table such as `lance.default.documents`, not a raw Lance dataset path. + +!!! note "Named Arguments" + Named arguments require Spark 3.5 or later. On Spark 3.4, use the positional form. + +## Basic Usage + +`SEARCH` returns the selected table columns plus `_score`. Create an FTS index before querying text columns. + +=== "SQL" + ```sql + ALTER TABLE lance.default.documents + CREATE INDEX body_fts USING fts (body) WITH ( + base_tokenizer = 'simple', + language = 'English', + max_token_length = 40, + lower_case = true, + stem = false, + remove_stop_words = false, + ascii_folding = false, + with_position = true + ); + + SELECT id, body, _score + FROM SEARCH( + table => 'lance.default.documents', + query => 'vector database', + search_columns => array('body'), + columns => array('id', 'body'), + limit => 10 + ) + ORDER BY _score DESC; + ``` + +See [CREATE INDEX](../ddl/create-index.md#full-text-search-index) for FTS index options. + +## Positional Form + +Use positional arguments for simple calls and Spark 3.4 compatibility. + +=== "SQL" + ```sql + SELECT * + FROM SEARCH('lance.default.documents', 'lance', 5); + ``` + +## Arguments + +| Argument | Type | Required | Description | +|----------|------|----------|-------------| +| `table` | String | Yes | Catalog table name to search. | +| `query` or `search_query` | String | Yes | Full-text query string. | +| `search_columns` | Array string literal | No | Text columns to search. When omitted, Lance uses the indexed columns configured for the FTS index. | +| `num_results`, `limit`, or `k` | Integer | No | Number of results. Defaults to `10`. | +| `columns` | Array string literal | No | Output table columns. `_score` is always included. Use `array('*')` or omit this argument for all table columns. | +| `filter` | String | No | SQL filter expression evaluated by Lance. | +| `offset` | Integer | No | Number of results to skip. | +| `version` | Long | No | Lance table version to search. | +| `with_row_id` | Boolean | No | Include Lance row ids in the result as `_rowid`. | + +## Output + +The result includes the requested table columns and a nullable `_score` float column. If `with_row_id => true`, or if `_rowid` is listed in `columns`, the result also includes Lance row ids. + +## Execution + +Spark plans `SEARCH` as a DataSource V2 batch read with one input partition. The partition reader calls the Lance namespace `queryTable` API. With a directory namespace the search runs in the Spark process executing that reader; with a REST namespace the REST server handles the namespace request. + +## Validation + +The Docker integration suite covers `SEARCH` against the directory namespace and a REST namespace backed by a directory namespace. The `Spark Search Docker` GitHub Actions workflow runs both backends for pull requests. diff --git a/docs/src/operations/dql/vector-search.md b/docs/src/operations/dql/vector-search.md new file mode 100644 index 000000000..14851da83 --- /dev/null +++ b/docs/src/operations/dql/vector-search.md @@ -0,0 +1,72 @@ +# VECTOR_SEARCH + +Run vector similarity search from Spark SQL using Lance namespace execution. + +!!! warning "Spark Extension Required" + `VECTOR_SEARCH` requires the Lance Spark SQL extension to be enabled. See [Spark SQL Extensions](../../config.md#spark-sql-extensions) for configuration details. + +!!! note "Namespace Tables Required" + `VECTOR_SEARCH` resolves the `table` argument through a Spark catalog and executes through the Lance namespace `queryTable` API. Use a Lance namespace catalog table such as `lance.default.items`, not a raw Lance dataset path. + +!!! note "Named Arguments" + Named arguments require Spark 3.5 or later. On Spark 3.4, use the positional form. + +!!! note "Replacing `nearest` Read Option" + The previous DataFrame `nearest` read option has been removed. Use `VECTOR_SEARCH` for vector similarity search so execution goes through the Lance namespace API. + +## Basic Usage + +`VECTOR_SEARCH` returns the selected table columns plus `_distance`. + +=== "SQL" + ```sql + SELECT id, title, _distance + FROM VECTOR_SEARCH( + table => 'lance.default.items', + query_vector => array(0.12, 0.34, 0.56, 0.78), + vector_column => 'embedding', + num_results => 10, + distance_type => 'l2', + columns => array('id', 'title') + ) + ORDER BY _distance; + ``` + +## Positional Form + +Use positional arguments for simple calls and Spark 3.4 compatibility. + +=== "SQL" + ```sql + SELECT * + FROM VECTOR_SEARCH('lance.default.items', array(0.12, 0.34, 0.56), 5); + ``` + +## Arguments + +| Argument | Type | Required | Description | +|----------|------|----------|-------------| +| `table` | String | Yes | Catalog table name to search. | +| `query_vector` | Array numeric literal | Yes | Query vector. | +| `vector_column` | String | No | Vector column name. Lance defaults to `vector` when omitted. | +| `num_results`, `limit`, or `k` | Integer | No | Number of results. Defaults to `10`. | +| `distance_type` | String | No | Distance metric such as `l2`, `cosine`, or `dot`. | +| `columns` | Array string literal | No | Output table columns. `_distance` is always included. Use `array('*')` or omit this argument for all table columns. | +| `filter` | String | No | SQL filter expression evaluated by Lance. | +| `offset` | Integer | No | Number of results to skip. Lance Spark requests `num_results + offset` rows from Lance before applying the offset. | +| `version` | Long | No | Lance table version to search. | +| `nprobes`, `ef`, `refine_factor` | Integer | No | Vector index search tuning parameters. | +| `lower_bound`, `upper_bound` | Float | No | Distance bounds. | +| `bypass_vector_index`, `fast_search`, `prefilter`, `with_row_id` | Boolean | No | Lance query options. `with_row_id` adds `_rowid` to the output. | + +## Output + +The result includes the requested table columns and a nullable `_distance` float column. If `with_row_id => true`, or if `_rowid` is listed in `columns`, the result also includes Lance row ids. + +## Execution + +Spark plans `VECTOR_SEARCH` as a DataSource V2 batch read with one input partition. The partition reader calls the Lance namespace `queryTable` API. With a directory namespace the search runs in the Spark process executing that reader; with a REST namespace the REST server handles the namespace request. + +## Validation + +The Docker integration suite covers `VECTOR_SEARCH` against the directory namespace and a REST namespace backed by a directory namespace. The `Spark Search Docker` GitHub Actions workflow runs both backends for pull requests. diff --git a/docs/src/performance.md b/docs/src/performance.md index 0cc13ef2e..aa0367dc9 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -106,7 +106,7 @@ Set via Spark write option `max_row_per_file` (default: 1,000,000). Controls the maximum number of rows per Lance fragment file. There is no specific recommended value, but be aware the default is 1 million rows. If you store many multimodal data columns (images, audio, embeddings) -without using [Lance blob encoding](operations/dml/dataframe-write.md#writing-blob-data), +without using [Lance blob encoding](operations/dml/insert-into.md#writing-blob-data), or store a lot of long text columns, the file size might become very large. From Lance's perspective, having very large files does not impact your read performance. But you may want to reduce this value depending on the limits in your choice of object storage. diff --git a/integration-tests/LanceRestDirNamespaceServer.java b/integration-tests/LanceRestDirNamespaceServer.java new file mode 100644 index 000000000..8178b9aa7 --- /dev/null +++ b/integration-tests/LanceRestDirNamespaceServer.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.lance.namespace.RestAdapter; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; + +public final class LanceRestDirNamespaceServer { + private LanceRestDirNamespaceServer() {} + + public static void main(String[] args) throws Exception { + String root = args.length > 0 ? args[0] : "/home/lance/rest-data"; + String host = args.length > 1 ? args[1] : "127.0.0.1"; + int port = args.length > 2 ? Integer.parseInt(args[2]) : 10024; + + Map backendConfig = new HashMap<>(); + backendConfig.put("root", root); + + RestAdapter adapter = new RestAdapter("dir", backendConfig, host, port); + Runtime.getRuntime().addShutdownHook(new Thread(adapter::close)); + adapter.start(); + System.out.printf( + "Lance REST directory namespace listening on http://%s:%d with root %s%n", + host, adapter.getPort(), root); + new CountDownLatch(1).await(); + } +} diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 45c9dffb9..43afcd970 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -9,11 +9,13 @@ """ import os +import socket import subprocess import time import uuid -import urllib.request import urllib.error +import urllib.parse +import urllib.request import pytest from pyspark.sql import SparkSession @@ -21,6 +23,10 @@ def pytest_configure(config): config.addinivalue_line("markers", "requires_rest: test only runs on REST-based backends") + config.addinivalue_line( + "markers", + "rest_dir_compatible: test can run against a local REST namespace backed by dir", + ) # --------------------------------------------------------------------------- @@ -177,6 +183,19 @@ def minio(): LANCEDB_API_KEY = os.environ.get("LANCEDB_API_KEY") LANCEDB_HOST_OVERRIDE = os.environ.get("LANCEDB_HOST_OVERRIDE") LANCEDB_REGION = os.environ.get("LANCEDB_REGION", "us-east-1") +LANCE_SPARK_REST_URI = os.environ.get("LANCE_SPARK_REST_URI") +LANCE_SPARK_REST_API_KEY = os.environ.get("LANCE_SPARK_REST_API_KEY") +LANCE_SPARK_REST_DATABASE = os.environ.get("LANCE_SPARK_REST_DATABASE") +LANCE_SPARK_START_REST_DIR = os.environ.get("LANCE_SPARK_START_REST_DIR", "").lower() in ( + "1", + "true", + "yes", +) +LANCE_SPARK_REST_DIR_ROOT = os.environ.get( + "LANCE_SPARK_REST_DIR_ROOT", + "/home/lance/rest-data", +) +LANCE_SPARK_REST_DIR_PORT = int(os.environ.get("LANCE_SPARK_REST_DIR_PORT", "10024")) AWS_S3_BUCKET_NAME = os.environ.get("AWS_S3_BUCKET_NAME") AWS_REGION = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") or "us-east-1" AWS_GLUE_CATALOG_ID = os.environ.get("AWS_GLUE_CATALOG_ID") @@ -193,6 +212,8 @@ def minio(): _all_backends = ["local", "azurite", "minio"] if LANCEDB_DB and LANCEDB_API_KEY: _all_backends.append("lancedb") +if LANCE_SPARK_REST_URI or LANCE_SPARK_START_REST_DIR: + _all_backends.append("rest-dir") if AWS_S3_BUCKET_NAME: _all_backends.append("glue") _backends = os.environ.get("TEST_BACKENDS", ",".join(_all_backends)).split(",") @@ -218,6 +239,8 @@ def spark(request): - **minio** – S3-compatible storage via the MinIO emulator - **lancedb** – LanceDB Cloud via REST API (requires ``LANCEDB_DB`` and ``LANCEDB_API_KEY`` env vars; skipped otherwise) + - **rest-dir** – local REST namespace backed by a directory namespace + (requires ``LANCE_SPARK_REST_URI`` and compatible tests) - **glue** – AWS Glue Data Catalog with S3 storage (requires ``AWS_S3_BUCKET_NAME`` and AWS credentials from the default AWS provider chain) @@ -252,6 +275,24 @@ def spark(request): f"spark.sql.catalog.{CATALOG}.headers.x-lancedb-database", LANCEDB_DB, ) ) + elif backend == "rest-dir": + rest_dir = request.getfixturevalue("rest_dir_namespace") + + builder = ( + builder + .config(f"spark.sql.catalog.{CATALOG}.impl", "rest") + .config(f"spark.sql.catalog.{CATALOG}.uri", rest_dir["uri"]) + ) + if LANCE_SPARK_REST_API_KEY: + builder = builder.config( + f"spark.sql.catalog.{CATALOG}.headers.x-api-key", + LANCE_SPARK_REST_API_KEY, + ) + if LANCE_SPARK_REST_DATABASE: + builder = builder.config( + f"spark.sql.catalog.{CATALOG}.headers.x-lancedb-database", + LANCE_SPARK_REST_DATABASE, + ) elif backend == "glue": if not AWS_S3_BUCKET_NAME or not AWS_GLUE_ROOT: pytest.skip("AWS_S3_BUCKET_NAME is required for Glue backend") @@ -327,6 +368,44 @@ def spark(request): session.stop() +@pytest.fixture(scope="session") +def rest_dir_namespace(): + """Provide a REST namespace backed by the local directory namespace.""" + if not LANCE_SPARK_START_REST_DIR: + if not LANCE_SPARK_REST_URI: + pytest.skip("LANCE_SPARK_REST_URI is required for rest-dir backend") + yield {"uri": LANCE_SPARK_REST_URI} + return + + os.makedirs(LANCE_SPARK_REST_DIR_ROOT, exist_ok=True) + uri = LANCE_SPARK_REST_URI or f"http://127.0.0.1:{LANCE_SPARK_REST_DIR_PORT}" + parsed = urllib.parse.urlparse(uri) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or LANCE_SPARK_REST_DIR_PORT + log_path = "/tmp/lance-rest-dir-namespace.log" + classpath = "/home/lance/tests:/opt/spark/jars/*" + + with open(log_path, "w", encoding="utf-8") as log: + proc = subprocess.Popen( + [ + "java", + "-cp", + classpath, + "LanceRestDirNamespaceServer", + LANCE_SPARK_REST_DIR_ROOT, + "127.0.0.1", + str(port), + ], + stdout=log, + stderr=subprocess.STDOUT, + ) + try: + _wait_for_tcp(host, port, proc, "Lance REST directory namespace") + yield {"uri": uri} + finally: + _stop_process(proc) + + @pytest.fixture def test_table(request, spark): """Provide a unique table name for each test to avoid isolation issues. @@ -350,8 +429,11 @@ def test_table(request, spark): @pytest.fixture(autouse=True) def _skip_by_backend(request, spark): """Auto-skip tests marked ``requires_rest`` on non-REST backends.""" + backend = getattr(spark, "_lance_backend", None) + if backend == "rest-dir" and not request.node.get_closest_marker("rest_dir_compatible"): + pytest.skip("rest-dir backend is only enabled for compatible tests") if request.node.get_closest_marker("requires_rest"): - if getattr(spark, "_lance_backend", None) != "lancedb": + if backend not in ("lancedb", "rest-dir"): pytest.skip("requires REST-based backend") @@ -390,3 +472,17 @@ def _stop_process(proc, timeout=10): except subprocess.TimeoutExpired: proc.kill() proc.wait() + + +def _wait_for_tcp(host, port, proc, name, timeout=30): + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with socket.create_connection((host, port), timeout=1): + return + except OSError: + if proc.poll() is not None: + raise RuntimeError(f"{name} exited unexpectedly") + time.sleep(0.5) + _stop_process(proc) + raise RuntimeError(f"{name} did not become healthy within {timeout} s") diff --git a/integration-tests/test_lance_spark.py b/integration-tests/test_lance_spark.py index 7d7c527b0..a2f30e4b1 100644 --- a/integration-tests/test_lance_spark.py +++ b/integration-tests/test_lance_spark.py @@ -117,6 +117,12 @@ def _assert_lance_index_metadata(spark, table_name, index_name, expected_type): return metadata +def _require_sql_search_backend(spark): + backend = getattr(spark, "_lance_backend", None) + if backend not in ("local", "rest-dir"): + pytest.skip("SQL search table functions are covered on local dir and rest-dir backends") + + # ============================================================================= # DDL (Data Definition Language) Tests # ============================================================================= @@ -1583,6 +1589,166 @@ def test_primary_key_persists_after_insert(self, spark): # DQL (Data Query Language) Tests # ============================================================================= +@pytest.mark.rest_dir_compatible +class TestDQLSearchTableFunctions: + """Test namespace-backed SQL search table functions.""" + + def test_vector_search_table_function(self, spark): + """Test VECTOR_SEARCH against namespace query execution.""" + _require_sql_search_backend(spark) + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT NOT NULL, + vector ARRAY NOT NULL + ) USING lance + TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4') + """) + spark.sql(""" + INSERT INTO default.test_table VALUES + (0, array(0.0, 0.0, 0.0, 0.0)), + (1, array(1.0, 1.0, 1.0, 1.0)), + (2, array(10.0, 10.0, 10.0, 10.0)) + """) + + rows = spark.sql(""" + SELECT id, _distance + FROM VECTOR_SEARCH('default.test_table', array(0.0, 0.0, 0.0, 0.0), 2) + ORDER BY _distance, id + """).collect() + + assert [row.id for row in rows] == [0, 1] + assert rows[0]["_distance"] == pytest.approx(0.0) + assert rows[1]["_distance"] > rows[0]["_distance"] + + def test_search_table_function(self, spark): + """Test SEARCH against a Lance FTS index.""" + _require_sql_search_backend(spark) + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT NOT NULL, + body STRING + ) USING lance + """) + spark.sql(""" + INSERT INTO default.test_table VALUES + (1, 'lance vector search'), + (2, 'spark connector table function'), + (3, 'lance full text search') + """) + spark.sql(""" + ALTER TABLE default.test_table + CREATE INDEX body_fts USING fts (body) + WITH ( + base_tokenizer='simple', + language='English', + max_token_length=40, + lower_case=true, + stem=false, + remove_stop_words=false, + ascii_folding=false, + with_position=true + ) + """) + + rows = spark.sql(""" + SELECT id, body, _score + FROM SEARCH('default.test_table', 'lance', 10) + ORDER BY id + """).collect() + + assert [row.id for row in rows] == [1, 3] + assert "lance" in rows[0].body + assert rows[0]["_score"] > 0.0 + + def test_hybrid_search_table_function(self, spark): + """Test HYBRID_SEARCH client-side RRF fusion over namespace results.""" + _require_sql_search_backend(spark) + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT NOT NULL, + body STRING, + vector ARRAY NOT NULL + ) USING lance + TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4') + """) + spark.sql(""" + INSERT INTO default.test_table VALUES + (1, 'lance vector search', array(0.0, 0.0, 0.0, 0.0)), + (2, 'spark connector table function', array(1.0, 1.0, 1.0, 1.0)), + (3, 'lance full text search', array(10.0, 10.0, 10.0, 10.0)) + """) + spark.sql(""" + ALTER TABLE default.test_table + CREATE INDEX body_fts USING fts (body) + WITH ( + base_tokenizer='simple', + language='English', + max_token_length=40, + lower_case=true, + stem=false, + remove_stop_words=false, + ascii_folding=false, + with_position=true + ) + """) + + rows = spark.sql(""" + SELECT id, body, _distance, _score, _relevance_score + FROM HYBRID_SEARCH('default.test_table', array(0.0, 0.0, 0.0, 0.0), 'lance', 3) + ORDER BY _relevance_score DESC, id + """).collect() + + assert [row.id for row in rows] == [1, 3, 2] + assert rows[0]["_distance"] == pytest.approx(0.0) + assert rows[0]["_score"] > 0.0 + assert rows[0]["_relevance_score"] > rows[1]["_relevance_score"] + assert rows[1]["_score"] > 0.0 + assert rows[2]["_score"] is None + assert rows[2]["_distance"] > rows[0]["_distance"] + + def test_vector_search_named_args_projection_rowid_and_offset(self, spark): + """Test named vector options used by Databricks-style SQL calls.""" + _require_sql_search_backend(spark) + if SPARK_VERSION < Version("3.5"): + pytest.skip("named table function arguments require Spark 3.5+") + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT NOT NULL, + vector ARRAY NOT NULL + ) USING lance + TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4') + """) + spark.sql(""" + INSERT INTO default.test_table VALUES + (0, array(0.0, 0.0, 0.0, 0.0)), + (1, array(1.0, 1.0, 1.0, 1.0)), + (2, array(10.0, 10.0, 10.0, 10.0)) + """) + + rows = spark.sql(""" + SELECT id, _rowid, _distance + FROM VECTOR_SEARCH( + table => 'default.test_table', + query_vector => array(0.0, 0.0, 0.0, 0.0), + vector_column => 'vector', + columns => array('id'), + num_results => 1, + offset => 1, + with_row_id => true + ) + ORDER BY _distance, id + """).collect() + + assert len(rows) == 1 + assert rows[0].id == 1 + assert rows[0]["_rowid"] >= 0 + assert rows[0]["_distance"] > 0.0 + + class TestDQLSelect: """Test DQL SELECT operations.""" diff --git a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 87bc6403c..924575c4f 100644 --- a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,9 +14,12 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.search.LanceSearchTableFunctions class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -30,6 +33,28 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // propagate blob source credentials from read scans to the write side extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectTableFunction( + ( + FunctionIdentifier("vector_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "vector_search"), + LanceSearchTableFunctions.vectorSearch _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "search"), + LanceSearchTableFunctions.search _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("hybrid_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "hybrid_search"), + LanceSearchTableFunctions.hybridSearch _)) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java similarity index 80% rename from lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java rename to lance-spark-3.4_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java index c2fef4e45..c2d97ba80 100644 --- a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java +++ b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.lance.spark.read; +package org.lance.spark.search; -public class SparkConnectorReadWithVectorSearchTest - extends BaseSparkConnectorReadWithVectorSearchTest {} +public class SparkSearchTableFunctionTest extends BaseSparkSearchTableFunctionTest {} diff --git a/lance-spark-3.4_2.13/pom.xml b/lance-spark-3.4_2.13/pom.xml index 825562087..85c2b194e 100644 --- a/lance-spark-3.4_2.13/pom.xml +++ b/lance-spark-3.4_2.13/pom.xml @@ -119,6 +119,7 @@ ../lance-spark-base_2.12/src/test/java + ../lance-spark-3.4_2.12/src/test/java diff --git a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 87bc6403c..924575c4f 100644 --- a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,9 +14,12 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.search.LanceSearchTableFunctions class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -30,6 +33,28 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // propagate blob source credentials from read scans to the write side extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectTableFunction( + ( + FunctionIdentifier("vector_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "vector_search"), + LanceSearchTableFunctions.vectorSearch _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "search"), + LanceSearchTableFunctions.search _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("hybrid_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "hybrid_search"), + LanceSearchTableFunctions.hybridSearch _)) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchRestNamespaceSmokeTest.java similarity index 80% rename from lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java rename to lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchRestNamespaceSmokeTest.java index c2fef4e45..7d4be1d2c 100644 --- a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/read/SparkConnectorReadWithVectorSearchTest.java +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchRestNamespaceSmokeTest.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.lance.spark.read; +package org.lance.spark.search; -public class SparkConnectorReadWithVectorSearchTest - extends BaseSparkConnectorReadWithVectorSearchTest {} +public class SparkSearchRestNamespaceSmokeTest extends BaseSparkSearchRestNamespaceSmokeTest {} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java new file mode 100644 index 000000000..c2d97ba80 --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/search/SparkSearchTableFunctionTest.java @@ -0,0 +1,16 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +public class SparkSearchTableFunctionTest extends BaseSparkSearchTableFunctionTest {} diff --git a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 87bc6403c..924575c4f 100644 --- a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,9 +14,12 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.search.LanceSearchTableFunctions class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -30,6 +33,28 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // propagate blob source credentials from read scans to the write side extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectTableFunction( + ( + FunctionIdentifier("vector_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "vector_search"), + LanceSearchTableFunctions.vectorSearch _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "search"), + LanceSearchTableFunctions.search _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("hybrid_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "hybrid_search"), + LanceSearchTableFunctions.hybridSearch _)) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 87bc6403c..924575c4f 100644 --- a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,9 +14,12 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy +import org.lance.spark.search.LanceSearchTableFunctions class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -30,6 +33,28 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // propagate blob source credentials from read scans to the write side extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectTableFunction( + ( + FunctionIdentifier("vector_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "vector_search"), + LanceSearchTableFunctions.vectorSearch _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "search"), + LanceSearchTableFunctions.search _)) + extensions.injectTableFunction( + ( + FunctionIdentifier("hybrid_search"), + new ExpressionInfo( + "org.lance.spark.search.LanceSearchTableFunctions", + "hybrid_search"), + LanceSearchTableFunctions.hybridSearch _)) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java index ecd7013f8..5f14e0290 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java @@ -14,15 +14,10 @@ package org.lance.spark; import org.lance.ReadOptions; -import org.lance.ipc.Query; import org.lance.namespace.LanceNamespace; -import org.lance.spark.utils.QueryUtils; import com.google.common.base.Preconditions; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.HashMap; import java.util.List; @@ -57,8 +52,7 @@ public class LanceSparkReadOptions implements Serializable { public static final String CONFIG_METADATA_CACHE_SIZE = "metadata_cache_size"; public static final String CONFIG_BATCH_SIZE = "batch_size"; public static final String CONFIG_TOP_N_PUSH_DOWN = "topN_push_down"; - - public static final String CONFIG_NEAREST = "nearest"; + private static final String DEPRECATED_CONFIG_NEAREST = "nearest"; /** * Whether executors should rebuild the namespace client and re-fetch storage options via {@code @@ -111,7 +105,6 @@ public class LanceSparkReadOptions implements Serializable { private final Integer indexCacheSize; private final Integer metadataCacheSize; private final int batchSize; - private transient Query nearest; private final boolean topNPushDown; private final Map storageOptions; @@ -141,7 +134,6 @@ private LanceSparkReadOptions(Builder builder) { this.indexCacheSize = builder.indexCacheSize; this.metadataCacheSize = builder.metadataCacheSize; this.batchSize = builder.batchSize; - this.nearest = builder.nearest; this.topNPushDown = builder.topNPushDown; this.storageOptions = new HashMap<>(builder.storageOptions); this.namespace = builder.namespace; @@ -254,10 +246,6 @@ public int getBatchSize() { return batchSize; } - public Query getNearest() { - return nearest; - } - public boolean isTopNPushDown() { return topNPushDown; } @@ -266,10 +254,6 @@ public Map getStorageOptions() { return storageOptions; } - public String getNearestJson() { - return QueryUtils.queryToString(nearest); - } - public LanceNamespace getNamespace() { return namespace; } @@ -321,7 +305,6 @@ public LanceSparkReadOptions withVersion(long newVersion) { .indexCacheSize(this.indexCacheSize) .metadataCacheSize(this.metadataCacheSize) .batchSize(this.batchSize) - .nearest(this.nearest) .topNPushDown(this.topNPushDown) .storageOptions(this.storageOptions) .namespace(this.namespace) @@ -357,17 +340,6 @@ public ReadOptions toReadOptions() { return builder.build(); } - private void writeObject(ObjectOutputStream out) throws IOException { - out.defaultWriteObject(); - out.writeObject(QueryUtils.queryToString(nearest)); - } - - private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { - in.defaultReadObject(); - String json = (String) in.readObject(); - this.nearest = QueryUtils.stringToQuery(json); - } - @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) { @@ -378,7 +350,6 @@ public boolean equals(Object o) { && batchSize == that.batchSize && topNPushDown == that.topNPushDown && executorCredentialRefresh == that.executorCredentialRefresh - && Objects.equals(nearest, that.nearest) && Objects.equals(datasetUri, that.datasetUri) && Objects.equals(blockSize, that.blockSize) && Objects.equals(version, that.version) @@ -398,7 +369,6 @@ public int hashCode() { indexCacheSize, metadataCacheSize, batchSize, - nearest, topNPushDown, storageOptions, tableId, @@ -410,7 +380,6 @@ public static class Builder { private String datasetUri; private boolean pushDownFilters = DEFAULT_PUSH_DOWN_FILTERS; private Integer blockSize; - private Query nearest; private Long version; private Integer indexCacheSize; private Integer metadataCacheSize; @@ -439,20 +408,6 @@ public Builder blockSize(Integer blockSize) { return this; } - public Builder nearest(Query nearest) { - this.nearest = nearest; - return this; - } - - public Builder nearest(String json) { - try { - this.nearest = QueryUtils.stringToQuery(json); - } catch (Exception e) { - throw new IllegalArgumentException("Failed to parse nearest query from json: " + json, e); - } - return this; - } - public Builder version(Long version) { this.version = version; return this; @@ -539,6 +494,9 @@ public Builder withCatalogDefaults(LanceSparkCatalogConfig catalogConfig) { * both call sites stay in sync and catalog-level configs reach the typed fields. */ private void parseTypedFlags(Map opts) { + Preconditions.checkArgument( + !opts.containsKey(DEPRECATED_CONFIG_NEAREST), + "The nearest read option is no longer supported; use VECTOR_SEARCH table function"); if (opts.containsKey(CONFIG_PUSH_DOWN_FILTERS)) { this.pushDownFilters = Boolean.parseBoolean(opts.get(CONFIG_PUSH_DOWN_FILTERS)); } @@ -562,9 +520,6 @@ private void parseTypedFlags(Map opts) { if (opts.containsKey(CONFIG_TOP_N_PUSH_DOWN)) { this.topNPushDown = Boolean.parseBoolean(opts.get(CONFIG_TOP_N_PUSH_DOWN)); } - if (opts.containsKey(CONFIG_NEAREST)) { - nearest(opts.get(CONFIG_NEAREST)); - } if (opts.containsKey(CONFIG_EXECUTOR_CREDENTIAL_REFRESH)) { this.executorCredentialRefresh = Boolean.parseBoolean(opts.get(CONFIG_EXECUTOR_CREDENTIAL_REFRESH)); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java index f2cecea28..c2cb39a4c 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java @@ -128,10 +128,6 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in scanOptions.filter(inputPartition.getWhereCondition().get()); } scanOptions.batchSize(readOptions.getBatchSize()); - if (readOptions.getNearest() != null) { - scanOptions.nearest(readOptions.getNearest()); - scanOptions.prefilter(true); - } if (inputPartition.getLimit().isPresent()) { scanOptions.limit(inputPartition.getLimit().get()); } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java index 362c358e1..fa8b0602f 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java @@ -298,7 +298,6 @@ private List pruneByRowAddrFilters(List allSplits) { *
  • Filters are present (unknown selectivity makes row count estimation unreliable) *
  • TopN sort orders are present (all fragments needed for global sort) *
  • Aggregation is pushed (e.g., COUNT(*) LIMIT — row counts don't apply) - *
  • Vector search (nearest) is active (needs global search across all fragments) *
  • Fragment row counts are unavailable * * @@ -312,7 +311,6 @@ private List pruneByLimit( || whereConditions.isPresent() || topNSortOrders.isPresent() || pushedAggregation.isPresent() - || readOptions.getNearest() != null || fragmentRowCounts.isEmpty()) { return allSplits; } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchInputPartition.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchInputPartition.java new file mode 100644 index 000000000..46d6f154d --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchInputPartition.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.types.StructType; + +public class LanceHybridSearchInputPartition implements InputPartition { + private static final long serialVersionUID = -7982173981273123897L; + + private final StructType schema; + private final LanceHybridSearchQuery query; + + public LanceHybridSearchInputPartition(StructType schema, LanceHybridSearchQuery query) { + this.schema = schema; + this.query = query; + } + + public StructType getSchema() { + return schema; + } + + public LanceHybridSearchQuery getQuery() { + return query; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchPartitionReaderFactory.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchPartitionReaderFactory.java new file mode 100644 index 000000000..06f8fcdc3 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchPartitionReaderFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; + +public class LanceHybridSearchPartitionReaderFactory implements PartitionReaderFactory { + private static final long serialVersionUID = 6821379812739812739L; + + @Override + public PartitionReader createReader(InputPartition partition) { + return new LanceHybridSearchRowPartitionReader(asHybridPartition(partition)); + } + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return false; + } + + private LanceHybridSearchInputPartition asHybridPartition(InputPartition partition) { + if (!(partition instanceof LanceHybridSearchInputPartition)) { + throw new IllegalArgumentException( + "Unknown InputPartition type. Expecting LanceHybridSearchInputPartition"); + } + return (LanceHybridSearchInputPartition) partition; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchQuery.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchQuery.java new file mode 100644 index 000000000..79546add8 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchQuery.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; + +public class LanceHybridSearchQuery implements Serializable { + private static final long serialVersionUID = -2476891273912739812L; + + private final LanceSearchQuery vectorQuery; + private final LanceSearchQuery fullTextQuery; + private final StructType vectorSchema; + private final StructType fullTextSchema; + private final Integer k; + private final Integer offset; + private final Float rrfK; + + public LanceHybridSearchQuery( + LanceSearchQuery vectorQuery, + LanceSearchQuery fullTextQuery, + StructType vectorSchema, + StructType fullTextSchema, + Integer k, + Integer offset, + Float rrfK) { + if (vectorQuery == null) { + throw new IllegalArgumentException("vector query is required"); + } + if (fullTextQuery == null) { + throw new IllegalArgumentException("full text query is required"); + } + if (vectorSchema == null) { + throw new IllegalArgumentException("vector schema is required"); + } + if (fullTextSchema == null) { + throw new IllegalArgumentException("full text schema is required"); + } + if (k == null || k <= 0) { + throw new IllegalArgumentException("k must be positive"); + } + if (offset != null && offset < 0) { + throw new IllegalArgumentException("offset must be non-negative"); + } + if (rrfK == null || rrfK <= 0.0f) { + throw new IllegalArgumentException("rrf_k must be positive"); + } + this.vectorQuery = vectorQuery; + this.fullTextQuery = fullTextQuery; + this.vectorSchema = vectorSchema; + this.fullTextSchema = fullTextSchema; + this.k = k; + this.offset = offset == null ? Integer.valueOf(0) : offset; + this.rrfK = rrfK; + } + + public LanceSearchQuery getVectorQuery() { + return vectorQuery; + } + + public LanceSearchQuery getFullTextQuery() { + return fullTextQuery; + } + + public StructType getVectorSchema() { + return vectorSchema; + } + + public StructType getFullTextSchema() { + return fullTextSchema; + } + + public Integer getK() { + return k; + } + + public Integer getOffset() { + return offset; + } + + public Float getRrfK() { + return rrfK; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchRowPartitionReader.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchRowPartitionReader.java new file mode 100644 index 000000000..f343807d5 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchRowPartitionReader.java @@ -0,0 +1,245 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class LanceHybridSearchRowPartitionReader implements PartitionReader { + private static final String DISTANCE_COLUMN = "_distance"; + private static final String FTS_SCORE_COLUMN = "_score"; + private static final String HYBRID_SCORE_COLUMN = "_relevance_score"; + private static final String ROW_ID_COLUMN = "_rowid"; + + private final LanceHybridSearchInputPartition inputPartition; + private List rows; + private int rowIndex = -1; + + public LanceHybridSearchRowPartitionReader(LanceHybridSearchInputPartition inputPartition) { + this.inputPartition = inputPartition; + } + + @Override + public boolean next() throws IOException { + if (rows == null) { + rows = executeHybridSearch(); + } + rowIndex += 1; + return rowIndex < rows.size(); + } + + @Override + public InternalRow get() { + return rows.get(rowIndex); + } + + @Override + public void close() {} + + private List executeHybridSearch() throws IOException { + LanceHybridSearchQuery query = inputPartition.getQuery(); + List vectorRows = + readSide(query.getVectorQuery(), query.getVectorSchema(), DISTANCE_COLUMN); + List fullTextRows = + readSide(query.getFullTextQuery(), query.getFullTextSchema(), FTS_SCORE_COLUMN); + + Map mergedRows = new HashMap<>(); + float rrfK = query.getRrfK(); + for (int i = 0; i < vectorRows.size(); i++) { + SideRow row = vectorRows.get(i); + mergedRows.computeIfAbsent(row.rowId, HybridRow::new).addVector(row, i, rrfK); + } + for (int i = 0; i < fullTextRows.size(); i++) { + SideRow row = fullTextRows.get(i); + mergedRows.computeIfAbsent(row.rowId, HybridRow::new).addFullText(row, i, rrfK); + } + + List sortedRows = new ArrayList<>(mergedRows.values()); + sortedRows.sort( + (left, right) -> { + int relevanceCompare = Double.compare(right.relevance, left.relevance); + if (relevanceCompare != 0) { + return relevanceCompare; + } + int rankCompare = Integer.compare(left.bestRank(), right.bestRank()); + if (rankCompare != 0) { + return rankCompare; + } + return Long.compare(left.rowId, right.rowId); + }); + + int fromIndex = Math.min(query.getOffset(), sortedRows.size()); + int toIndex = Math.min(fromIndex + query.getK(), sortedRows.size()); + List resultRows = new ArrayList<>(toIndex - fromIndex); + for (int i = fromIndex; i < toIndex; i++) { + resultRows.add(toInternalRow(sortedRows.get(i), inputPartition.getSchema())); + } + return resultRows; + } + + private List readSide(LanceSearchQuery query, StructType schema, String metricColumn) + throws IOException { + LanceSearchColumnarPartitionReader reader = + new LanceSearchColumnarPartitionReader(new LanceSearchInputPartition(schema, query)); + try { + return readSideRows(reader, schema, metricColumn); + } finally { + reader.close(); + } + } + + private List readSideRows( + LanceSearchColumnarPartitionReader reader, StructType schema, String metricColumn) + throws IOException { + int rowIdIndex = fieldIndex(schema, ROW_ID_COLUMN); + int metricIndex = fieldIndex(schema, metricColumn); + StructField[] fields = schema.fields(); + List valueFields = new ArrayList<>(); + for (int i = 0; i < fields.length; i++) { + String fieldName = fields[i].name(); + if (!fieldName.equalsIgnoreCase(ROW_ID_COLUMN) && !fieldName.equalsIgnoreCase(metricColumn)) { + valueFields.add(new SideValueField(i, fieldName)); + } + } + + List sideRows = new ArrayList<>(); + while (reader.next()) { + ColumnarBatch batch = reader.get(); + Iterator rowIterator = batch.rowIterator(); + while (rowIterator.hasNext()) { + InternalRow copiedRow = rowIterator.next().copy(); + long rowId = copiedRow.getLong(rowIdIndex); + Float metric = copiedRow.isNullAt(metricIndex) ? null : copiedRow.getFloat(metricIndex); + Map values = new HashMap<>(); + for (SideValueField valueField : valueFields) { + if (copiedRow.isNullAt(valueField.index)) { + values.put(valueField.name, null); + } else { + values.put( + valueField.name, + copiedRow.get(valueField.index, fields[valueField.index].dataType())); + } + } + sideRows.add(new SideRow(rowId, metric, values)); + } + } + return sideRows; + } + + private InternalRow toInternalRow(HybridRow row, StructType outputSchema) { + StructField[] fields = outputSchema.fields(); + Object[] values = new Object[fields.length]; + for (int i = 0; i < fields.length; i++) { + String fieldName = fields[i].name(); + if (fieldName.equalsIgnoreCase(DISTANCE_COLUMN)) { + values[i] = row.distance; + } else if (fieldName.equalsIgnoreCase(FTS_SCORE_COLUMN)) { + values[i] = row.score; + } else if (fieldName.equalsIgnoreCase(HYBRID_SCORE_COLUMN)) { + values[i] = (float) row.relevance; + } else if (fieldName.equalsIgnoreCase(ROW_ID_COLUMN)) { + values[i] = row.rowId; + } else { + values[i] = row.values.get(fieldName); + } + } + return new GenericInternalRow(values); + } + + private int fieldIndex(StructType schema, String name) { + StructField[] fields = schema.fields(); + for (int i = 0; i < fields.length; i++) { + if (fields[i].name().equalsIgnoreCase(name)) { + return i; + } + } + throw new IllegalStateException( + "Expected field '" + name + "' in schema " + schema.treeString()); + } + + private static double rrfScore(int rank, float rrfK) { + return 1.0d / (((double) rank) + (double) rrfK); + } + + private static final class SideValueField { + private final int index; + private final String name; + + private SideValueField(int index, String name) { + this.index = index; + this.name = name; + } + } + + private static final class SideRow { + private final long rowId; + private final Float metric; + private final Map values; + + private SideRow(long rowId, Float metric, Map values) { + this.rowId = rowId; + this.metric = metric; + this.values = values; + } + } + + private static final class HybridRow { + private final long rowId; + private final Map values = new HashMap<>(); + private Float distance; + private Float score; + private double relevance; + private int vectorRank = Integer.MAX_VALUE; + private int fullTextRank = Integer.MAX_VALUE; + + private HybridRow(long rowId) { + this.rowId = rowId; + } + + private void addVector(SideRow row, int rank, float rrfK) { + mergeValues(row); + distance = row.metric; + vectorRank = Math.min(vectorRank, rank); + relevance += rrfScore(rank, rrfK); + } + + private void addFullText(SideRow row, int rank, float rrfK) { + mergeValues(row); + score = row.metric; + fullTextRank = Math.min(fullTextRank, rank); + relevance += rrfScore(rank, rrfK); + } + + private void mergeValues(SideRow row) { + for (Map.Entry entry : row.values.entrySet()) { + values.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + + private int bestRank() { + return Math.min(vectorRank, fullTextRank); + } + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScan.java new file mode 100644 index 000000000..07d65fdeb --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScan.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; + +public class LanceHybridSearchScan implements Scan, Batch, Serializable { + private static final long serialVersionUID = 3798127398172391237L; + + private final StructType schema; + private final LanceHybridSearchQuery query; + + public LanceHybridSearchScan(StructType schema, LanceHybridSearchQuery query) { + this.schema = schema; + this.query = query; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public String description() { + return "LanceHybridSearchScan"; + } + + @Override + public Batch toBatch() { + return this; + } + + @Override + public InputPartition[] planInputPartitions() { + return new InputPartition[] {new LanceHybridSearchInputPartition(schema, query)}; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new LanceHybridSearchPartitionReaderFactory(); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScanBuilder.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScanBuilder.java new file mode 100644 index 000000000..6c9fe4f5e --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchScanBuilder.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; + +public class LanceHybridSearchScanBuilder implements ScanBuilder { + private final StructType schema; + private final LanceHybridSearchQuery query; + + public LanceHybridSearchScanBuilder(StructType schema, LanceHybridSearchQuery query) { + this.schema = schema; + this.query = query; + } + + @Override + public Scan build() { + return new LanceHybridSearchScan(schema, query); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchTable.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchTable.java new file mode 100644 index 000000000..f80959198 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceHybridSearchTable.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import java.util.Collections; +import java.util.Set; + +public class LanceHybridSearchTable implements SupportsRead { + private final String name; + private final StructType schema; + private final LanceHybridSearchQuery query; + + public LanceHybridSearchTable(String name, StructType schema, LanceHybridSearchQuery query) { + this.name = name; + this.schema = schema; + this.query = query; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new LanceHybridSearchScanBuilder(schema, query); + } + + @Override + public String name() { + return name; + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Set capabilities() { + return Collections.singleton(TableCapability.BATCH_READ); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchColumnarPartitionReader.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchColumnarPartitionReader.java new file mode 100644 index 000000000..80c30a344 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchColumnarPartitionReader.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.lance.namespace.LanceNamespace; +import org.lance.spark.LanceRuntime; +import org.lance.spark.vectorized.LanceArrowColumnVector; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowFileReader; +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.io.Closeable; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class LanceSearchColumnarPartitionReader implements PartitionReader { + private final LanceSearchInputPartition inputPartition; + private ArrowFileReader arrowReader; + private ColumnarBatch currentBatch; + private boolean finished; + + public LanceSearchColumnarPartitionReader(LanceSearchInputPartition inputPartition) { + this.inputPartition = inputPartition; + } + + @Override + public boolean next() throws IOException { + if (finished) { + return false; + } + if (arrowReader == null) { + openArrowReader(); + } + if (arrowReader.loadNextBatch()) { + currentBatch = toColumnarBatch(arrowReader.getVectorSchemaRoot(), inputPartition.getSchema()); + return true; + } + finished = true; + return false; + } + + @Override + public ColumnarBatch get() { + return currentBatch; + } + + @Override + public void close() throws IOException { + try { + if (currentBatch != null) { + currentBatch.close(); + } + } finally { + if (arrowReader != null) { + arrowReader.close(); + } + } + } + + private void openArrowReader() throws IOException { + LanceSearchQuery query = inputPartition.getQuery(); + LanceNamespace namespace = + LanceRuntime.getOrCreateNamespace(query.getNamespaceImpl(), query.getNamespaceProperties()); + if (namespace == null) { + throw new IOException("Lance namespace is required for search"); + } + try { + byte[] bytes = namespace.queryTable(query.toQueryTableRequest()); + arrowReader = + new ArrowFileReader( + new ByteArrayReadableSeekableByteChannel(bytes), LanceRuntime.allocator()); + } finally { + if (namespace instanceof Closeable) { + ((Closeable) namespace).close(); + } + } + } + + private ColumnarBatch toColumnarBatch(VectorSchemaRoot root, StructType schema) { + Map actualFields = new HashMap<>(); + for (FieldVector vector : root.getFieldVectors()) { + actualFields.put(vector.getField().getName(), vector); + } + + StructField[] fields = schema.fields(); + ColumnVector[] vectors = new ColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + String fieldName = fields[i].name(); + FieldVector vector = actualFields.get(fieldName); + if (vector == null) { + throw new IllegalStateException( + "Lance search did not return expected field '" + fieldName + "'"); + } + vectors[i] = new LanceArrowColumnVector(vector, false); + } + return new ColumnarBatch(vectors, root.getRowCount()); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchInputPartition.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchInputPartition.java new file mode 100644 index 000000000..546c371ea --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchInputPartition.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.types.StructType; + +public class LanceSearchInputPartition implements InputPartition { + private static final long serialVersionUID = -38612098237192389L; + + private final StructType schema; + private final LanceSearchQuery query; + + public LanceSearchInputPartition(StructType schema, LanceSearchQuery query) { + this.schema = schema; + this.query = query; + } + + public StructType getSchema() { + return schema; + } + + public LanceSearchQuery getQuery() { + return query; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchPartitionReaderFactory.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchPartitionReaderFactory.java new file mode 100644 index 000000000..fb8c3d047 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchPartitionReaderFactory.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +public class LanceSearchPartitionReaderFactory implements PartitionReaderFactory { + private static final long serialVersionUID = -812739812739817239L; + + @Override + public PartitionReader createReader(InputPartition partition) { + return new LanceSearchRowPartitionReader( + new LanceSearchColumnarPartitionReader(asSearchPartition(partition))); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + return new LanceSearchColumnarPartitionReader(asSearchPartition(partition)); + } + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + private LanceSearchInputPartition asSearchPartition(InputPartition partition) { + if (!(partition instanceof LanceSearchInputPartition)) { + throw new IllegalArgumentException( + "Unknown InputPartition type. Expecting LanceSearchInputPartition"); + } + return (LanceSearchInputPartition) partition; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchQuery.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchQuery.java new file mode 100644 index 000000000..5c61f5c74 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchQuery.java @@ -0,0 +1,348 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.lance.namespace.model.QueryTableRequest; +import org.lance.namespace.model.QueryTableRequestColumns; +import org.lance.namespace.model.QueryTableRequestFullTextQuery; +import org.lance.namespace.model.QueryTableRequestVector; +import org.lance.namespace.model.StringFtsQuery; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class LanceSearchQuery implements Serializable { + private static final long serialVersionUID = 538912738912739128L; + + public enum SearchType { + VECTOR, + FULL_TEXT + } + + private final SearchType searchType; + private final List tableId; + private final String namespaceImpl; + private final Map namespaceProperties; + private final List outputColumns; + private final Integer k; + private final Integer offset; + private final Long version; + private final String filter; + private final Boolean withRowId; + private final List vector; + private final String vectorColumn; + private final String distanceType; + private final Integer nprobes; + private final Integer ef; + private final Integer refineFactor; + private final Float lowerBound; + private final Float upperBound; + private final Boolean bypassVectorIndex; + private final Boolean fastSearch; + private final Boolean prefilter; + private final String textQuery; + private final List searchColumns; + + private LanceSearchQuery(Builder builder) { + this.searchType = builder.searchType; + this.tableId = immutableList(builder.tableId); + this.namespaceImpl = builder.namespaceImpl; + this.namespaceProperties = immutableMap(builder.namespaceProperties); + this.outputColumns = immutableList(builder.outputColumns); + this.k = builder.k; + this.offset = builder.offset; + this.version = builder.version; + this.filter = builder.filter; + this.withRowId = builder.withRowId; + this.vector = immutableList(builder.vector); + this.vectorColumn = builder.vectorColumn; + this.distanceType = builder.distanceType; + this.nprobes = builder.nprobes; + this.ef = builder.ef; + this.refineFactor = builder.refineFactor; + this.lowerBound = builder.lowerBound; + this.upperBound = builder.upperBound; + this.bypassVectorIndex = builder.bypassVectorIndex; + this.fastSearch = builder.fastSearch; + this.prefilter = builder.prefilter; + this.textQuery = builder.textQuery; + this.searchColumns = immutableList(builder.searchColumns); + } + + public static Builder builder(SearchType searchType) { + return new Builder(searchType); + } + + public SearchType getSearchType() { + return searchType; + } + + public List getTableId() { + return tableId; + } + + public String getNamespaceImpl() { + return namespaceImpl; + } + + public Map getNamespaceProperties() { + return namespaceProperties; + } + + public QueryTableRequest toQueryTableRequest() { + QueryTableRequest request = new QueryTableRequest().id(tableId).k(k); + request.vector(new QueryTableRequestVector()); + + if (!outputColumns.isEmpty()) { + request.columns(new QueryTableRequestColumns().columnNames(outputColumns)); + } + if (offset != null) { + request.offset(offset); + } + if (version != null) { + request.version(version); + } + if (filter != null) { + request.filter(filter); + } + if (withRowId != null) { + request.withRowId(withRowId); + } + + if (searchType == SearchType.VECTOR) { + request.vector(new QueryTableRequestVector().singleVector(vector)); + if (vectorColumn != null) { + request.vectorColumn(vectorColumn); + } + if (distanceType != null) { + request.distanceType(distanceType); + } + if (nprobes != null) { + request.nprobes(nprobes); + } + if (ef != null) { + request.ef(ef); + } + if (refineFactor != null) { + request.refineFactor(refineFactor); + } + if (lowerBound != null) { + request.lowerBound(lowerBound); + } + if (upperBound != null) { + request.upperBound(upperBound); + } + if (bypassVectorIndex != null) { + request.bypassVectorIndex(bypassVectorIndex); + } + if (fastSearch != null) { + request.fastSearch(fastSearch); + } + if (prefilter != null) { + request.prefilter(prefilter); + } + } else { + StringFtsQuery stringQuery = new StringFtsQuery().query(textQuery); + if (!searchColumns.isEmpty()) { + stringQuery.columns(searchColumns); + } + request.fullTextQuery(new QueryTableRequestFullTextQuery().stringQuery(stringQuery)); + } + + return request; + } + + private static List immutableList(List values) { + if (values == null || values.isEmpty()) { + return Collections.emptyList(); + } + return Collections.unmodifiableList(new ArrayList<>(values)); + } + + private static Map immutableMap(Map values) { + if (values == null || values.isEmpty()) { + return Collections.emptyMap(); + } + return Collections.unmodifiableMap(new HashMap<>(values)); + } + + public static final class Builder { + private final SearchType searchType; + private List tableId = Collections.emptyList(); + private String namespaceImpl; + private Map namespaceProperties = Collections.emptyMap(); + private List outputColumns = Collections.emptyList(); + private Integer k = 10; + private Integer offset; + private Long version; + private String filter; + private Boolean withRowId; + private List vector = Collections.emptyList(); + private String vectorColumn; + private String distanceType; + private Integer nprobes; + private Integer ef; + private Integer refineFactor; + private Float lowerBound; + private Float upperBound; + private Boolean bypassVectorIndex; + private Boolean fastSearch; + private Boolean prefilter; + private String textQuery; + private List searchColumns = Collections.emptyList(); + + private Builder(SearchType searchType) { + this.searchType = searchType; + } + + public Builder tableId(List tableId) { + this.tableId = tableId; + return this; + } + + public Builder namespaceImpl(String namespaceImpl) { + this.namespaceImpl = namespaceImpl; + return this; + } + + public Builder namespaceProperties(Map namespaceProperties) { + this.namespaceProperties = namespaceProperties; + return this; + } + + public Builder outputColumns(List outputColumns) { + this.outputColumns = outputColumns; + return this; + } + + public Builder topK(Integer k) { + this.k = k; + return this; + } + + public Builder offset(Integer offset) { + this.offset = offset; + return this; + } + + public Builder version(Long version) { + this.version = version; + return this; + } + + public Builder filter(String filter) { + this.filter = filter; + return this; + } + + public Builder withRowId(Boolean withRowId) { + this.withRowId = withRowId; + return this; + } + + public Builder vector(List vector) { + this.vector = vector; + return this; + } + + public Builder vectorColumn(String vectorColumn) { + this.vectorColumn = vectorColumn; + return this; + } + + public Builder distanceType(String distanceType) { + this.distanceType = distanceType; + return this; + } + + public Builder nprobes(Integer nprobes) { + this.nprobes = nprobes; + return this; + } + + public Builder ef(Integer ef) { + this.ef = ef; + return this; + } + + public Builder refineFactor(Integer refineFactor) { + this.refineFactor = refineFactor; + return this; + } + + public Builder lowerBound(Float lowerBound) { + this.lowerBound = lowerBound; + return this; + } + + public Builder upperBound(Float upperBound) { + this.upperBound = upperBound; + return this; + } + + public Builder bypassVectorIndex(Boolean bypassVectorIndex) { + this.bypassVectorIndex = bypassVectorIndex; + return this; + } + + public Builder fastSearch(Boolean fastSearch) { + this.fastSearch = fastSearch; + return this; + } + + public Builder prefilter(Boolean prefilter) { + this.prefilter = prefilter; + return this; + } + + public Builder textQuery(String textQuery) { + this.textQuery = textQuery; + return this; + } + + public Builder searchColumns(List searchColumns) { + this.searchColumns = searchColumns; + return this; + } + + public LanceSearchQuery build() { + if (searchType == null) { + throw new IllegalArgumentException("search type is required"); + } + if (tableId == null || tableId.isEmpty()) { + throw new IllegalArgumentException("table id is required"); + } + if (namespaceImpl == null || namespaceImpl.isEmpty()) { + throw new IllegalArgumentException("namespace implementation is required"); + } + if (k == null || k <= 0) { + throw new IllegalArgumentException("k must be positive"); + } + if (offset != null && offset < 0) { + throw new IllegalArgumentException("offset must be non-negative"); + } + if (searchType == SearchType.VECTOR && (vector == null || vector.isEmpty())) { + throw new IllegalArgumentException("query_vector is required"); + } + if (searchType == SearchType.FULL_TEXT && (textQuery == null || textQuery.isEmpty())) { + throw new IllegalArgumentException("query is required"); + } + return new LanceSearchQuery(this); + } + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchRowPartitionReader.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchRowPartitionReader.java new file mode 100644 index 000000000..cb5a7e40d --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchRowPartitionReader.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.io.IOException; +import java.util.Iterator; + +public class LanceSearchRowPartitionReader implements PartitionReader { + private final LanceSearchColumnarPartitionReader reader; + private Iterator currentRows; + private InternalRow currentRecord; + + public LanceSearchRowPartitionReader(LanceSearchColumnarPartitionReader reader) { + this.reader = reader; + } + + @Override + public boolean next() throws IOException { + if (currentRows != null && currentRows.hasNext()) { + currentRecord = currentRows.next(); + return true; + } + if (reader.next()) { + ColumnarBatch currentBatch = reader.get(); + currentRows = currentBatch.rowIterator(); + if (currentRows != null && currentRows.hasNext()) { + currentRecord = currentRows.next(); + return true; + } + } + return false; + } + + @Override + public InternalRow get() { + return currentRecord; + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + return reader.currentMetricsValues(); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScan.java new file mode 100644 index 000000000..75ad8d1c7 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScan.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; + +public class LanceSearchScan implements Scan, Batch, Serializable { + private static final long serialVersionUID = -120398471239847123L; + + private final StructType schema; + private final LanceSearchQuery query; + + public LanceSearchScan(StructType schema, LanceSearchQuery query) { + this.schema = schema; + this.query = query; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public String description() { + return "LanceSearchScan"; + } + + @Override + public Batch toBatch() { + return this; + } + + @Override + public InputPartition[] planInputPartitions() { + return new InputPartition[] {new LanceSearchInputPartition(schema, query)}; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new LanceSearchPartitionReaderFactory(); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScanBuilder.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScanBuilder.java new file mode 100644 index 000000000..f29260d83 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchScanBuilder.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; + +public class LanceSearchScanBuilder implements ScanBuilder { + private final StructType schema; + private final LanceSearchQuery query; + + public LanceSearchScanBuilder(StructType schema, LanceSearchQuery query) { + this.schema = schema; + this.query = query; + } + + @Override + public Scan build() { + return new LanceSearchScan(schema, query); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchTable.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchTable.java new file mode 100644 index 000000000..c34a65165 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/search/LanceSearchTable.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import java.util.Collections; +import java.util.Set; + +public class LanceSearchTable implements SupportsRead { + private final String name; + private final StructType schema; + private final LanceSearchQuery query; + + public LanceSearchTable(String name, StructType schema, LanceSearchQuery query) { + this.name = name; + this.schema = schema; + this.query = query; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new LanceSearchScanBuilder(schema, query); + } + + @Override + public String name() { + return name; + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Set capabilities() { + return Collections.singleton(TableCapability.BATCH_READ); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/QueryUtils.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/QueryUtils.java deleted file mode 100644 index abf19ad1c..000000000 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/QueryUtils.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark.utils; - -import org.lance.ipc.Query; - -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; -import com.fasterxml.jackson.databind.module.SimpleModule; - -import java.io.IOException; -import java.util.Optional; - -public class QueryUtils { - private static final ObjectMapper MAPPER = new ObjectMapper(); - - static { - SimpleModule module = new SimpleModule(); - module.addSerializer(Query.class, new QuerySerializer()); - MAPPER.registerModule(module); - MAPPER.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - MAPPER.addMixIn(Query.class, QueryMixin.class); - MAPPER.addMixIn(Query.Builder.class, QueryBuilderMixin.class); - } - - private QueryUtils() {} - - public static String queryToString(Query query) { - if (query == null) { - return null; - } - try { - return MAPPER.writeValueAsString(query); - } catch (JsonProcessingException e) { - throw new RuntimeException("Failed to serialize query", e); - } - } - - public static Query stringToQuery(String json) { - if (json == null) { - return null; - } - try { - return MAPPER.readValue(json, Query.class); - } catch (JsonProcessingException e) { - throw new RuntimeException("Failed to deserialize query", e); - } - } - - private static class QuerySerializer extends JsonSerializer { - @Override - public void serialize(Query value, JsonGenerator gen, SerializerProvider serializers) - throws IOException { - gen.writeStartObject(); - if (value.getColumn() != null) { - gen.writeStringField("column", value.getColumn()); - } - gen.writeNumberField("k", value.getK()); - if (value.getKey() != null) { - gen.writeFieldName("key"); - float[] key = value.getKey(); - gen.writeStartArray(); - for (float f : key) { - gen.writeNumber(f); - } - gen.writeEndArray(); - } - gen.writeNumberField("minimumNprobes", value.getMinimumNprobes()); - - writeOptional(gen, "maximumNprobes", value.getMaximumNprobes()); - writeOptional(gen, "ef", value.getEf()); - writeOptional(gen, "refineFactor", value.getRefineFactor()); - - gen.writeBooleanField("useIndex", value.isUseIndex()); - writeOptional(gen, "distanceType", value.getDistanceType()); - gen.writeEndObject(); - } - - private void writeOptional(JsonGenerator gen, String fieldName, Optional opt) - throws IOException { - if (opt != null && opt.isPresent()) { - gen.writeObjectField(fieldName, opt.get()); - } - } - } - - @JsonDeserialize(builder = Query.Builder.class) - private abstract static class QueryMixin {} - - @JsonPOJOBuilder(withPrefix = "set") - private abstract static class QueryBuilderMixin {} -} diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/search/LanceSearchTableFunctions.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/search/LanceSearchTableFunctions.scala new file mode 100644 index 000000000..ce117c8eb --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/search/LanceSearchTableFunctions.scala @@ -0,0 +1,564 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.lance.spark.LanceDataset +import org.lance.spark.search.LanceSearchQuery.SearchType + +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +object LanceSearchTableFunctions { + private val DefaultK: Integer = Integer.valueOf(10) + private val TableArg = "table" + private val QueryVectorArg = "query_vector" + private val QueryArg = "query" + private val SearchQueryArg = "search_query" + private val ColumnsArg = "columns" + private val SearchColumnsArg = "search_columns" + private val DistanceMetricColumn = "_distance" + private val FtsScoreColumn = "_score" + private val HybridScoreColumn = "_relevance_score" + private val RowIdColumn = "_rowid" + private val NamedArgumentExpressionClass = + "org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression" + + def vectorSearch(args: Seq[Expression]): LogicalPlan = { + val parsed = parseArgs("VECTOR_SEARCH", args, Seq(TableArg, QueryVectorArg, "k")) + val tableName = requiredString(parsed, TableArg) + val queryVector = requiredFloatArray(parsed, QueryVectorArg) + val k = optionalInt(parsed, "num_results") + .orElse(optionalInt(parsed, "limit")) + .orElse(optionalInt(parsed, "k")) + .getOrElse(DefaultK) + val outputColumns = optionalStringArray(parsed, ColumnsArg).getOrElse(Seq.empty) + val filter = optionalString(parsed, "filter").orNull + val withRowId = effectiveWithRowId(parsed, outputColumns) + val offset = optionalInt(parsed, "offset") + val version = optionalLong(parsed, "version") + val requestK = offset + .map(value => Integer.valueOf(Math.addExact(k.intValue(), value.intValue()))) + .getOrElse(k) + val resolved = resolveLanceTable(tableName, version) + val schemaColumns = + requestOutputColumns(resolved.table.schema(), outputColumns, DistanceMetricColumn) + val requestColumns = namespaceOutputColumns(schemaColumns) + val schema = + outputSchema(resolved.table.schema(), schemaColumns, DistanceMetricColumn, withRowId) + + val builder = LanceSearchQuery + .builder(SearchType.VECTOR) + .tableId(resolved.table.readOptions().getTableId) + .namespaceImpl(resolved.table.getNamespaceImpl) + .namespaceProperties(resolved.table.getNamespaceProperties) + .outputColumns(requestColumns.asJava) + .vector(queryVector.asJava) + .topK(requestK) + .vectorColumn(optionalString(parsed, "vector_column").orNull) + .distanceType(optionalString(parsed, "distance_type").orNull) + .filter(filter) + .offset(offset.orNull) + .version(version.orNull) + .withRowId(withRowId.orNull) + .nprobes(optionalInt(parsed, "nprobes").orNull) + .ef(optionalInt(parsed, "ef").orNull) + .refineFactor(optionalInt(parsed, "refine_factor").orNull) + .lowerBound(optionalFloat(parsed, "lower_bound").orNull) + .upperBound(optionalFloat(parsed, "upper_bound").orNull) + .bypassVectorIndex(optionalBoolean(parsed, "bypass_vector_index").orNull) + .fastSearch(optionalBoolean(parsed, "fast_search").orNull) + .prefilter(optionalBoolean(parsed, "prefilter").orNull) + + relation("VECTOR_SEARCH", schema, builder.build(), resolved) + } + + def search(args: Seq[Expression]): LogicalPlan = { + val parsed = parseArgs("SEARCH", args, Seq(TableArg, QueryArg, "k")) + val tableName = requiredString(parsed, TableArg) + val queryText = optionalString(parsed, QueryArg) + .orElse(optionalString(parsed, SearchQueryArg)) + .getOrElse(throw new IllegalArgumentException("SEARCH requires query")) + val k = optionalInt(parsed, "num_results") + .orElse(optionalInt(parsed, "limit")) + .orElse(optionalInt(parsed, "k")) + .getOrElse(DefaultK) + val outputColumns = optionalStringArray(parsed, ColumnsArg).getOrElse(Seq.empty) + val filter = optionalString(parsed, "filter").orNull + val withRowId = effectiveWithRowId(parsed, outputColumns) + val version = optionalLong(parsed, "version") + val resolved = resolveLanceTable(tableName, version) + val schemaColumns = requestOutputColumns(resolved.table.schema(), outputColumns, FtsScoreColumn) + val requestColumns = namespaceOutputColumns(schemaColumns) + val schema = outputSchema(resolved.table.schema(), schemaColumns, FtsScoreColumn, withRowId) + + val builder = LanceSearchQuery + .builder(SearchType.FULL_TEXT) + .tableId(resolved.table.readOptions().getTableId) + .namespaceImpl(resolved.table.getNamespaceImpl) + .namespaceProperties(resolved.table.getNamespaceProperties) + .outputColumns(requestColumns.asJava) + .textQuery(queryText) + .searchColumns(optionalStringArray(parsed, SearchColumnsArg).getOrElse(Seq.empty).asJava) + .topK(k) + .filter(filter) + .offset(optionalInt(parsed, "offset").orNull) + .version(version.orNull) + .withRowId(withRowId.orNull) + + relation("SEARCH", schema, builder.build(), resolved) + } + + def hybridSearch(args: Seq[Expression]): LogicalPlan = { + val parsed = parseArgs("HYBRID_SEARCH", args, Seq(TableArg, QueryVectorArg, QueryArg, "k")) + val tableName = requiredString(parsed, TableArg) + val queryVector = requiredFloatArray(parsed, QueryVectorArg) + val queryText = optionalString(parsed, QueryArg) + .orElse(optionalString(parsed, SearchQueryArg)) + .getOrElse(throw new IllegalArgumentException("HYBRID_SEARCH requires query")) + val k = optionalInt(parsed, "num_results") + .orElse(optionalInt(parsed, "limit")) + .orElse(optionalInt(parsed, "k")) + .getOrElse(DefaultK) + val offset = optionalInt(parsed, "offset").getOrElse(Integer.valueOf(0)) + val candidateK = optionalInt(parsed, "candidates") + .orElse(optionalInt(parsed, "num_candidates")) + .orElse(optionalInt(parsed, "candidate_count")) + .map(value => Integer.valueOf(math.max(value.intValue(), k.intValue() + offset.intValue()))) + .getOrElse(Integer.valueOf(k.intValue() + offset.intValue())) + val outputColumns = optionalStringArray(parsed, ColumnsArg).getOrElse(Seq.empty) + val filter = optionalString(parsed, "filter").orNull + val withRowId = effectiveWithRowId(parsed, outputColumns) + val rrfK = optionalFloat(parsed, "rrf_k").getOrElse(java.lang.Float.valueOf(60.0f)) + val version = optionalLong(parsed, "version") + + val resolved = resolveLanceTable(tableName, version) + val vectorSchemaColumns = + hybridSideOutputColumns(resolved.table.schema(), outputColumns, DistanceMetricColumn) + val ftsSchemaColumns = + hybridSideOutputColumns(resolved.table.schema(), outputColumns, FtsScoreColumn) + val vectorSchema = + outputSchema( + resolved.table.schema(), + vectorSchemaColumns, + DistanceMetricColumn, + Some(java.lang.Boolean.TRUE)) + val ftsSchema = + outputSchema( + resolved.table.schema(), + ftsSchemaColumns, + FtsScoreColumn, + Some(java.lang.Boolean.TRUE)) + val schema = hybridOutputSchema(resolved.table.schema(), outputColumns, withRowId) + + val vectorRequestColumns = namespaceOutputColumns(vectorSchemaColumns) + val ftsRequestColumns = namespaceOutputColumns(ftsSchemaColumns) + + val vectorQuery = LanceSearchQuery + .builder(SearchType.VECTOR) + .tableId(resolved.table.readOptions().getTableId) + .namespaceImpl(resolved.table.getNamespaceImpl) + .namespaceProperties(resolved.table.getNamespaceProperties) + .outputColumns(vectorRequestColumns.asJava) + .vector(queryVector.asJava) + .topK(candidateK) + .vectorColumn(optionalString(parsed, "vector_column").orNull) + .distanceType(optionalString(parsed, "distance_type").orNull) + .filter(filter) + .version(version.orNull) + .withRowId(java.lang.Boolean.TRUE) + .nprobes(optionalInt(parsed, "nprobes").orNull) + .ef(optionalInt(parsed, "ef").orNull) + .refineFactor(optionalInt(parsed, "refine_factor").orNull) + .lowerBound(optionalFloat(parsed, "lower_bound").orNull) + .upperBound(optionalFloat(parsed, "upper_bound").orNull) + .bypassVectorIndex(optionalBoolean(parsed, "bypass_vector_index").orNull) + .fastSearch(optionalBoolean(parsed, "fast_search").orNull) + .prefilter(optionalBoolean(parsed, "prefilter").orNull) + .build() + + val ftsQuery = LanceSearchQuery + .builder(SearchType.FULL_TEXT) + .tableId(resolved.table.readOptions().getTableId) + .namespaceImpl(resolved.table.getNamespaceImpl) + .namespaceProperties(resolved.table.getNamespaceProperties) + .outputColumns(ftsRequestColumns.asJava) + .textQuery(queryText) + .searchColumns(optionalStringArray(parsed, SearchColumnsArg).getOrElse(Seq.empty).asJava) + .topK(candidateK) + .filter(filter) + .version(version.orNull) + .withRowId(java.lang.Boolean.TRUE) + .build() + + val hybridQuery = + new LanceHybridSearchQuery(vectorQuery, ftsQuery, vectorSchema, ftsSchema, k, offset, rrfK) + hybridRelation("HYBRID_SEARCH", schema, hybridQuery, resolved) + } + + private def relation( + functionName: String, + schema: StructType, + query: LanceSearchQuery, + resolved: ResolvedLanceTable): LogicalPlan = { + val table = new LanceSearchTable(functionName, schema, query) + DataSourceV2Relation.create( + table, + Some(resolved.catalog), + Some(resolved.identifier), + CaseInsensitiveStringMap.empty()) + } + + private def hybridRelation( + functionName: String, + schema: StructType, + query: LanceHybridSearchQuery, + resolved: ResolvedLanceTable): LogicalPlan = { + val table = new LanceHybridSearchTable(functionName, schema, query) + DataSourceV2Relation.create( + table, + Some(resolved.catalog), + Some(resolved.identifier), + CaseInsensitiveStringMap.empty()) + } + + private def parseArgs( + functionName: String, + args: Seq[Expression], + positionalNames: Seq[String]): ParsedArgs = { + val named = scala.collection.mutable.LinkedHashMap.empty[String, Expression] + val positional = ArrayBuffer.empty[Expression] + + args.foreach { expr => + namedArgument(expr) match { + case Some((key, value)) => + named.put(normalizeName(key), value) + case None => + positional += expr + } + } + + if (named.nonEmpty && positional.nonEmpty) { + throw new IllegalArgumentException( + s"$functionName does not support mixing named and positional arguments") + } + if (named.nonEmpty) { + ParsedArgs(named.toMap) + } else { + if (positional.size > positionalNames.size) { + throw new IllegalArgumentException(s"$functionName received too many positional arguments") + } + ParsedArgs(positionalNames.zip(positional).map { case (name, expr) => name -> expr }.toMap) + } + } + + private def requiredString(parsed: ParsedArgs, name: String): String = + optionalString(parsed, name).getOrElse(throw new IllegalArgumentException(s"$name is required")) + + private def requiredFloatArray(parsed: ParsedArgs, name: String): Seq[java.lang.Float] = + optionalFloatArray(parsed, name).getOrElse( + throw new IllegalArgumentException(s"$name is required")) + + private def optionalString(parsed: ParsedArgs, name: String): Option[String] = + parsed.get(name).map { + case attr: UnresolvedAttribute => attr.name + case expr => literalValue(expr) match { + case value: org.apache.spark.unsafe.types.UTF8String => value.toString + case value: String => value + case null => null + case other => other.toString + } + } + + private def optionalStringArray(parsed: ParsedArgs, name: String): Option[Seq[String]] = + parsed.get(name).map(expr => + literalArray(expr).map { + case value: org.apache.spark.unsafe.types.UTF8String => value.toString + case value: String => value + case other => other.toString + }) + + private def optionalFloatArray(parsed: ParsedArgs, name: String): Option[Seq[java.lang.Float]] = + parsed.get(name).map(expr => literalArray(expr).map(toFloat)) + + private def optionalInt(parsed: ParsedArgs, name: String): Option[Integer] = + parsed.get(name).map(value => Integer.valueOf(toNumber(literalValue(value)).intValue())) + + private def optionalLong(parsed: ParsedArgs, name: String): Option[java.lang.Long] = + parsed.get(name).map(value => java.lang.Long.valueOf(toNumber(literalValue(value)).longValue())) + + private def optionalFloat(parsed: ParsedArgs, name: String): Option[java.lang.Float] = + parsed.get(name).map(value => toFloat(literalValue(value))) + + private def optionalBoolean(parsed: ParsedArgs, name: String): Option[java.lang.Boolean] = + parsed.get(name).map(value => + java.lang.Boolean.valueOf(literalValue(value).asInstanceOf[Boolean])) + + private def literalArray(expr: Expression): Seq[Any] = expr match { + case array: CreateArray => array.children.map(literalValue) + case literal: Literal if literal.value == null => Seq.empty + case other => + literalValue(other) match { + case values: Seq[_] => values + case values: Array[_] => values.toSeq + case value => + throw new IllegalArgumentException(s"Expected array literal, got $value") + } + } + + private def literalValue(expr: Expression): Any = expr match { + case literal: Literal => literal.value + case array: CreateArray => array.children.map(literalValue) + case other if other.foldable => other.eval(InternalRow.empty) + case other => + throw new IllegalArgumentException(s"Argument must be a foldable literal: ${other.sql}") + } + + private def toNumber(value: Any): Number = value match { + case number: Number => number + case decimal: org.apache.spark.sql.types.Decimal => decimal.toJavaBigDecimal + case decimal: scala.math.BigDecimal => decimal.bigDecimal + case other => throw new IllegalArgumentException(s"Expected numeric literal, got $other") + } + + private def toFloat(value: Any): java.lang.Float = + java.lang.Float.valueOf(toNumber(value).floatValue()) + + private def namedArgument(expr: Expression): Option[(String, Expression)] = { + if (expr.getClass.getName != NamedArgumentExpressionClass) { + None + } else { + val key = expr.getClass.getMethod("key").invoke(expr).asInstanceOf[String] + val value = expr.getClass.getMethod("value").invoke(expr).asInstanceOf[Expression] + Some((key, value)) + } + } + + private def normalizeName(name: String): String = + name.toLowerCase(Locale.ROOT) + + private def normalizeOutputColumns(columns: Seq[String]): Seq[String] = + if (columns.exists(_ == "*")) { + Seq.empty + } else { + columns + } + + private def requestOutputColumns( + baseSchema: StructType, + columns: Seq[String], + metricName: String): Seq[String] = { + val normalizedColumns = normalizeOutputColumns(columns) + if (normalizedColumns.isEmpty) { + normalizedColumns + } else { + val resolvedColumns = normalizedColumns.map { column => + if (column.equalsIgnoreCase(metricName)) { + metricName + } else if (column.equalsIgnoreCase(RowIdColumn)) { + RowIdColumn + } else { + findField(baseSchema, column).name + } + } + if (resolvedColumns.exists(_.equalsIgnoreCase(metricName))) { + resolvedColumns + } else { + resolvedColumns :+ metricName + } + } + } + + private def hybridSideOutputColumns( + baseSchema: StructType, + columns: Seq[String], + metricName: String): Seq[String] = { + val normalizedColumns = normalizeOutputColumns(columns) + if (normalizedColumns.isEmpty) { + normalizedColumns + } else { + val resolvedColumns = normalizedColumns.flatMap { column => + if (isHybridMetricColumn(column) || column.equalsIgnoreCase(RowIdColumn)) { + None + } else { + Some(findField(baseSchema, column).name) + } + } + if (resolvedColumns.exists(_.equalsIgnoreCase(metricName))) { + resolvedColumns + } else { + resolvedColumns :+ metricName + } + } + } + + private def namespaceOutputColumns(columns: Seq[String]): Seq[String] = + normalizeOutputColumns(columns).filterNot(_.equalsIgnoreCase(RowIdColumn)) + + private def effectiveWithRowId( + parsed: ParsedArgs, + outputColumns: Seq[String]): Option[java.lang.Boolean] = + if (normalizeOutputColumns(outputColumns).exists(_.equalsIgnoreCase(RowIdColumn))) { + Some(java.lang.Boolean.TRUE) + } else { + optionalBoolean(parsed, "with_row_id") + } + + private def outputSchema( + baseSchema: StructType, + columns: Seq[String], + metricName: String, + withRowId: Option[java.lang.Boolean]): StructType = { + val normalizedColumns = normalizeOutputColumns(columns) + val fields = + if (normalizedColumns.isEmpty) { + baseSchema.fields.toSeq + } else { + normalizedColumns.map { column => + if (column.equalsIgnoreCase(metricName)) { + StructField(metricName, DataTypes.FloatType, nullable = true) + } else if (column.equalsIgnoreCase(RowIdColumn)) { + StructField(RowIdColumn, DataTypes.LongType, nullable = true) + } else { + findField(baseSchema, column) + } + } + } + val result = ArrayBuffer(fields: _*) + if (!result.exists(_.name.equalsIgnoreCase(metricName))) { + result += StructField(metricName, DataTypes.FloatType, nullable = true) + } + if (withRowId.contains(java.lang.Boolean.TRUE) && + !result.exists(_.name.equalsIgnoreCase(RowIdColumn))) { + result += StructField(RowIdColumn, DataTypes.LongType, nullable = true) + } + new StructType(result.toArray) + } + + private def hybridOutputSchema( + baseSchema: StructType, + columns: Seq[String], + withRowId: Option[java.lang.Boolean]): StructType = { + val normalizedColumns = normalizeOutputColumns(columns) + val fields = + if (normalizedColumns.isEmpty) { + baseSchema.fields.toSeq + } else { + normalizedColumns.map { column => + if (column.equalsIgnoreCase(DistanceMetricColumn)) { + StructField(DistanceMetricColumn, DataTypes.FloatType, nullable = true) + } else if (column.equalsIgnoreCase(FtsScoreColumn)) { + StructField(FtsScoreColumn, DataTypes.FloatType, nullable = true) + } else if (column.equalsIgnoreCase(HybridScoreColumn)) { + StructField(HybridScoreColumn, DataTypes.FloatType, nullable = false) + } else if (column.equalsIgnoreCase(RowIdColumn)) { + StructField(RowIdColumn, DataTypes.LongType, nullable = true) + } else { + findField(baseSchema, column) + } + } + } + val result = ArrayBuffer(fields: _*) + if (!result.exists(_.name.equalsIgnoreCase(DistanceMetricColumn))) { + result += StructField(DistanceMetricColumn, DataTypes.FloatType, nullable = true) + } + if (!result.exists(_.name.equalsIgnoreCase(FtsScoreColumn))) { + result += StructField(FtsScoreColumn, DataTypes.FloatType, nullable = true) + } + if (!result.exists(_.name.equalsIgnoreCase(HybridScoreColumn))) { + result += StructField(HybridScoreColumn, DataTypes.FloatType, nullable = false) + } + if (withRowId.contains(java.lang.Boolean.TRUE) && + !result.exists(_.name.equalsIgnoreCase(RowIdColumn))) { + result += StructField(RowIdColumn, DataTypes.LongType, nullable = true) + } + new StructType(result.toArray) + } + + private def isHybridMetricColumn(column: String): Boolean = + column.equalsIgnoreCase(DistanceMetricColumn) || + column.equalsIgnoreCase(FtsScoreColumn) || + column.equalsIgnoreCase(HybridScoreColumn) + + private def findField(schema: StructType, column: String): StructField = + schema.fields + .find(_.name.equalsIgnoreCase(column)) + .getOrElse(throw new IllegalArgumentException(s"Unknown column '$column'")) + + private def resolveLanceTable( + tableName: String, + version: Option[java.lang.Long]): ResolvedLanceTable = { + val spark = SparkSession.active + val parts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName).toArray + if (parts.isEmpty) { + throw new IllegalArgumentException("table is required") + } + val catalogManager = spark.sessionState.catalogManager + val (catalog, identifier) = + if (parts.length >= 2 && isConfiguredCatalog(spark, parts.head)) { + val catalog = catalogManager.catalog(parts.head) + val identParts = parts.tail + (catalog, Identifier.of(identParts.dropRight(1), identParts.last)) + } else { + val namespace = + if (parts.length > 1) parts.dropRight(1) else catalogManager.currentNamespace + (catalogManager.currentCatalog, Identifier.of(namespace, parts.last)) + } + + val tableCatalog = catalog match { + case tableCatalog: TableCatalog => tableCatalog + case other => + throw new IllegalArgumentException(s"Catalog '${other.name()}' is not a table catalog") + } + + val table = + version + .map(value => tableCatalog.loadTable(identifier, value.toString)) + .getOrElse(tableCatalog.loadTable(identifier)) + table match { + case table: LanceDataset => + if (table.getNamespaceImpl == null || table.readOptions().getTableId == null) { + throw new IllegalArgumentException( + "Lance search table functions require namespace-backed Lance tables") + } + ResolvedLanceTable(tableCatalog, identifier, table) + case other => + throw new IllegalArgumentException( + s"Table '$tableName' is not a Lance table: ${other.getClass.getName}") + } + } + + private def isConfiguredCatalog(spark: SparkSession, name: String): Boolean = + spark.sessionState.catalogManager.isCatalogRegistered(name) || + spark.conf.getOption(s"spark.sql.catalog.$name").isDefined + + private case class ParsedArgs(args: Map[String, Expression]) { + def get(name: String): Option[Expression] = args.get(normalizeName(name)) + } + + private case class ResolvedLanceTable( + catalog: TableCatalog, + identifier: Identifier, + table: LanceDataset) +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseFloat16VectorTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseFloat16VectorTest.java index 2dfd14c70..8d714579e 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseFloat16VectorTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseFloat16VectorTest.java @@ -13,10 +13,7 @@ */ package org.lance.spark; -import org.lance.index.DistanceType; -import org.lance.ipc.Query; import org.lance.spark.utils.Float16Utils; -import org.lance.spark.utils.QueryUtils; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -57,6 +54,8 @@ void setup() { "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") .config("spark.sql.catalog." + catalogName + ".impl", "dir") .config("spark.sql.catalog." + catalogName + "." + "root", tempDir.toString()) + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") .getOrCreate(); spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); } @@ -519,52 +518,37 @@ public void testFloat16VectorSearchKnn() { Assumptions.assumeTrue( Float16Utils.isFloat2VectorAvailable(), "Float16 requires Arrow 18+ (Spark 4.0+)"); - // Create a float16 vector dataset via DataSource API so the URI is deterministic. - String datasetUri = tempDir.toString() + "/float16_knn_dataset"; - - // Build rows with known vectors so we can predict KNN results. - // Vector 0: [0, 0, 0, 0] — closest to query [0, 0, 0, 0] - // Vector 1: [1, 1, 1, 1] — L2 distance = 4 - // Vector 2: [10, 10, 10, 10] — L2 distance = 400 - // Vector 3: [100, 100, 100, 100] — L2 distance = 40000 - List writeRows = new ArrayList<>(); - writeRows.add(RowFactory.create(0, new float[] {0.0f, 0.0f, 0.0f, 0.0f})); - writeRows.add(RowFactory.create(1, new float[] {1.0f, 1.0f, 1.0f, 1.0f})); - writeRows.add(RowFactory.create(2, new float[] {10.0f, 10.0f, 10.0f, 10.0f})); - writeRows.add(RowFactory.create(3, new float[] {100.0f, 100.0f, 100.0f, 100.0f})); - - Metadata vecMetadata = - new MetadataBuilder() - .putLong("arrow.fixed-size-list.size", 4) - .putString("arrow.float16", "true") - .build(); - StructType schema = - new StructType( - new StructField[] { - DataTypes.createStructField("id", DataTypes.IntegerType, false), - DataTypes.createStructField( - "vec", DataTypes.createArrayType(DataTypes.FloatType, false), false, vecMetadata) - }); - - Dataset df = spark.createDataFrame(writeRows, schema); - df.write().format(LanceDataSource.name).save(datasetUri); - - // Build KNN query: find 2 nearest neighbors to [0, 0, 0, 0] - Query.Builder builder = new Query.Builder(); - builder.setK(2); - builder.setColumn("vec"); - builder.setKey(new float[] {0.0f, 0.0f, 0.0f, 0.0f}); - builder.setUseIndex(false); // brute-force scan (no index needed) - builder.setDistanceType(DistanceType.L2); + String tableName = "float16_knn_" + System.currentTimeMillis(); + String fullName = catalogName + ".default." + tableName; + spark.sql( + "CREATE TABLE " + + fullName + + " (id INT NOT NULL, vec ARRAY NOT NULL) USING lance " + + "TBLPROPERTIES (" + + "'vec.arrow.fixed-size-list.size' = '4', " + + "'vec.arrow.float16' = 'true'" + + ")"); + spark.sql( + "INSERT INTO " + + fullName + + " VALUES " + + "(0, array(0.0, 0.0, 0.0, 0.0)), " + + "(1, array(1.0, 1.0, 1.0, 1.0)), " + + "(2, array(10.0, 10.0, 10.0, 10.0)), " + + "(3, array(100.0, 100.0, 100.0, 100.0))"); - // Read via DataSource API with vector search Dataset result = - spark - .read() - .format(LanceDataSource.name) - .option(LanceSparkReadOptions.CONFIG_NEAREST, QueryUtils.queryToString(builder.build())) - .option(LanceSparkReadOptions.CONFIG_DATASET_URI, datasetUri) - .load(); + spark.sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + fullName + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "vector_column => 'vec', " + + "num_results => 2, " + + "distance_type => 'l2', " + + "bypass_vector_index => true, " + + "columns => array('id', 'vec'))"); List rows = result.collectAsList(); // K=2, so we expect 2 results diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsJsonTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsJsonTest.java deleted file mode 100644 index 83cc6ebcc..000000000 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsJsonTest.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark; - -import org.lance.index.DistanceType; -import org.lance.ipc.Query; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.HashMap; -import java.util.Map; - -public class LanceSparkReadOptionsJsonTest { - - @Test - public void testNearestJsonSerialization() { - Query.Builder builder = new Query.Builder(); - builder.setK(10); - builder.setColumn("vector_col"); - builder.setRefineFactor(2); - builder.setKey(new float[] {1.0f, 2.0f, 3.0f}); - builder.setMinimumNprobes(5); - builder.setMaximumNprobes(20); - builder.setEf(100); - builder.setDistanceType(DistanceType.L2); - builder.setUseIndex(true); - - Query query = builder.build(); - - LanceSparkReadOptions options = - LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").nearest(query).build(); - - String json = options.getNearestJson(); - Assertions.assertNotNull(json); - System.out.println("Serialized JSON: " + json); - - // Test deserialization via fromOptions - Map properties = new HashMap<>(); - properties.put("path", "s3://bucket/path"); - properties.put("nearest", json); - - LanceSparkReadOptions deserializedOptions = LanceSparkReadOptions.from(properties); - Query deserializedQuery = deserializedOptions.getNearest(); - - Assertions.assertNotNull(deserializedQuery); - Assertions.assertEquals(query.getK(), deserializedQuery.getK()); - Assertions.assertEquals(query.getColumn(), deserializedQuery.getColumn()); - - // Check RefineFactor (Optional) - Assertions.assertTrue(deserializedQuery.getRefineFactor().isPresent()); - Assertions.assertEquals(Integer.valueOf(2), deserializedQuery.getRefineFactor().get()); - - Assertions.assertArrayEquals(query.getKey(), deserializedQuery.getKey()); - - // Check new fields - Assertions.assertEquals(query.getMinimumNprobes(), deserializedQuery.getMinimumNprobes()); - - Assertions.assertTrue(deserializedQuery.getMaximumNprobes().isPresent()); - Assertions.assertEquals(Integer.valueOf(20), deserializedQuery.getMaximumNprobes().get()); - - Assertions.assertTrue(deserializedQuery.getEf().isPresent()); - Assertions.assertEquals(Integer.valueOf(100), deserializedQuery.getEf().get()); - - Assertions.assertTrue(deserializedQuery.getDistanceType().isPresent()); - Assertions.assertEquals(DistanceType.L2, deserializedQuery.getDistanceType().get()); - - Assertions.assertEquals(query.isUseIndex(), deserializedQuery.isUseIndex()); - } - - @Test - public void testNearestJsonStringInput() { - // We use "set" prefix in Mixin configuration, but Jackson maps "k" to "setK". - // The JSON should use property names "k", "key", "column". - String json = "{\"column\":\"vector_col\",\"k\":10,\"refineFactor\":2,\"key\":[1.0,2.0,3.0]}"; - - LanceSparkReadOptions options = - LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").nearest(json).build(); - - Query query = options.getNearest(); - Assertions.assertNotNull(query); - Assertions.assertEquals(10, query.getK()); - Assertions.assertEquals("vector_col", query.getColumn()); - Assertions.assertArrayEquals(new float[] {1.0f, 2.0f, 3.0f}, query.getKey()); - } -} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java index 6567db776..5ab77f1f7 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java @@ -13,8 +13,6 @@ */ package org.lance.spark; -import org.lance.ipc.Query; - import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -29,87 +27,6 @@ public class LanceSparkReadOptionsSerializationTest { - @Test - public void testJavaSerialization() throws IOException, ClassNotFoundException { - String json = "{\"column\":\"vector_col\",\"k\":10,\"key\":[1.0,2.0,3.0]}"; - - LanceSparkReadOptions options = - LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").nearest(json).build(); - - Query originalQuery = options.getNearest(); - Assertions.assertNotNull(originalQuery); - - // Serialize - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(options); - oos.close(); - - // Deserialize - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - ObjectInputStream ois = new ObjectInputStream(bais); - LanceSparkReadOptions deserializedOptions = (LanceSparkReadOptions) ois.readObject(); - - Query deserializedQuery = deserializedOptions.getNearest(); - - Assertions.assertNotNull( - deserializedQuery, "Nearest query should not be null after deserialization"); - Assertions.assertEquals(originalQuery.getK(), deserializedQuery.getK()); - Assertions.assertEquals(originalQuery.getColumn(), deserializedQuery.getColumn()); - Assertions.assertArrayEquals(originalQuery.getKey(), deserializedQuery.getKey()); - } - - @Test - public void testUseIndexSerialization() throws IOException, ClassNotFoundException { - // Case 1: useIndex is explicitly set to false - String jsonFalse = - "{\"column\":\"vector_col\",\"k\":10,\"key\":[1.0,2.0,3.0],\"useIndex\":false}"; - LanceSparkReadOptions optionsFalse = - LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").nearest(jsonFalse).build(); - - Query queryFalse = optionsFalse.getNearest(); - Assertions.assertFalse(queryFalse.isUseIndex()); - - // Serialize - ByteArrayOutputStream baosFalse = new ByteArrayOutputStream(); - ObjectOutputStream oosFalse = new ObjectOutputStream(baosFalse); - oosFalse.writeObject(optionsFalse); - oosFalse.close(); - - // Deserialize - ByteArrayInputStream baisFalse = new ByteArrayInputStream(baosFalse.toByteArray()); - ObjectInputStream oisFalse = new ObjectInputStream(baisFalse); - LanceSparkReadOptions deserializedOptionsFalse = (LanceSparkReadOptions) oisFalse.readObject(); - - Assertions.assertFalse( - deserializedOptionsFalse.getNearest().isUseIndex(), - "useIndex should remain false after serialization/deserialization"); - - // Case 2: useIndex is explicitly set to true - String jsonTrue = - "{\"column\":\"vector_col\",\"k\":10,\"key\":[1.0,2.0,3.0],\"useIndex\":true}"; - LanceSparkReadOptions optionsTrue = - LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").nearest(jsonTrue).build(); - - Query queryTrue = optionsTrue.getNearest(); - Assertions.assertTrue(queryTrue.isUseIndex()); - - // Serialize - ByteArrayOutputStream baosTrue = new ByteArrayOutputStream(); - ObjectOutputStream oosTrue = new ObjectOutputStream(baosTrue); - oosTrue.writeObject(optionsTrue); - oosTrue.close(); - - // Deserialize - ByteArrayInputStream baisTrue = new ByteArrayInputStream(baosTrue.toByteArray()); - ObjectInputStream oisTrue = new ObjectInputStream(baisTrue); - LanceSparkReadOptions deserializedOptionsTrue = (LanceSparkReadOptions) oisTrue.readObject(); - - Assertions.assertTrue( - deserializedOptionsTrue.getNearest().isUseIndex(), - "useIndex should remain true after serialization/deserialization"); - } - @Test public void testExecutorCredentialRefreshDefaultsToTrue() { LanceSparkReadOptions options = @@ -136,6 +53,20 @@ public void testExecutorCredentialRefreshParsedFromOptions() { Assertions.assertTrue(optionsTrue.isExecutorCredentialRefresh()); } + @Test + public void testDeprecatedNearestReadOptionFailsFast() { + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + LanceSparkReadOptions.from( + Collections.singletonMap("nearest", "{\"column\":\"vector\"}"), + "s3://bucket/path")); + + Assertions.assertTrue(exception.getMessage().contains("nearest")); + Assertions.assertTrue(exception.getMessage().contains("VECTOR_SEARCH")); + } + @Test public void testExecutorCredentialRefreshSurvivesSerialization() throws IOException, ClassNotFoundException { @@ -175,11 +106,6 @@ public void testExecutorCredentialRefreshPreservedByWithVersion() { "withVersion() must propagate the executor_credential_refresh flag"); } - /** - * Catalog-level config (set via {@code --conf spark.sql.catalog..}) is the only route - * available to SQL DML (DELETE / UPDATE / MERGE INTO), which has no per-statement {@code - * .option(...)} attach point. This test guards the catalog-conf path. - */ @Test public void testExecutorCredentialRefreshFromCatalogDefaults() { Map catalogOpts = new HashMap<>(); @@ -198,20 +124,12 @@ public void testExecutorCredentialRefreshFromCatalogDefaults() { + "so it takes effect for SELECT without .option(...) and for SQL DML"); } - /** - * Spark's scan-time options (via {@code spark.read.option(...)}) go through a second {@code - * fromOptions(mergedMap)} rebuild in {@code LanceDataset.newScanBuilder}. Per-read settings must - * win over catalog-level defaults. - */ @Test public void testPerReadOptionOverridesCatalogDefaults() { Map catalogOpts = new HashMap<>(); catalogOpts.put(LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "false"); LanceSparkCatalogConfig catalogConfig = LanceSparkCatalogConfig.from(catalogOpts); - // Simulate the rebuild path in LanceDataset.newScanBuilder: the builder starts by applying - // the catalog defaults, then fromOptions() replays against the merged (catalog + per-read) - // map where the per-read value wins. Map merged = new HashMap<>(catalogConfig.getStorageOptions()); merged.put(LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "true"); diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java deleted file mode 100644 index 6c116ec6c..000000000 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/BaseSparkConnectorReadWithVectorSearchTest.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark.read; - -import org.lance.index.DistanceType; -import org.lance.ipc.Query; -import org.lance.spark.LanceDataSource; -import org.lance.spark.LanceSparkReadOptions; -import org.lance.spark.TestUtils; -import org.lance.spark.utils.QueryUtils; - -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import static org.junit.jupiter.api.Assertions.*; - -/* - *The test logic is same with org.lance.VectorSearchTest.test_knn - */ - -public abstract class BaseSparkConnectorReadWithVectorSearchTest { - private static SparkSession spark; - private static String dbPath; - private static Dataset data; - - @BeforeAll - static void setup() { - - Query.Builder builder = new Query.Builder(); - float[] key = new float[32]; - for (int i = 0; i < 32; i++) { - key[i] = (float) (i + 32); - } - builder.setK(1); - builder.setColumn("vec"); - builder.setKey(key); - builder.setUseIndex(true); - builder.setDistanceType(DistanceType.L2); - - spark = - SparkSession.builder() - .appName("spark-lance-connector-test") - .master("local") - .config("spark.sql.catalog.lance", "org.lance.spark.LanceNamespaceSparkCatalog") - .getOrCreate(); - dbPath = TestUtils.TestTable1Config.dbPath; - data = - spark - .read() - .format(LanceDataSource.name) - .option(LanceSparkReadOptions.CONFIG_NEAREST, QueryUtils.queryToString(builder.build())) - .option( - LanceSparkReadOptions.CONFIG_DATASET_URI, - TestUtils.getDatasetUri(dbPath, "test_dataset5")) - .load(); - data.createOrReplaceTempView("test_dataset5"); - } - - @AfterAll - static void tearDown() { - if (spark != null) { - spark.stop(); - } - } - - @Test - public void validateData() { - Set expectedI = new HashSet<>(Arrays.asList(1, 81, 161, 241, 321)); - Set actualI = new HashSet<>(); - List rows = data.collectAsList(); - for (int i = 0; i < rows.size(); i++) { - actualI.add(rows.get(i).getInt(0)); - } - assertEquals(expectedI, actualI, "Unexpected values in 'i' column"); - } -} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchRestNamespaceSmokeTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchRestNamespaceSmokeTest.java new file mode 100644 index 000000000..f9de549b2 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchRestNamespaceSmokeTest.java @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class BaseSparkSearchRestNamespaceSmokeTest { + private static final String CATALOG_NAME = "lance_rest_search"; + private SparkSession spark; + + @BeforeEach + void setup() { + String uri = System.getenv("LANCE_SPARK_REST_URI"); + String apiKey = System.getenv("LANCE_SPARK_REST_API_KEY"); + String database = System.getenv("LANCE_SPARK_REST_DATABASE"); + Assumptions.assumeTrue(uri != null && apiKey != null && database != null); + + spark = + SparkSession.builder() + .appName("lance-search-rest-namespace-smoke-test") + .master("local[2]") + .config( + "spark.sql.catalog." + CATALOG_NAME, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + CATALOG_NAME + ".impl", "rest") + .config("spark.sql.catalog." + CATALOG_NAME + ".uri", uri) + .config("spark.sql.catalog." + CATALOG_NAME + ".headers.x-api-key", apiKey) + .config("spark.sql.catalog." + CATALOG_NAME + ".headers.x-lancedb-database", database) + .getOrCreate(); + spark.sql("CREATE NAMESPACE IF NOT EXISTS " + CATALOG_NAME + ".default"); + } + + @AfterEach + void tearDown() throws IOException { + if (spark != null) { + spark.close(); + } + } + + @Test + public void testVectorAndFullTextSearchViaRestNamespace() { + String vectorTable = fullTableName("vector_search"); + spark.sql( + "CREATE TABLE " + + vectorTable + + " (id INT NOT NULL, vector ARRAY NOT NULL) USING lance " + + "TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4')"); + spark.sql( + "INSERT INTO " + + vectorTable + + " VALUES " + + "(0, array(0.0, 0.0, 0.0, 0.0)), " + + "(1, array(1.0, 1.0, 1.0, 1.0)), " + + "(2, array(10.0, 10.0, 10.0, 10.0))"); + + List vectorRows = + spark + .sql( + "SELECT id, _distance FROM VECTOR_SEARCH('" + + vectorTable + + "', array(0.0, 0.0, 0.0, 0.0), 2) ORDER BY _distance, id") + .collectAsList(); + assertEquals(2, vectorRows.size()); + assertEquals(0, vectorRows.get(0).getInt(0)); + assertEquals(1, vectorRows.get(1).getInt(0)); + + String ftsTable = fullTableName("fts_search"); + spark.sql("CREATE TABLE " + ftsTable + " (id INT NOT NULL, body STRING) USING lance"); + spark.sql( + "INSERT INTO " + + ftsTable + + " VALUES " + + "(1, 'lance vector search'), " + + "(2, 'spark connector table function'), " + + "(3, 'lance full text search')"); + spark.sql( + "ALTER TABLE " + + ftsTable + + " CREATE INDEX body_fts USING fts (body) WITH (" + + "base_tokenizer='simple', " + + "language='English', " + + "max_token_length=40, " + + "lower_case=true, " + + "stem=false, " + + "remove_stop_words=false, " + + "ascii_folding=false, " + + "with_position=true)"); + + List ftsRows = + spark + .sql("SELECT id, _score FROM SEARCH('" + ftsTable + "', 'lance', 10) ORDER BY id") + .collectAsList(); + List ids = ftsRows.stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + assertEquals(java.util.Arrays.asList(1, 3), ids); + assertTrue(ftsRows.get(0).getFloat(1) > 0.0f); + + String hybridTable = fullTableName("hybrid_search"); + spark.sql( + "CREATE TABLE " + + hybridTable + + " (id INT NOT NULL, body STRING, vector ARRAY NOT NULL) USING lance " + + "TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4')"); + spark.sql( + "INSERT INTO " + + hybridTable + + " VALUES " + + "(1, 'lance vector search', array(0.0, 0.0, 0.0, 0.0)), " + + "(2, 'spark connector table function', array(1.0, 1.0, 1.0, 1.0)), " + + "(3, 'lance full text search', array(10.0, 10.0, 10.0, 10.0))"); + spark.sql( + "ALTER TABLE " + + hybridTable + + " CREATE INDEX body_fts USING fts (body) WITH (" + + "base_tokenizer='simple', " + + "language='English', " + + "max_token_length=40, " + + "lower_case=true, " + + "stem=false, " + + "remove_stop_words=false, " + + "ascii_folding=false, " + + "with_position=true)"); + + List hybridRows = + spark + .sql( + "SELECT id, _distance, _score, _relevance_score FROM HYBRID_SEARCH(" + + "table => '" + + hybridTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "query => 'lance', " + + "columns => array('id'), " + + "num_results => 3, " + + "candidates => 3, " + + "rrf_k => 1.0) " + + "ORDER BY _relevance_score DESC, id") + .collectAsList(); + List hybridIds = + hybridRows.stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + assertEquals(java.util.Arrays.asList(1, 3, 2), hybridIds); + assertEquals(0.0f, hybridRows.get(0).getFloat(1), 0.001f); + assertTrue(hybridRows.get(0).getFloat(2) > 0.0f); + assertTrue(hybridRows.get(2).isNullAt(2)); + } + + private String fullTableName(String prefix) { + return CATALOG_NAME + + ".default." + + prefix + + "_" + + UUID.randomUUID().toString().replace("-", ""); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchTableFunctionTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchTableFunctionTest.java new file mode 100644 index 000000000..0cb09fc9f --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/search/BaseSparkSearchTableFunctionTest.java @@ -0,0 +1,464 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.search; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class BaseSparkSearchTableFunctionTest { + private static final String CATALOG_NAME = "lance_search"; + private SparkSession spark; + + @TempDir Path tempDir; + + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("lance-search-table-function-test") + .master("local[2]") + .config( + "spark.sql.catalog." + CATALOG_NAME, "org.lance.spark.LanceNamespaceSparkCatalog") + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") + .config("spark.sql.catalog." + CATALOG_NAME + ".impl", "dir") + .config("spark.sql.catalog." + CATALOG_NAME + ".root", tempDir.toString()) + .getOrCreate(); + spark.sql("CREATE NAMESPACE " + CATALOG_NAME + ".default"); + } + + @AfterEach + void tearDown() throws IOException { + if (spark != null) { + spark.close(); + } + } + + @Test + public void testVectorSearchTableFunction() { + String fullName = createVectorTable(); + + Dataset result = + spark.sql( + "SELECT id, _distance FROM VECTOR_SEARCH('" + + fullName + + "', array(0.0, 0.0, 0.0, 0.0), 2) ORDER BY _distance, id"); + + List rows = result.collectAsList(); + assertEquals(2, rows.size()); + assertEquals(0, rows.get(0).getInt(0)); + assertEquals(1, rows.get(1).getInt(0)); + assertEquals(0.0f, rows.get(0).getFloat(1), 0.001f); + assertTrue(rows.get(1).getFloat(1) > rows.get(0).getFloat(1)); + } + + @Test + public void testSearchTableFunction() { + String fullName = createFtsTable(); + + Dataset result = + spark.sql( + "SELECT id, body, _score FROM SEARCH('" + fullName + "', 'lance', 10) ORDER BY id"); + + List rows = result.collectAsList(); + List ids = rows.stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + assertEquals(java.util.Arrays.asList(1, 3), ids); + assertTrue(rows.get(0).getString(1).contains("lance")); + assertTrue(rows.get(0).getFloat(2) > 0.0f); + } + + @Test + public void testVectorSearchOffsetReturnsRequestedCount() { + Assumptions.assumeTrue(supportsNamedArguments()); + String fullName = createVectorTable(); + + List rows = + spark + .sql( + "SELECT id, _distance FROM VECTOR_SEARCH(" + + "table => '" + + fullName + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "columns => array('id'), " + + "num_results => 1, " + + "offset => 1) ORDER BY _distance, id") + .collectAsList(); + + assertEquals(1, rows.size()); + assertEquals(1, rows.get(0).getInt(0)); + } + + @Test + public void testVectorSearchVersionUsesVersionedSchema() { + Assumptions.assumeTrue(supportsNamedArguments()); + String fullName = createVectorTable(); + + spark.sql( + "CREATE TEMPORARY VIEW version_schema_view AS " + + "SELECT _rowaddr, _fragid, concat('extra_', id) AS extra FROM " + + fullName); + spark.sql("ALTER TABLE " + fullName + " ADD COLUMNS extra FROM version_schema_view"); + + Dataset result = + spark.sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + fullName + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "num_results => 1, " + + "version => 2)"); + + assertEquals( + java.util.Arrays.asList("id", "vector", "_distance"), + java.util.Arrays.asList(result.columns())); + List rows = result.collectAsList(); + assertEquals(1, rows.size()); + assertEquals(0, rows.get(0).getInt(0)); + } + + @Test + public void testHybridSearchTableFunction() { + String fullName = createHybridTable(); + + Dataset result = + spark.sql( + "SELECT id, body, _distance, _score, _relevance_score FROM HYBRID_SEARCH('" + + fullName + + "', array(0.0, 0.0, 0.0, 0.0), 'lance', 3) " + + "ORDER BY _relevance_score DESC, id"); + + List rows = result.collectAsList(); + List ids = rows.stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + assertEquals(java.util.Arrays.asList(1, 3, 2), ids); + assertEquals(0.0f, rows.get(0).getFloat(2), 0.001f); + assertTrue(rows.get(0).getFloat(3) > 0.0f); + assertTrue(rows.get(0).getFloat(4) > rows.get(1).getFloat(4)); + assertTrue(rows.get(1).getFloat(3) > 0.0f); + assertTrue(rows.get(1).getFloat(2) > rows.get(0).getFloat(2)); + assertTrue(rows.get(2).isNullAt(3)); + assertTrue(rows.get(2).getFloat(2) > rows.get(0).getFloat(2)); + + Dataset defaultResult = + spark.sql( + "SELECT * FROM HYBRID_SEARCH('" + + fullName + + "', array(0.0, 0.0, 0.0, 0.0), 'lance', 1) " + + "ORDER BY _relevance_score DESC, id"); + assertEquals( + java.util.Arrays.asList("id", "body", "vector", "_distance", "_score", "_relevance_score"), + java.util.Arrays.asList(defaultResult.columns())); + Row defaultRow = defaultResult.collectAsList().get(0); + assertEquals(java.util.Arrays.asList(0.0f, 0.0f, 0.0f, 0.0f), defaultRow.getList(2)); + } + + @Test + public void testVectorSearchRequiresQueryVector() { + Assumptions.assumeTrue(supportsNamedArguments()); + String fullName = createVectorTable(); + + Exception exception = + assertThrows( + Exception.class, + () -> + spark + .sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + fullName + + "', " + + "vector_column => 'vec')") + .collectAsList()); + assertTrue(getDeepMessage(exception).contains("query_vector is required")); + } + + @Test + public void testNamedArguments() { + Assumptions.assumeTrue(supportsNamedArguments()); + String vectorTable = createVectorTable(); + String ftsTable = createFtsTable(); + String hybridTable = createHybridTable(); + + List vectorRows = + spark + .sql( + "SELECT id, _distance FROM VECTOR_SEARCH(" + + "table => '" + + vectorTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "vector_column => 'vector', " + + "num_results => 2, " + + "distance_type => 'l2', " + + "bypass_vector_index => true, " + + "columns => array('id')) ORDER BY _distance, id") + .collectAsList(); + assertEquals(2, vectorRows.size()); + assertEquals(0, vectorRows.get(0).getInt(0)); + + List metricOnlyRows = + spark + .sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + vectorTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "columns => array('_distance'), " + + "filter => 'id >= 0', " + + "num_results => 1)") + .collectAsList(); + assertEquals(1, metricOnlyRows.size()); + assertEquals(1, metricOnlyRows.get(0).size()); + assertEquals(0.0f, metricOnlyRows.get(0).getFloat(0), 0.001f); + + Dataset vectorRowIdOnly = + spark.sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + vectorTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "columns => array('_rowid'), " + + "num_results => 1)"); + assertEquals( + java.util.Arrays.asList("_rowid", "_distance"), + java.util.Arrays.asList(vectorRowIdOnly.columns())); + Row vectorRowIdOnlyRow = vectorRowIdOnly.collectAsList().get(0); + assertTrue(vectorRowIdOnlyRow.getLong(0) >= 0); + assertEquals(0.0f, vectorRowIdOnlyRow.getFloat(1), 0.001f); + + Dataset vectorWithRowId = + spark.sql( + "SELECT * FROM VECTOR_SEARCH(" + + "table => '" + + vectorTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "columns => array('ID'), " + + "with_row_id => true, " + + "num_results => 1)"); + assertEquals( + java.util.Arrays.asList("id", "_distance", "_rowid"), + java.util.Arrays.asList(vectorWithRowId.columns())); + Row vectorRowId = vectorWithRowId.collectAsList().get(0); + assertEquals(0, vectorRowId.getInt(0)); + assertEquals(0.0f, vectorRowId.getFloat(1), 0.001f); + assertTrue(vectorRowId.getLong(2) >= 0); + + List searchRows = + spark + .sql( + "SELECT id, body, _score FROM SEARCH(" + + "table => '" + + ftsTable + + "', " + + "query => 'lance', " + + "search_columns => array('body'), " + + "columns => array('id', 'body'), " + + "limit => 10) ORDER BY id") + .collectAsList(); + assertEquals(2, searchRows.size()); + assertEquals(1, searchRows.get(0).getInt(0)); + + List scoreOnlyRows = + spark + .sql( + "SELECT * FROM SEARCH(" + + "table => '" + + ftsTable + + "', " + + "query => 'lance', " + + "columns => array('_score'), " + + "filter => 'id >= 0', " + + "limit => 1)") + .collectAsList(); + assertEquals(1, scoreOnlyRows.size()); + assertEquals(1, scoreOnlyRows.get(0).size()); + assertTrue(scoreOnlyRows.get(0).getFloat(0) > 0.0f); + + Dataset searchRowIdOnly = + spark.sql( + "SELECT * FROM SEARCH(" + + "table => '" + + ftsTable + + "', " + + "query => 'lance', " + + "columns => array('_rowid'), " + + "limit => 1)"); + assertEquals( + java.util.Arrays.asList("_rowid", "_score"), + java.util.Arrays.asList(searchRowIdOnly.columns())); + Row searchRowIdOnlyRow = searchRowIdOnly.collectAsList().get(0); + assertTrue(searchRowIdOnlyRow.getLong(0) >= 0); + assertTrue(searchRowIdOnlyRow.getFloat(1) > 0.0f); + + Dataset searchWithRowId = + spark.sql( + "SELECT * FROM SEARCH(" + + "table => '" + + ftsTable + + "', " + + "query => 'lance', " + + "columns => array('ID'), " + + "with_row_id => true, " + + "limit => 1)"); + assertEquals( + java.util.Arrays.asList("id", "_score", "_rowid"), + java.util.Arrays.asList(searchWithRowId.columns())); + Row searchRowId = searchWithRowId.collectAsList().get(0); + assertTrue(searchRowId.getInt(0) == 1 || searchRowId.getInt(0) == 3); + assertTrue(searchRowId.getFloat(1) > 0.0f); + assertTrue(searchRowId.getLong(2) >= 0); + + Dataset hybridWithRowId = + spark.sql( + "SELECT * FROM HYBRID_SEARCH(" + + "table => '" + + hybridTable + + "', " + + "query_vector => array(0.0, 0.0, 0.0, 0.0), " + + "query => 'lance', " + + "vector_column => 'vector', " + + "search_columns => array('body'), " + + "columns => array('ID'), " + + "num_results => 2, " + + "candidates => 3, " + + "rrf_k => 1.0, " + + "with_row_id => true) " + + "ORDER BY _relevance_score DESC, id"); + assertEquals( + java.util.Arrays.asList("id", "_distance", "_score", "_relevance_score", "_rowid"), + java.util.Arrays.asList(hybridWithRowId.columns())); + List hybridRows = hybridWithRowId.collectAsList(); + assertEquals(2, hybridRows.size()); + assertEquals(1, hybridRows.get(0).getInt(0)); + assertEquals(0.0f, hybridRows.get(0).getFloat(1), 0.001f); + assertTrue(hybridRows.get(0).getFloat(2) > 0.0f); + assertTrue(hybridRows.get(0).getFloat(3) > hybridRows.get(1).getFloat(3)); + assertTrue(hybridRows.get(0).getLong(4) >= 0); + } + + private String createVectorTable() { + String fullName = fullTableName("vector_search"); + spark.sql( + "CREATE TABLE " + + fullName + + " (id INT NOT NULL, vector ARRAY NOT NULL) USING lance " + + "TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4')"); + spark.sql( + "INSERT INTO " + + fullName + + " VALUES " + + "(0, array(0.0, 0.0, 0.0, 0.0)), " + + "(1, array(1.0, 1.0, 1.0, 1.0)), " + + "(2, array(10.0, 10.0, 10.0, 10.0))"); + return fullName; + } + + private String createFtsTable() { + String fullName = fullTableName("fts_search"); + spark.sql("CREATE TABLE " + fullName + " (id INT NOT NULL, body STRING) USING lance"); + spark.sql( + "INSERT INTO " + + fullName + + " VALUES " + + "(1, 'lance vector search'), " + + "(2, 'spark connector table function'), " + + "(3, 'lance full text search')"); + spark.sql( + "ALTER TABLE " + + fullName + + " CREATE INDEX body_fts USING fts (body) WITH (" + + "base_tokenizer='simple', " + + "language='English', " + + "max_token_length=40, " + + "lower_case=true, " + + "stem=false, " + + "remove_stop_words=false, " + + "ascii_folding=false, " + + "with_position=true)"); + return fullName; + } + + private String createHybridTable() { + String fullName = fullTableName("hybrid_search"); + spark.sql( + "CREATE TABLE " + + fullName + + " (id INT NOT NULL, body STRING, vector ARRAY NOT NULL) USING lance " + + "TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4')"); + spark.sql( + "INSERT INTO " + + fullName + + " VALUES " + + "(1, 'lance vector search', array(0.0, 0.0, 0.0, 0.0)), " + + "(2, 'spark connector table function', array(1.0, 1.0, 1.0, 1.0)), " + + "(3, 'lance full text search', array(10.0, 10.0, 10.0, 10.0))"); + spark.sql( + "ALTER TABLE " + + fullName + + " CREATE INDEX body_fts USING fts (body) WITH (" + + "base_tokenizer='simple', " + + "language='English', " + + "max_token_length=40, " + + "lower_case=true, " + + "stem=false, " + + "remove_stop_words=false, " + + "ascii_folding=false, " + + "with_position=true)"); + return fullName; + } + + private String fullTableName(String prefix) { + return CATALOG_NAME + + ".default." + + prefix + + "_" + + UUID.randomUUID().toString().replace("-", ""); + } + + private String getDeepMessage(Throwable throwable) { + StringBuilder builder = new StringBuilder(); + Throwable current = throwable; + while (current != null) { + if (current.getMessage() != null) { + builder.append(current.getMessage()).append('\n'); + } + current = current.getCause(); + } + return builder.toString(); + } + + private boolean supportsNamedArguments() { + return !spark.version().startsWith("3.4."); + } +}