diff --git a/.github/workflows/spark-namespace-insert.yml b/.github/workflows/spark-namespace-insert.yml new file mode 100644 index 000000000..6441597db --- /dev/null +++ b/.github/workflows/spark-namespace-insert.yml @@ -0,0 +1,178 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Spark Namespace Insert Docker + +on: + pull_request: + types: + - opened + - synchronize + - ready_for_review + - reopened + paths: + - ".github/workflows/spark-namespace-insert.yml" + - "Makefile" + - "docker/**" + - "integration-tests/**" + - "lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java" + - "lance-spark-base_2.12/src/main/java/org/lance/spark/write/**" + - "lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkWriteOptionsTest.java" + - "pom.xml" + - "*/pom.xml" + workflow_dispatch: + inputs: + spark-version: + description: "Spark version to test" + required: true + default: "3.5" + scala-version: + description: "Scala version to test" + required: true + default: "2.13" + backends: + description: "Comma-separated test backends: local or local,rest-dir" + required: true + default: "local,rest-dir" + rest-uri: + description: "Optional REST namespace URI. If omitted, tests start a local REST directory namespace." + required: false + default: "" + rest-database: + description: "Optional database header value for an external REST namespace" + required: false + default: "" + docker-run-args: + description: "Extra docker run args for docker-test" + required: false + default: "" + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + SPARK_VERSION: ${{ github.event.inputs['spark-version'] || '3.5' }} + SCALA_VERSION: ${{ github.event.inputs['scala-version'] || '2.13' }} + NAMESPACE_INSERT_TEST_BACKENDS: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.backends || 'local,rest-dir' }} + NAMESPACE_INSERT_PYTEST_CMD: >- + pytest /home/lance/tests/test_lance_spark.py::TestDMLNamespaceInsert + -v --timeout=180 + +jobs: + namespace-insert-docker-test: + name: Namespace Insert Docker Test + runs-on: ubuntu-24.04 + timeout-minutes: 90 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + - name: Set up Java + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: 17 + cache: "maven" + - name: Resolve Docker build args + id: docker-args + run: | + make print-docker-build-args SPARK_VERSION=${SPARK_VERSION} SCALA_VERSION=${SCALA_VERSION} >> $GITHUB_OUTPUT + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build test-base image (cached) + uses: docker/build-push-action@v6 + with: + context: docker + file: docker/Dockerfile.test-base + load: true + tags: lance-spark-test-base:${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + build-args: | + SPARK_DOWNLOAD_VERSION=${{ steps.docker-args.outputs.spark-download-version }} + SPARK_MAJOR_VERSION=${{ env.SPARK_VERSION }} + SCALA_VERSION=${{ env.SCALA_VERSION }} + PY4J_VERSION=${{ steps.docker-args.outputs.py4j-version }} + SPARK_SCALA_SUFFIX=${{ steps.docker-args.outputs.spark-scala-suffix }} + cache-from: type=gha,scope=namespace-insert-test-base-${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + cache-to: type=gha,mode=max,scope=namespace-insert-test-base-${{ env.SPARK_VERSION }}_${{ env.SCALA_VERSION }} + - name: Build bundle + run: make bundle SPARK_VERSION=${SPARK_VERSION} SCALA_VERSION=${SCALA_VERSION} + - name: Build test image + run: | + make docker-build-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + LANCE_NAMESPACE_IMPL_VERSION=${{ steps.docker-args.outputs.lance-namespace-impl-version }} + - name: Run directory namespace insert tests + if: ${{ contains(env.NAMESPACE_INSERT_TEST_BACKENDS, 'local') }} + run: | + make docker-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + TEST_BACKENDS=local \ + PYTEST_CMD="${NAMESPACE_INSERT_PYTEST_CMD}" + - name: Resolve REST namespace URI + id: rest + if: ${{ contains(env.NAMESPACE_INSERT_TEST_BACKENDS, 'rest-dir') }} + env: + INPUT_REST_URI: ${{ github.event.inputs['rest-uri'] }} + INPUT_DOCKER_RUN_ARGS: ${{ github.event.inputs['docker-run-args'] }} + run: | + rest_uri="${INPUT_REST_URI}" + docker_run_args="${INPUT_DOCKER_RUN_ARGS}" + start_rest_dir="false" + rest_dir_root="" + rest_dir_port="" + + if [ -z "${rest_uri}" ]; then + rest_dir_port="10024" + rest_dir_root="/home/lance/rest-data" + rest_uri="http://127.0.0.1:${rest_dir_port}" + start_rest_dir="true" + fi + + echo "uri=${rest_uri}" >> "$GITHUB_OUTPUT" + echo "start_rest_dir=${start_rest_dir}" >> "$GITHUB_OUTPUT" + echo "rest_dir_root=${rest_dir_root}" >> "$GITHUB_OUTPUT" + echo "rest_dir_port=${rest_dir_port}" >> "$GITHUB_OUTPUT" + { + echo "docker_run_args<> "$GITHUB_OUTPUT" + - name: Run REST directory namespace insert tests + if: ${{ contains(env.NAMESPACE_INSERT_TEST_BACKENDS, 'rest-dir') }} + env: + LANCE_SPARK_REST_URI: ${{ steps.rest.outputs.uri }} + LANCE_SPARK_REST_API_KEY: ${{ secrets.LANCE_SPARK_REST_API_KEY }} + LANCE_SPARK_REST_DATABASE: ${{ github.event.inputs['rest-database'] }} + LANCE_SPARK_START_REST_DIR: ${{ steps.rest.outputs.start_rest_dir }} + LANCE_SPARK_REST_DIR_ROOT: ${{ steps.rest.outputs.rest_dir_root }} + LANCE_SPARK_REST_DIR_PORT: ${{ steps.rest.outputs.rest_dir_port }} + DOCKER_RUN_ARGS: ${{ steps.rest.outputs.docker_run_args }} + run: | + make docker-test \ + SPARK_VERSION=${SPARK_VERSION} \ + SCALA_VERSION=${SCALA_VERSION} \ + TEST_BACKENDS=rest-dir \ + LANCE_SPARK_REST_URI="${LANCE_SPARK_REST_URI}" \ + LANCE_SPARK_REST_API_KEY="${LANCE_SPARK_REST_API_KEY}" \ + LANCE_SPARK_REST_DATABASE="${LANCE_SPARK_REST_DATABASE}" \ + LANCE_SPARK_START_REST_DIR="${LANCE_SPARK_START_REST_DIR}" \ + LANCE_SPARK_REST_DIR_ROOT="${LANCE_SPARK_REST_DIR_ROOT}" \ + LANCE_SPARK_REST_DIR_PORT="${LANCE_SPARK_REST_DIR_PORT}" \ + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS}" \ + PYTEST_CMD="${NAMESPACE_INSERT_PYTEST_CMD}" diff --git a/docs/src/config.md b/docs/src/config.md index 2c5727347..8ab40ac55 100644 --- a/docs/src/config.md +++ b/docs/src/config.md @@ -71,6 +71,23 @@ and namespace-specific options: | `spark.sql.catalog.{name}.parent` | String | ✗ | Parent prefix for multi-level namespaces. See [Note on Namespace Levels](#note-on-namespace-levels). | | `spark.sql.catalog.{name}.parent_delimiter` | String | ✗ | Delimiter for parent prefix (default: `.`). See [Note on Namespace Levels](#note-on-namespace-levels). | +## Write Options + +Write options can be set on DataFrame writes. Catalog-level values are also used as defaults when +they are present in the Spark catalog configuration. + +| Option | Type | Default | Description | +|--------------------------------|---------|---------|----------------------------------------------------------------------------------------------------------| +| `use_namespace_insert` | Boolean | `false` | Use the Lance Namespace insert API for eligible append writes to namespace-backed tables. | +| `namespace_insert_parallelism` | Integer | `0` | Number of writer tasks to request from Spark for namespace insert writes. `0` preserves Spark's plan. For sharded tables Spark uses the sharding distribution; for unsharded tables Spark repartitions by the first output column. | +| `batch_size` | Integer | `8192` | Maximum rows per Arrow batch/request before flushing. | +| `max_batch_bytes` | Long | `268435456` | Maximum approximate bytes per Arrow batch/request before flushing. | + +Namespace insert writes are intended for append ingestion through a namespace implementation, +including REST namespaces. Each insert request is committed by the namespace as it runs, so this mode +does not provide the same Spark driver-side atomic commit behavior as the default writer. If a Spark +task or driver fails after some requests complete, those rows may already be visible. + ## Example Namespace Implementations ### Directory Namespace diff --git a/docs/src/operations/dml/insert-into.md b/docs/src/operations/dml/insert-into.md index c6ee49041..2d76e4897 100644 --- a/docs/src/operations/dml/insert-into.md +++ b/docs/src/operations/dml/insert-into.md @@ -59,6 +59,50 @@ Add data to existing Lance tables using SQL or DataFrames. newDF.write().mode("append").saveAsTable("users"); ``` +## Namespace Insert Writes + +For namespace-backed tables, append writes can use the Lance Namespace insert API. This is useful +when the namespace implementation can execute ingestion close to the table storage, such as a REST +namespace service. + +=== "Python" + ```python + df.writeTo("users") \ + .option("use_namespace_insert", "true") \ + .option("namespace_insert_parallelism", "8") \ + .option("batch_size", "4096") \ + .append() + ``` + +=== "Scala" + ```scala + df.writeTo("users") + .option("use_namespace_insert", "true") + .option("namespace_insert_parallelism", "8") + .option("batch_size", "4096") + .append() + ``` + +`use_namespace_insert` applies to append writes to existing namespace-backed tables. Create, +replace, overwrite, path-based writes, and schema backfill operations use the default writer. + +### Expected Behavior + +For users, namespace insert writes look like a normal DataFrame append with two optional write +options. Existing `INSERT INTO` statements and `.append()` calls keep using the default Spark writer +unless `use_namespace_insert` is set. When enabled, Spark still plans executor-side writer tasks, but +each task sends Arrow batches to the configured Lance namespace instead of committing Lance fragments +directly from the driver. This lets directory and REST namespaces handle ingestion through the same +namespace API. + +When `namespace_insert_parallelism` is greater than `0`, Spark creates that many writer tasks. For +sharded tables Spark uses the table sharding distribution; for unsharded tables Spark repartitions +by the first output column. + +Each namespace insert request is committed as it runs. If a Spark task or driver fails after some +requests complete, those rows may already be visible. Use the default writer when you need Spark's +driver-side atomic commit behavior. + ## Insert with Column Specification === "SQL" diff --git a/integration-tests/test_lance_spark.py b/integration-tests/test_lance_spark.py index a2f30e4b1..54884fcc3 100644 --- a/integration-tests/test_lance_spark.py +++ b/integration-tests/test_lance_spark.py @@ -16,7 +16,16 @@ import time import pytest from packaging.version import Version -from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType, BinaryType +from pyspark.sql.types import ( + ArrayType, + BinaryType, + DoubleType, + FloatType, + IntegerType, + StringType, + StructField, + StructType, +) SPARK_VERSION = Version(os.environ.get("SPARK_VERSION", "3.5")) @@ -123,6 +132,12 @@ def _require_sql_search_backend(spark): pytest.skip("SQL search table functions are covered on local dir and rest-dir backends") +def _require_namespace_insert_backend(spark): + backend = getattr(spark, "_lance_backend", None) + if backend not in ("local", "rest-dir"): + pytest.skip("Namespace insert writes are covered on local dir and rest-dir backends") + + # ============================================================================= # DDL (Data Definition Language) Tests # ============================================================================= @@ -1989,6 +2004,91 @@ def test_insert_append_data(self, spark): assert count == 4 +@pytest.mark.rest_dir_compatible +class TestDMLNamespaceInsert: + """Test append writes through the Lance namespace insert API.""" + + def test_namespace_insert_append_data(self, spark): + """Test DataFrame append through namespace insert with multiple writer tasks.""" + _require_namespace_insert_backend(spark) + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT, + name STRING, + value DOUBLE + ) + """) + + df = ( + spark.range(0, 24) + .repartition(6) + .selectExpr( + "CAST(id AS INT) AS id", + "concat('name-', id) AS name", + "CAST(id * 1.5 AS DOUBLE) AS value", + ) + ) + + ( + df.writeTo("default.test_table") + .option("use_namespace_insert", "true") + .option("namespace_insert_parallelism", "3") + .option("batch_size", "5") + .append() + ) + + rows = spark.sql(""" + SELECT COUNT(*) AS count_rows, SUM(id) AS sum_id, SUM(value) AS sum_value + FROM default.test_table + """).collect() + + assert rows[0].count_rows == 24 + assert rows[0].sum_id == sum(range(24)) + assert rows[0].sum_value == pytest.approx(sum(i * 1.5 for i in range(24))) + + def test_namespace_insert_append_vector_data(self, spark): + """Test namespace insert preserves fixed-size-list vector writes.""" + _require_namespace_insert_backend(spark) + + spark.sql(""" + CREATE TABLE default.test_table ( + id INT, + vector ARRAY + ) USING lance + TBLPROPERTIES ('vector.arrow.fixed-size-list.size' = '4') + """) + + schema = StructType([ + StructField("id", IntegerType(), True), + StructField("vector", ArrayType(FloatType()), True), + ]) + df = spark.createDataFrame( + [ + (1, [1.0, 0.0, 0.0, 0.0]), + (2, [0.0, 1.0, 0.0, 0.0]), + (3, [0.0, 0.0, 1.0, 0.0]), + (4, [0.0, 0.0, 0.0, 1.0]), + ], + schema, + ).repartition(2) + + ( + df.writeTo("default.test_table") + .option("use_namespace_insert", "true") + .option("namespace_insert_parallelism", "2") + .append() + ) + + rows = spark.sql(""" + SELECT id, vector FROM default.test_table ORDER BY id + """).collect() + + assert [row.id for row in rows] == [1, 2, 3, 4] + assert rows[0].vector == [1.0, 0.0, 0.0, 0.0] + assert rows[3].vector == [0.0, 0.0, 0.0, 1.0] + + @requires_update_or_merge class TestDMLUpdate: """Test DML UPDATE SET operations.""" diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java index 0e0f6db06..d250e8a23 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkWriteOptions.java @@ -58,9 +58,13 @@ public class LanceSparkWriteOptions implements Serializable { public static final String CONFIG_USE_LARGE_VAR_TYPES = "use_large_var_types"; public static final String CONFIG_MAX_BATCH_BYTES = "max_batch_bytes"; public static final String CONFIG_BLOB_PACK_FILE_SIZE_THRESHOLD = "blob_pack_file_size_threshold"; + public static final String CONFIG_USE_NAMESPACE_INSERT = "use_namespace_insert"; + public static final String CONFIG_NAMESPACE_INSERT_PARALLELISM = "namespace_insert_parallelism"; private static final WriteMode DEFAULT_WRITE_MODE = WriteMode.APPEND; private static final boolean DEFAULT_USE_QUEUED_WRITE_BUFFER = false; + private static final boolean DEFAULT_USE_NAMESPACE_INSERT = false; + private static final int DEFAULT_NAMESPACE_INSERT_PARALLELISM = 0; private static final int DEFAULT_QUEUE_DEPTH = 8; // Changed from 512 to 8192 for better write performance consistency with read path private static final int DEFAULT_BATCH_SIZE = 8192; @@ -85,6 +89,8 @@ public class LanceSparkWriteOptions implements Serializable { private final long maxBatchBytes; // Boxed: null means "unset" so lance-core uses its own default (1 GiB as of 6.0.0-beta.1). private final Long blobPackFileSizeThreshold; + private final boolean useNamespaceInsert; + private final int namespaceInsertParallelism; private final Map storageOptions; /** The namespace for credential vending. Transient as LanceNamespace is not serializable. */ @@ -110,6 +116,8 @@ private LanceSparkWriteOptions(Builder builder) { this.useLargeVarTypes = builder.useLargeVarTypes; this.maxBatchBytes = builder.maxBatchBytes; this.blobPackFileSizeThreshold = builder.blobPackFileSizeThreshold; + this.useNamespaceInsert = builder.useNamespaceInsert; + this.namespaceInsertParallelism = builder.namespaceInsertParallelism; this.storageOptions = new HashMap<>(builder.storageOptions); this.namespace = builder.namespace; this.tableId = builder.tableId; @@ -198,6 +206,14 @@ public Long getBlobPackFileSizeThreshold() { return blobPackFileSizeThreshold; } + public boolean isUseNamespaceInsert() { + return useNamespaceInsert; + } + + public int getNamespaceInsertParallelism() { + return namespaceInsertParallelism; + } + public Map getStorageOptions() { return storageOptions; } @@ -230,6 +246,8 @@ public Builder toBuilder() { .enableStableRowIds(enableStableRowIds) .useLargeVarTypes(useLargeVarTypes) .blobPackFileSizeThreshold(blobPackFileSizeThreshold) + .useNamespaceInsert(useNamespaceInsert) + .namespaceInsertParallelism(namespaceInsertParallelism) .storageOptions(storageOptions) .namespace(namespace) .tableId(tableId) @@ -315,6 +333,8 @@ public boolean equals(Object o) { && batchSize == that.batchSize && useLargeVarTypes == that.useLargeVarTypes && maxBatchBytes == that.maxBatchBytes + && useNamespaceInsert == that.useNamespaceInsert + && namespaceInsertParallelism == that.namespaceInsertParallelism && Objects.equals(datasetUri, that.datasetUri) && writeMode == that.writeMode && Objects.equals(maxRowsPerFile, that.maxRowsPerFile) @@ -344,6 +364,8 @@ public int hashCode() { useLargeVarTypes, maxBatchBytes, blobPackFileSizeThreshold, + useNamespaceInsert, + namespaceInsertParallelism, storageOptions, tableId, version); @@ -364,6 +386,8 @@ public static class Builder { private boolean useLargeVarTypes = DEFAULT_USE_LARGE_VAR_TYPES; private long maxBatchBytes = DEFAULT_MAX_BATCH_BYTES; private Long blobPackFileSizeThreshold; + private boolean useNamespaceInsert = DEFAULT_USE_NAMESPACE_INSERT; + private int namespaceInsertParallelism = DEFAULT_NAMESPACE_INSERT_PARALLELISM; private Map storageOptions = new HashMap<>(); private LanceNamespace namespace; private List tableId; @@ -440,6 +464,18 @@ public Builder blobPackFileSizeThreshold(Long blobPackFileSizeThreshold) { return this; } + public Builder useNamespaceInsert(boolean useNamespaceInsert) { + this.useNamespaceInsert = useNamespaceInsert; + return this; + } + + public Builder namespaceInsertParallelism(int namespaceInsertParallelism) { + Preconditions.checkArgument( + namespaceInsertParallelism >= 0, "namespace_insert_parallelism must be non-negative"); + this.namespaceInsertParallelism = namespaceInsertParallelism; + return this; + } + public Builder storageOptions(Map storageOptions) { this.storageOptions = new HashMap<>(storageOptions); return this; @@ -469,6 +505,8 @@ public Builder version(Long version) { */ public Builder fromOptions(Map options) { this.storageOptions = new HashMap<>(options); + this.storageOptions.remove(CONFIG_USE_NAMESPACE_INSERT); + this.storageOptions.remove(CONFIG_NAMESPACE_INSERT_PARALLELISM); if (options.containsKey(CONFIG_WRITE_MODE)) { this.writeMode = WriteMode.valueOf(options.get(CONFIG_WRITE_MODE).toUpperCase()); } @@ -512,6 +550,15 @@ public Builder fromOptions(Map options) { Preconditions.checkArgument(parsed > 0, "blob_pack_file_size_threshold must be positive"); this.blobPackFileSizeThreshold = parsed; } + if (options.containsKey(CONFIG_USE_NAMESPACE_INSERT)) { + this.useNamespaceInsert = Boolean.parseBoolean(options.get(CONFIG_USE_NAMESPACE_INSERT)); + } + if (options.containsKey(CONFIG_NAMESPACE_INSERT_PARALLELISM)) { + int parsedParallelism = Integer.parseInt(options.get(CONFIG_NAMESPACE_INSERT_PARALLELISM)); + Preconditions.checkArgument( + parsedParallelism >= 0, "namespace_insert_parallelism must be non-negative"); + this.namespaceInsertParallelism = parsedParallelism; + } return this; } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/AbstractBackfillWriter.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/AbstractBackfillWriter.java index 92214dd7c..53ceb1df2 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/AbstractBackfillWriter.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/AbstractBackfillWriter.java @@ -104,7 +104,8 @@ public void write(InternalRow record) throws IOException { VectorSchemaRoot.create( LanceArrowUtils.toArrowSchema(writerSchema, "UTC", false), allocator); org.lance.spark.arrow.LanceArrowWriter writer = - org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(data, writerSchema); + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.createWithoutResolver( + data, writerSchema); return new FragmentBuffer(data, writer); }); diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/NamespaceInsertBatchWrite.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/NamespaceInsertBatchWrite.java new file mode 100644 index 000000000..2215651d7 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/NamespaceInsertBatchWrite.java @@ -0,0 +1,481 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.write; + +import org.lance.Dataset; +import org.lance.memwal.MemWalIndexDetails; +import org.lance.memwal.ShardingField; +import org.lance.memwal.ShardingSpec; +import org.lance.namespace.LanceNamespace; +import org.lance.namespace.model.InsertIntoTableRequest; +import org.lance.spark.LanceRuntime; +import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.sharding.SparkLanceShardingUtils; +import org.lance.spark.utils.BlobReferenceResolver; +import org.lance.spark.utils.BlobSourceContext; +import org.lance.spark.utils.Utils; + +import com.google.common.base.Preconditions; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.DataWriter; +import org.apache.spark.sql.connector.write.DataWriterFactory; +import org.apache.spark.sql.connector.write.PhysicalWriteInfo; +import org.apache.spark.sql.connector.write.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.LanceArrowUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** Writes Spark rows through the Lance Namespace insert API. */ +public class NamespaceInsertBatchWrite implements BatchWrite { + private static final Logger LOG = LoggerFactory.getLogger(NamespaceInsertBatchWrite.class); + + private final StructType schema; + private final LanceSparkWriteOptions writeOptions; + private final Map initialStorageOptions; + private final String namespaceImpl; + private final Map namespaceProperties; + private final List tableId; + private final ShardingSpec shardingSpec; + private final Map blobSourceContexts; + + public NamespaceInsertBatchWrite( + StructType schema, + LanceSparkWriteOptions writeOptions, + Map initialStorageOptions, + String namespaceImpl, + Map namespaceProperties, + List tableId, + ShardingSpec shardingSpec, + Map blobSourceContexts) { + this.schema = schema; + this.writeOptions = writeOptions; + this.initialStorageOptions = initialStorageOptions; + this.namespaceImpl = namespaceImpl; + this.namespaceProperties = namespaceProperties; + this.tableId = tableId; + this.shardingSpec = shardingSpec; + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; + } + + @Override + public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { + return new WriterFactory( + schema, + writeOptions, + initialStorageOptions, + namespaceImpl, + namespaceProperties, + tableId, + shardingSpec, + blobSourceContexts); + } + + @Override + public boolean useCommitCoordinator() { + return false; + } + + @Override + public void commit(WriterCommitMessage[] messages) { + // The namespace insert API commits every request from executor tasks. + } + + @Override + public void abort(WriterCommitMessage[] messages) { + // No compensating transaction is available from the current namespace insert API. + } + + @Override + public String toString() { + return String.format("NamespaceInsertBatchWrite(tableId=%s)", tableId); + } + + static final class TaskCommit implements WriterCommitMessage { + private final long rowsInserted; + private final int insertRequests; + + TaskCommit(long rowsInserted, int insertRequests) { + this.rowsInserted = rowsInserted; + this.insertRequests = insertRequests; + } + + long rowsInserted() { + return rowsInserted; + } + + int insertRequests() { + return insertRequests; + } + } + + static class WriterFactory implements DataWriterFactory { + private final StructType schema; + private final LanceSparkWriteOptions writeOptions; + private final Map initialStorageOptions; + private final String namespaceImpl; + private final Map namespaceProperties; + private final List tableId; + private final ShardingSpecSnapshot shardingSpec; + private final Map blobSourceContexts; + + WriterFactory( + StructType schema, + LanceSparkWriteOptions writeOptions, + Map initialStorageOptions, + String namespaceImpl, + Map namespaceProperties, + List tableId, + ShardingSpec shardingSpec, + Map blobSourceContexts) { + this.schema = schema; + this.writeOptions = writeOptions; + this.initialStorageOptions = initialStorageOptions; + this.namespaceImpl = namespaceImpl; + this.namespaceProperties = namespaceProperties; + this.tableId = tableId; + this.shardingSpec = + SparkLanceShardingUtils.isEmpty(shardingSpec) + ? null + : ShardingSpecSnapshot.from(shardingSpec); + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; + } + + @Override + public DataWriter createWriter(int partitionId, long taskId) { + Preconditions.checkArgument( + namespaceImpl != null && tableId != null, + "namespace insert writes require a namespace-backed table"); + + ShardingBatchKeyEvaluator shardingKeyEvaluator = + shardingSpec == null + ? null + : new ShardingBatchKeyEvaluator(schema, writeOptions, shardingBinding()); + BlobReferenceResolver resolver = new BlobReferenceResolver(blobSourceContexts); + LanceNamespace namespace = + LanceRuntime.getOrCreateNamespace(namespaceImpl, namespaceProperties); + return new NamespaceInsertDataWriter( + schema, writeOptions, namespace, tableId, shardingKeyEvaluator, resolver); + } + + private ShardingBatchKeyEvaluator.ShardingBinding shardingBinding() { + try (Dataset dataset = + Utils.openDatasetBuilder( + writeOptions.toBuilder() + .storageOptions( + LanceRuntime.mergeStorageOptions( + writeOptions.getStorageOptions(), initialStorageOptions)) + .build()) + .build()) { + Optional details = dataset.memWalIndexDetails(); + if (details.isPresent() && !details.get().shardingSpecs().isEmpty()) { + return new ShardingBatchKeyEvaluator.ShardingBinding( + details.get().shardingSpecs().get(0), dataset.getLanceSchema()); + } + } catch (RuntimeException e) { + if (shardingSpec.hasSourceIds()) { + throw e; + } + LOG.warn("Falling back to in-memory sharding metadata for namespace insert write", e); + } + return new ShardingBatchKeyEvaluator.ShardingBinding(shardingSpec.toShardingSpec(), null); + } + } + + static class NamespaceInsertDataWriter implements DataWriter { + private final StructType sparkSchema; + private final LanceSparkWriteOptions writeOptions; + private final LanceNamespace namespace; + private final List tableId; + private final ShardingBatchKeyEvaluator shardingKeyEvaluator; + private final BlobReferenceResolver blobResolver; + private final Schema arrowSchema; + + private BufferAllocator batchAllocator; + private VectorSchemaRoot root; + private org.lance.spark.arrow.LanceArrowWriter arrowWriter; + private int rowCount; + private long rowsInserted; + private int insertRequests; + private Object lastKey; + private boolean hasRowsInCurrentBatch; + + NamespaceInsertDataWriter( + StructType sparkSchema, + LanceSparkWriteOptions writeOptions, + LanceNamespace namespace, + List tableId, + ShardingBatchKeyEvaluator shardingKeyEvaluator, + BlobReferenceResolver blobResolver) { + this.sparkSchema = sparkSchema; + this.writeOptions = writeOptions; + this.namespace = Objects.requireNonNull(namespace, "namespace"); + this.tableId = new ArrayList<>(Objects.requireNonNull(tableId, "tableId")); + this.shardingKeyEvaluator = shardingKeyEvaluator; + this.blobResolver = blobResolver; + this.arrowSchema = + LanceArrowUtils.toArrowSchema( + sparkSchema, "UTC", false, writeOptions.isUseLargeVarTypes()); + allocateBatch(); + } + + @Override + public void write(InternalRow record) throws IOException { + if (shardingKeyEvaluator != null) { + shardingKeyEvaluator.write(record, this::writePartitionedRow); + return; + } + writeRow(record); + } + + private void writePartitionedRow(InternalRow row, Object key) throws IOException { + if (!hasRowsInCurrentBatch) { + lastKey = key; + } else if (!Objects.equals(key, lastKey)) { + flush(); + lastKey = key; + } + writeRow(row); + hasRowsInCurrentBatch = rowCount > 0; + } + + private void writeRow(InternalRow row) throws IOException { + if (rowCount >= writeOptions.getBatchSize()) { + flush(); + } + arrowWriter.write(row); + rowCount++; + + long currentBatchBytes = + batchAllocator.getAllocatedMemory() + + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.estimatedBufferedBytes( + arrowWriter); + if (rowCount >= writeOptions.getBatchSize() + || (rowCount > 0 && currentBatchBytes >= writeOptions.getMaxBatchBytes())) { + flush(); + } + } + + private void flush() throws IOException { + if (rowCount == 0) { + return; + } + + arrowWriter.finish(); + root.setRowCount(rowCount); + byte[] requestData = serializeCurrentBatch(); + InsertIntoTableRequest request = new InsertIntoTableRequest().id(tableId).mode("append"); + try { + namespace.insertIntoTable(request, requestData); + } catch (RuntimeException e) { + throw new IOException("Failed to insert rows through Lance namespace", e); + } + + rowsInserted += rowCount; + insertRequests++; + hasRowsInCurrentBatch = false; + allocateBatch(); + } + + private byte[] serializeCurrentBatch() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, out)) { + streamWriter.start(); + streamWriter.writeBatch(); + streamWriter.end(); + } finally { + closeBatch(); + } + return out.toByteArray(); + } + + private void allocateBatch() { + closeBatch(); + batchAllocator = + LanceRuntime.allocator() + .newChildAllocator("namespace-insert-write-batch", 0, Long.MAX_VALUE); + root = VectorSchemaRoot.create(arrowSchema, batchAllocator); + arrowWriter = + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.createWithResolver( + root, sparkSchema, blobResolver); + rowCount = 0; + } + + private void closeBatch() { + if (root != null) { + root.close(); + root = null; + arrowWriter = null; + } + if (batchAllocator != null) { + batchAllocator.close(); + batchAllocator = null; + } + } + + @Override + public WriterCommitMessage commit() throws IOException { + if (shardingKeyEvaluator != null) { + shardingKeyEvaluator.flush(this::writePartitionedRow); + } + flush(); + return new TaskCommit(rowsInserted, insertRequests); + } + + @Override + public void abort() throws IOException { + close(); + } + + @Override + public void close() throws IOException { + IOException failure = null; + try { + if (shardingKeyEvaluator != null) { + shardingKeyEvaluator.close(); + } + } catch (RuntimeException e) { + failure = new IOException("Failed to close sharding evaluator", e); + } + try { + closeBatch(); + } catch (RuntimeException e) { + if (failure == null) { + failure = new IOException("Failed to close Arrow batch", e); + } else { + failure.addSuppressed(e); + } + } + try { + blobResolver.close(); + } catch (RuntimeException e) { + if (failure == null) { + failure = new IOException("Failed to close blob resolver", e); + } else { + failure.addSuppressed(e); + } + } + try { + if (namespace instanceof Closeable) { + ((Closeable) namespace).close(); + } + } catch (IOException e) { + if (failure == null) { + failure = e; + } else { + failure.addSuppressed(e); + } + } + if (failure != null) { + throw failure; + } + } + } + + private static final class ShardingSpecSnapshot implements Serializable { + private static final long serialVersionUID = 1L; + + private final int specId; + private final List fields; + + private ShardingSpecSnapshot(int specId, List fields) { + this.specId = specId; + this.fields = fields; + } + + private static ShardingSpecSnapshot from(ShardingSpec spec) { + List fields = new ArrayList<>(); + for (ShardingField field : spec.fields()) { + fields.add(ShardingFieldSnapshot.from(field)); + } + return new ShardingSpecSnapshot(spec.specId(), fields); + } + + private ShardingSpec toShardingSpec() { + List restored = new ArrayList<>(); + for (ShardingFieldSnapshot field : fields) { + restored.add(field.toShardingField()); + } + return new ShardingSpec(specId, restored); + } + + private boolean hasSourceIds() { + for (ShardingFieldSnapshot field : fields) { + if (!field.sourceIds.isEmpty()) { + return true; + } + } + return false; + } + } + + private static final class ShardingFieldSnapshot implements Serializable { + private static final long serialVersionUID = 1L; + + private final String fieldId; + private final List sourceIds; + private final String transform; + private final String expression; + private final String resultType; + private final Map parameters; + + private ShardingFieldSnapshot( + String fieldId, + List sourceIds, + String transform, + String expression, + String resultType, + Map parameters) { + this.fieldId = fieldId; + this.sourceIds = sourceIds; + this.transform = transform; + this.expression = expression; + this.resultType = resultType; + this.parameters = parameters; + } + + private static ShardingFieldSnapshot from(ShardingField field) { + return new ShardingFieldSnapshot( + field.fieldId(), + new ArrayList<>(field.sourceIds()), + field.transform().orElse(null), + field.expression().orElse(null), + field.resultType(), + new HashMap<>(field.parameters())); + } + + private ShardingField toShardingField() { + return new ShardingField(fieldId, sourceIds, transform, expression, resultType, parameters); + } + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java index de6c4a769..005b710da 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java @@ -236,7 +236,8 @@ private void allocateNewBatch() { throw e; } currentArrowWriter = - org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(currentBatch, sparkSchema, resolver); + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.createWithResolver( + currentBatch, sparkSchema, resolver); currentBatchRowCount.set(0); } @@ -248,7 +249,9 @@ private boolean isBatchFullByBytes() { // Include bytes buffered outside the vector (unresolved blob references resolve to far larger // bytes on finish); sizing only the ~200-byte references would let the batch grow until // resolution OOMs the executor. - return currentBatchAllocator.getAllocatedMemory() + currentArrowWriter.estimatedBufferedBytes() + return currentBatchAllocator.getAllocatedMemory() + + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.estimatedBufferedBytes( + currentArrowWriter) >= maxBatchBytes; } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java index 7b7e6f0d4..243a41768 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java @@ -191,7 +191,8 @@ public void write(InternalRow row) { // batch grow until resolution OOMs the executor. currentBatchBytes = (this.allocator.getAllocatedMemory() - batchStartBytes) - + arrowWriter.estimatedBufferedBytes(); + + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.estimatedBufferedBytes( + arrowWriter); count++; if (isBatchFull()) { @@ -232,7 +233,8 @@ public void prepareLoadNextBatch() throws IOException { } root.setRowCount(0); arrowWriter = - org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema, resolver); + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.createWithResolver( + root, sparkSchema, resolver); lock.lock(); try { count = 0; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ShardingBatchKeyEvaluator.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ShardingBatchKeyEvaluator.java index affd4865c..f73c3003d 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ShardingBatchKeyEvaluator.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ShardingBatchKeyEvaluator.java @@ -143,7 +143,9 @@ private void resetBatch() { private void allocateBatch() { root = VectorSchemaRoot.create(arrowSchema, allocator); root.allocateNew(); - arrowWriter = org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema); + arrowWriter = + org.lance.spark.arrow.LanceArrowWriteBridge$.MODULE$.createWithoutResolver( + root, sparkSchema); } private void closeBatch() { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java index 7e5f6bd30..a890bbae7 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.distributions.Distribution; import org.apache.spark.sql.connector.distributions.Distributions; +import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.SortOrder; import org.apache.spark.sql.connector.write.BatchWrite; @@ -127,6 +128,10 @@ private ShardingSpec shardingSpec() { public Distribution requiredDistribution() { ShardingSpec spec = shardingSpec(); if (SparkLanceShardingUtils.isEmpty(spec)) { + if (useNamespaceInsertParallelism()) { + NamedReference firstColumn = Expressions.column(schema.fields()[0].name()); + return Distributions.clustered(new NamedReference[] {firstColumn}); + } return Distributions.unspecified(); } NamedReference[] refs = @@ -147,9 +152,28 @@ public SortOrder[] requiredOrdering() { .toArray(SortOrder[]::new); } + @Override + public int requiredNumPartitions() { + if (useNamespaceInsertParallelism()) { + return writeOptions.getNamespaceInsertParallelism(); + } + return RequiresDistributionAndOrdering.super.requiredNumPartitions(); + } + @Override public BatchWrite toBatch() { ShardingSpec spec = shardingSpec(); + if (useNamespaceInsertBatchWrite()) { + return new NamespaceInsertBatchWrite( + schema, + writeOptions, + initialStorageOptions, + namespaceImpl, + namespaceProperties, + tableId, + spec, + blobSourceContexts); + } return new LanceBatchWrite( schema, writeOptions, @@ -164,6 +188,21 @@ public BatchWrite toBatch() { blobSourceContexts); } + private boolean useNamespaceInsertBatchWrite() { + return writeOptions.isUseNamespaceInsert() + && stagedCommit == null + && !overwrite + && !writeOptions.isOverwrite() + && namespaceImpl != null + && tableId != null; + } + + private boolean useNamespaceInsertParallelism() { + return useNamespaceInsertBatchWrite() + && writeOptions.getNamespaceInsertParallelism() > 0 + && schema.fields().length > 0; + } + @Override public StreamingWrite toStreaming() { throw new UnsupportedOperationException(); @@ -242,22 +281,7 @@ public Write build() { LanceSparkWriteOptions options = !overwrite ? writeOptions - : LanceSparkWriteOptions.builder() - .storageOptions(writeOptions.getStorageOptions()) - .namespace(writeOptions.getNamespace()) - .tableId(writeOptions.getTableId()) - .batchSize(writeOptions.getBatchSize()) - .datasetUri(writeOptions.getDatasetUri()) - .fileFormatVersion(writeOptions.getFileFormatVersion()) - .maxBytesPerFile(writeOptions.getMaxBytesPerFile()) - .maxRowsPerFile(writeOptions.getMaxRowsPerFile()) - .maxRowsPerGroup(writeOptions.getMaxRowsPerGroup()) - .queueDepth(writeOptions.getQueueDepth()) - .useQueuedWriteBuffer(writeOptions.isUseQueuedWriteBuffer()) - .useLargeVarTypes(writeOptions.isUseLargeVarTypes()) - .enableStableRowIds(writeOptions.getEnableStableRowIds()) - .writeMode(WriteParams.WriteMode.OVERWRITE) - .build(); + : writeOptions.toBuilder().writeMode(WriteParams.WriteMode.OVERWRITE).build(); return new SparkWrite( schema, diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index 33580df1b..63be0d0b8 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -155,6 +155,21 @@ object LanceArrowWriter { } } +object LanceArrowWriteBridge { + // Java write paths use this bridge to avoid Scala overload/accessor differences across + // Spark/Scala cross-builds. + def createWithResolver( + root: VectorSchemaRoot, + sparkSchema: StructType, + resolver: BlobReferenceResolver): LanceArrowWriter = + LanceArrowWriter.create(root, sparkSchema, resolver) + + def createWithoutResolver(root: VectorSchemaRoot, sparkSchema: StructType): LanceArrowWriter = + LanceArrowWriter.create(root, sparkSchema) + + def estimatedBufferedBytes(writer: LanceArrowWriter): Long = writer.estimatedBufferedBytes +} + /** * Writer that converts Spark InternalRow data to Arrow format. * Copied from Spark's ArrowWriter to support custom field writers for FixedSizeList. diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkWriteOptionsTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkWriteOptionsTest.java index ac2d68341..ad760e42d 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkWriteOptionsTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkWriteOptionsTest.java @@ -23,6 +23,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; /** Tests for {@link LanceSparkWriteOptions}. */ @@ -150,4 +151,46 @@ public void testFromOptionsWithAllWriteSettings() { assertTrue(writeOptions.getEnableStableRowIds()); assertEquals(Long.valueOf(2147483648L), writeOptions.getBlobPackFileSizeThreshold()); } + + @Test + public void testNamespaceInsertOptionsParsedFromOptions() { + final Map options = new HashMap<>(); + options.put("path", TEMP_URL); + options.put("use_namespace_insert", "true"); + options.put("namespace_insert_parallelism", "4"); + + final LanceSparkWriteOptions writeOptions = + LanceSparkWriteOptions.builder().datasetUri(TEMP_URL).fromOptions(options).build(); + + assertTrue(writeOptions.isUseNamespaceInsert()); + assertEquals(4, writeOptions.getNamespaceInsertParallelism()); + assertFalse(writeOptions.getStorageOptions().containsKey("use_namespace_insert")); + assertFalse(writeOptions.getStorageOptions().containsKey("namespace_insert_parallelism")); + } + + @Test + public void testNamespaceInsertOptionsCopiedByToBuilder() { + final LanceSparkWriteOptions writeOptions = + LanceSparkWriteOptions.builder() + .datasetUri(TEMP_URL) + .useNamespaceInsert(true) + .namespaceInsertParallelism(8) + .build() + .toBuilder() + .build(); + + assertTrue(writeOptions.isUseNamespaceInsert()); + assertEquals(8, writeOptions.getNamespaceInsertParallelism()); + } + + @Test + public void testNamespaceInsertParallelismMustBeNonNegative() { + final Map options = new HashMap<>(); + options.put("path", TEMP_URL); + options.put("namespace_insert_parallelism", "-1"); + + assertThrows( + IllegalArgumentException.class, + () -> LanceSparkWriteOptions.builder().datasetUri(TEMP_URL).fromOptions(options).build()); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java index 39e2e79af..fdb30f926 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java @@ -16,11 +16,19 @@ import org.lance.Dataset; import org.lance.WriteParams; import org.lance.memwal.InitializeMemWalParams; +import org.lance.namespace.DirectoryNamespace; +import org.lance.namespace.model.CreateNamespaceRequest; +import org.lance.namespace.model.CreateTableRequest; +import org.lance.namespace.model.CreateTableResponse; import org.lance.spark.LanceSparkWriteOptions; import org.lance.spark.TestUtils; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -36,9 +44,12 @@ import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; +import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -72,6 +83,44 @@ private String createIdentityShardedDataset(String name) { return datasetUri; } + private byte[] createNamespaceTableData(BufferAllocator allocator) throws IOException { + try (VectorSchemaRoot root = VectorSchemaRoot.create(ARROW_SCHEMA, allocator)) { + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + idVector.allocateNew(2); + nameVector.allocateNew(2); + idVector.set(0, 1); + idVector.set(1, 2); + nameVector.set(0, "Alice".getBytes()); + nameVector.set(1, "Bob".getBytes()); + idVector.setValueCount(2); + nameVector.setValueCount(2); + root.setRowCount(2); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + return out.toByteArray(); + } + } + + private String createNamespaceTable(String namespaceName, String tableName) throws IOException { + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + DirectoryNamespace namespace = new DirectoryNamespace()) { + namespace.initialize(Collections.singletonMap("root", tempDir.toString()), allocator); + namespace.createNamespace( + new CreateNamespaceRequest().id(Collections.singletonList(namespaceName))); + CreateTableResponse response = + namespace.createTable( + new CreateTableRequest().id(Arrays.asList(namespaceName, tableName)), + createNamespaceTableData(allocator)); + return response.getLocation(); + } + } + private SparkWrite.SparkWriteBuilder createBuilder(String datasetUri) { LanceSparkWriteOptions writeOptions = LanceSparkWriteOptions.from(datasetUri); return new SparkWrite.SparkWriteBuilder( @@ -176,6 +225,28 @@ private SparkWrite createWrite(String datasetUri) { return (SparkWrite) builder.build(); } + private SparkWrite createNamespaceInsertWrite( + String datasetUri, List tableId, int parallelism) { + LanceSparkWriteOptions writeOptions = + LanceSparkWriteOptions.builder() + .datasetUri(datasetUri) + .useNamespaceInsert(true) + .namespaceInsertParallelism(parallelism) + .build(); + SparkWrite.SparkWriteBuilder builder = + new SparkWrite.SparkWriteBuilder( + SPARK_SCHEMA, + writeOptions, + Collections.emptyMap(), + "dir", + Collections.singletonMap("root", tempDir.toString()), + tableId, + false, + null, + Collections.emptyMap()); + return (SparkWrite) builder.build(); + } + @Test public void testRequiredDistributionWithMemWalSharding(TestInfo testInfo) { String datasetUri = createIdentityShardedDataset(testInfo.getTestMethod().get().getName()); @@ -197,6 +268,21 @@ public void testRequiredDistributionWithoutSharding(TestInfo testInfo) { assertFalse(dist instanceof ClusteredDistribution); } + @Test + public void testNamespaceInsertParallelismWithoutShardingUsesClusteredDistribution( + TestInfo testInfo) throws IOException { + String namespaceName = "workspace"; + String tableName = testInfo.getTestMethod().get().getName(); + String datasetUri = createNamespaceTable(namespaceName, tableName); + SparkWrite write = + createNamespaceInsertWrite(datasetUri, Arrays.asList(namespaceName, tableName), 3); + + Distribution dist = write.requiredDistribution(); + assertInstanceOf(ClusteredDistribution.class, dist); + assertEquals(1, ((ClusteredDistribution) dist).clustering().length); + assertEquals(3, write.requiredNumPartitions()); + } + @Test public void testRequiredOrderingWithMemWalSharding(TestInfo testInfo) { String datasetUri = createIdentityShardedDataset(testInfo.getTestMethod().get().getName());