From 2cd86275406c4e1685709fa2bcff5db933c0f5cd Mon Sep 17 00:00:00 2001 From: beinan Date: Wed, 20 May 2026 03:27:08 +0000 Subject: [PATCH 1/4] feat: preserve blob data through Spark shuffle during JOIN + INSERT INTO When blob columns flow through Spark's shuffle (e.g., INSERT INTO target SELECT ... FROM source_a JOIN source_b), the actual blob data was previously lost. This PR introduces a blob reference mechanism that preserves blob data through shuffle without materializing the full blob bytes. Read side: blob columns serialize compact ~100-byte BlobReference descriptors (LANCEREF magic + dataset URI + column name + row address) instead of empty bytes. The scanner requests _rowaddr when blob columns are present and strips it from the output. Write side: LargeBinaryWriter detects BlobReference headers, buffers them during setValue(), then batch-resolves all references in finish() via a single takeBlobs() call per (dataset, column) group. Dataset instances are cached across batches for the task lifetime. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../java/org/lance/spark/BlobJoinTest.java | 19 + .../java/org/lance/spark/BlobJoinTest.java | 19 + .../LanceFragmentColumnarBatchScanner.java | 42 +- .../spark/internal/LanceFragmentScanner.java | 94 ++-- .../org/lance/spark/utils/BlobReference.java | 167 +++++++ .../spark/utils/BlobReferenceResolver.java | 159 +++++++ .../spark/vectorized/BlobStructAccessor.java | 49 ++ .../vectorized/LanceArrowColumnVector.java | 2 +- .../lance/spark/arrow/LanceArrowWriter.scala | 39 +- .../lance/spark/BaseBlobCreateTableTest.java | 12 +- .../org/lance/spark/BaseBlobJoinTest.java | 437 ++++++++++++++++++ .../lance/spark/utils/BlobReferenceTest.java | 94 ++++ 12 files changed, 1092 insertions(+), 41 deletions(-) create mode 100644 lance-spark-3.4_2.12/src/test/java/org/lance/spark/BlobJoinTest.java create mode 100644 lance-spark-3.5_2.12/src/test/java/org/lance/spark/BlobJoinTest.java create mode 100644 lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java create mode 100644 lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java create mode 100644 lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java create mode 100644 lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java diff --git a/lance-spark-3.4_2.12/src/test/java/org/lance/spark/BlobJoinTest.java b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/BlobJoinTest.java new file mode 100644 index 000000000..201a81d12 --- /dev/null +++ b/lance-spark-3.4_2.12/src/test/java/org/lance/spark/BlobJoinTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark; + +/** Concrete implementation of BaseBlobJoinTest for Spark 3.4. */ +public class BlobJoinTest extends BaseBlobJoinTest { + // All test methods are inherited from BaseBlobJoinTest +} diff --git a/lance-spark-3.5_2.12/src/test/java/org/lance/spark/BlobJoinTest.java b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/BlobJoinTest.java new file mode 100644 index 000000000..df7707285 --- /dev/null +++ b/lance-spark-3.5_2.12/src/test/java/org/lance/spark/BlobJoinTest.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark; + +/** Concrete implementation of BaseBlobJoinTest for Spark 3.5. */ +public class BlobJoinTest extends BaseBlobJoinTest { + // All test methods are inherited from BaseBlobJoinTest +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java index 05dcb3555..4bc96ab58 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentColumnarBatchScanner.java @@ -19,6 +19,7 @@ import org.lance.spark.vectorized.LanceArrowColumnVector; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; @@ -36,6 +37,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; public class LanceFragmentColumnarBatchScanner implements AutoCloseable { private final LanceFragmentScanner fragmentScanner; @@ -117,6 +119,10 @@ private List buildSparkOrderedVectors( actualFields.put(rootVectors.get(i).getField().getName(), rootVectors.get(i)); } + // Extract row addresses for blob reference support + Set blobColumnNames = fragmentScanner.getBlobColumnNames(); + long[] rowAddresses = extractRowAddresses(rootVectors, blobColumnNames, root.getRowCount()); + List fieldVectors = new ArrayList<>(schema.size()); StructField[] fields = schema.fields(); for (StructField field : fields) { @@ -150,12 +156,46 @@ private List buildSparkOrderedVectors( throw new IllegalStateException( "Lance scan did not return expected field '" + fieldName + "'"); } - fieldVectors.add(new LanceArrowColumnVector(vector)); + LanceArrowColumnVector colVec = new LanceArrowColumnVector(vector); + + // Set blob reference context so getBinary() produces blob references + if (rowAddresses != null && blobColumnNames.contains(fieldName)) { + BlobStructAccessor blobAccessor = colVec.getBlobStructAccessor(); + if (blobAccessor != null) { + blobAccessor.setBlobReferenceContext( + fragmentScanner.getDatasetUri(), fieldName, rowAddresses); + } + } + + fieldVectors.add(colVec); } } return fieldVectors; } + /** + * Extracts row addresses from the {@code _rowaddr} column appended by the native scanner. Row + * addresses are needed to construct blob references that allow the write side to fetch actual + * blob bytes from the source dataset. + */ + private long[] extractRowAddresses( + List rootVectors, Set blobColumnNames, int rowCount) { + if (blobColumnNames.isEmpty()) { + return null; + } + for (FieldVector fv : rootVectors) { + if (LanceConstant.ROW_ADDRESS.equals(fv.getField().getName()) && fv instanceof UInt8Vector) { + UInt8Vector rowAddrVector = (UInt8Vector) fv; + long[] rowAddresses = new long[rowCount]; + for (int i = 0; i < rowCount; i++) { + rowAddresses[i] = rowAddrVector.get(i); + } + return rowAddresses; + } + } + return null; + } + // Virtual column vector for blob position private static class BlobPositionColumnVector extends ColumnVector { private final BlobStructAccessor accessor; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java index cf620e881..f2cecea28 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java @@ -21,6 +21,7 @@ import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.read.LanceInputPartition; +import org.lance.spark.utils.BlobUtils; import org.lance.spark.utils.Utils; import org.apache.arrow.vector.ipc.ArrowReader; @@ -29,18 +30,29 @@ import java.io.IOException; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; public class LanceFragmentScanner implements AutoCloseable { private final Dataset dataset; private final LanceScanner scanner; private final int fragmentId; - private final boolean withFragemtId; + private final boolean withFragmentId; private final LanceInputPartition inputPartition; private final long datasetOpenTimeNs; private final long scannerCreateTimeNs; + /** + * Whether the scanner requested _rowaddr for blob reference support. When true, the _rowaddr + * column in the Arrow batch was implicitly added and should be stripped from user-visible output. + */ + private final boolean withRowAddrForBlobs; + + /** The names of blob columns in the projected schema. */ + private final Set blobColumnNames; + private LanceFragmentScanner( Dataset dataset, LanceScanner scanner, @@ -48,14 +60,18 @@ private LanceFragmentScanner( boolean withFragmentId, LanceInputPartition inputPartition, long datasetOpenTimeNs, - long scannerCreateTimeNs) { + long scannerCreateTimeNs, + boolean withRowAddrForBlobs, + Set blobColumnNames) { this.dataset = dataset; this.scanner = scanner; this.fragmentId = fragmentId; - this.withFragemtId = withFragmentId; + this.withFragmentId = withFragmentId; this.inputPartition = inputPartition; this.datasetOpenTimeNs = datasetOpenTimeNs; this.scannerCreateTimeNs = scannerCreateTimeNs; + this.withRowAddrForBlobs = withRowAddrForBlobs; + this.blobColumnNames = blobColumnNames; } public static LanceFragmentScanner create(int fragmentId, LanceInputPartition inputPartition) { @@ -63,17 +79,6 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in LanceScanner lanceScanner = null; try { LanceSparkReadOptions readOptions = inputPartition.getReadOptions(); - // Optionally rebuild the namespace client on the executor so the dataset open routes through - // Utils.OpenDatasetBuilder's namespaceClient branch. This preserves the storage options - // provider on the Rust side, which refreshes short-lived vended credentials (e.g. STS - // tokens) during long-running scans. The price is an eager describeTable() RPC against the - // namespace on every fragment open. - // - // For catalogs whose backing service authenticates per-call (e.g. Hive Metastore over - // Kerberos) executors typically lack a TGT and that RPC fails with "GSS initiate failed". - // Setting LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH=false makes executors - // skip the rebuild and open the dataset by URI using the initialStorageOptions the driver - // already obtained, at the cost of losing the Rust-side credential refresh callback. if (inputPartition.getNamespaceImpl() != null && readOptions.isExecutorCredentialRefresh()) { if (LanceRuntime.useNamespaceOnWorkers(inputPartition.getNamespaceImpl())) { readOptions.setNamespace( @@ -97,18 +102,27 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in fragmentId, readOptions.getDatasetUri(), readOptions.getVersion())); } ScanOptions.Builder scanOptions = new ScanOptions.Builder(); + + // Detect blob columns in the schema + Set blobColumnNames = getBlobColumnNames(inputPartition.getSchema()); + boolean hasBlobColumns = !blobColumnNames.isEmpty(); + List projectedColumns = getColumnNames(inputPartition.getSchema()); if (projectedColumns.isEmpty() && inputPartition.getSchema().isEmpty()) { - // Lance requires at least one projected column. Use _rowid as a lightweight - // sentinel so the scanner still returns the correct row count (e.g. SELECT 1). scanOptions.withRowId(true); } if (hasField(inputPartition.getSchema(), LanceConstant.ROW_ID)) { scanOptions.withRowId(true); } - if (hasField(inputPartition.getSchema(), LanceConstant.ROW_ADDRESS)) { + + // Request _rowaddr when blob columns are present so we can build blob references. + boolean userRequestedRowAddr = + hasField(inputPartition.getSchema(), LanceConstant.ROW_ADDRESS); + boolean withRowAddrForBlobs = hasBlobColumns && !userRequestedRowAddr; + if (hasBlobColumns || userRequestedRowAddr) { scanOptions.withRowAddress(true); } + scanOptions.columns(projectedColumns); if (inputPartition.getWhereCondition().isPresent()) { scanOptions.filter(inputPartition.getWhereCondition().get()); @@ -116,12 +130,6 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in scanOptions.batchSize(readOptions.getBatchSize()); if (readOptions.getNearest() != null) { scanOptions.nearest(readOptions.getNearest()); - // We strictly set `prefilter = true` here to ensure query correctness. - // This is necessary due to the combination of two factors: - // 1. Spark currently performs the vector search by individually scanning each fragment. - // 2. Lance mandates that `prefilter` must be enabled for fragmented vector queries. - // If Spark's execution model or Lance's search functionality changes in the future, - // we need to revisit this. scanOptions.prefilter(true); } if (inputPartition.getLimit().isPresent()) { @@ -145,7 +153,9 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in withFragmentId, inputPartition, dsOpenTimeNs, - scanCreateTimeNs); + scanCreateTimeNs, + withRowAddrForBlobs, + blobColumnNames); } catch (Throwable throwable) { if (lanceScanner != null) { try { @@ -211,8 +221,8 @@ public int fragmentId() { return fragmentId; } - public boolean withFragemtId() { - return withFragemtId; + public boolean withFragmentId() { + return withFragmentId; } public LanceInputPartition getInputPartition() { @@ -227,19 +237,37 @@ public long getScannerCreateTimeNs() { return scannerCreateTimeNs; } - /** - * Builds the projection column list for the scanner. Row ID and row address are requested through - * explicit scan flags so Lance computes them from the active fragment metadata instead of reading - * them as regular columns. - */ + /** Whether the scanner implicitly requested _rowaddr for blob reference support. */ + public boolean isWithRowAddrForBlobs() { + return withRowAddrForBlobs; + } + + /** Returns the blob column names in the projected schema. */ + public Set getBlobColumnNames() { + return blobColumnNames; + } + + /** Returns the dataset URI for blob references. */ + public String getDatasetUri() { + return inputPartition.getReadOptions().getDatasetUri(); + } + + private static Set getBlobColumnNames(StructType schema) { + Set blobColumns = new HashSet<>(); + for (StructField field : schema.fields()) { + if (BlobUtils.isBlobSparkField(field)) { + blobColumns.add(field.name()); + } + } + return blobColumns; + } + private static List getColumnNames(StructType schema) { - // Collect all field names in the schema for quick lookup java.util.Set schemaFields = new java.util.HashSet<>(); for (StructField field : schema.fields()) { schemaFields.add(field.name()); } - // Regular data columns (exclude all special/metadata columns) List columns = Arrays.stream(schema.fields()) .map(StructField::name) diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java new file mode 100644 index 000000000..c4cb33e66 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.utils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * A compact serializable reference to a blob stored in a Lance dataset. + * + *

When a blob column is read from a Lance table and flows through Spark's shuffle (e.g. during a + * JOIN + INSERT INTO), the actual blob bytes are NOT materialized. Instead, a small BlobReference + * (~200 bytes) is serialized as the binary value. The write side detects these references, opens + * the source dataset, fetches the actual blob bytes via {@code Dataset.takeBlobs()}, and writes + * them to the target table. + * + *

Wire format: + * + *

+ *   [8 bytes] magic header (LANCEREF)
+ *   [1 byte]  version
+ *   [2+N bytes] datasetUri (length-prefixed UTF-8)
+ *   [2+N bytes] columnName (length-prefixed UTF-8)
+ *   [8 bytes] rowAddress
+ * 
+ */ +public class BlobReference { + + /** 8-byte magic header to identify a serialized BlobReference. */ + public static final byte[] MAGIC = {'L', 'A', 'N', 'C', 'E', 'R', 'E', 'F'}; + + /** Min byte length: magic(8) + version(1) + two empty strings(2+2) + rowAddress(8). */ + private static final int MIN_SIZE = MAGIC.length + 1 + 2 + 2 + 8; + + private static final byte VERSION = 1; + + private final String datasetUri; + private final String columnName; + private final long rowAddress; + + public BlobReference(String datasetUri, String columnName, long rowAddress) { + this.datasetUri = datasetUri; + this.columnName = columnName; + this.rowAddress = rowAddress; + } + + /** + * Checks whether a byte array is a valid serialized BlobReference by verifying the magic header, + * version, and that the encoded string lengths are consistent with the total size. + */ + public static boolean isBlobReference(byte[] bytes) { + if (bytes == null || bytes.length < MIN_SIZE) { + return false; + } + for (int i = 0; i < MAGIC.length; i++) { + if (bytes[i] != MAGIC[i]) { + return false; + } + } + if (bytes[MAGIC.length] != VERSION) { + return false; + } + try { + DataInputStream in = new DataInputStream(new ByteArrayInputStream(bytes)); + in.skipBytes(MAGIC.length + 1); + int uriLen = in.readUnsignedShort(); + int remaining = bytes.length - MAGIC.length - 1 - 2; + if (uriLen < 0 || uriLen > remaining) { + return false; + } + in.skipBytes(uriLen); + int colLen = in.readUnsignedShort(); + remaining = remaining - uriLen - 2; + if (colLen < 0 || colLen > remaining) { + return false; + } + int expectedRemaining = colLen + 8; + return remaining == expectedRemaining; + } catch (IOException e) { + return false; + } + } + + /** Serialize this reference to a compact byte array. */ + public byte[] serialize() { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(128); + DataOutputStream out = new DataOutputStream(baos); + out.write(MAGIC); + out.writeByte(VERSION); + writeString(out, datasetUri); + writeString(out, columnName); + out.writeLong(rowAddress); + out.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize BlobReference", e); + } + } + + /** Deserialize a BlobReference from bytes. */ + public static BlobReference deserialize(byte[] bytes) { + if (!isBlobReference(bytes)) { + throw new IllegalArgumentException("Not a valid BlobReference"); + } + try { + DataInputStream in = new DataInputStream(new ByteArrayInputStream(bytes)); + in.skipBytes(MAGIC.length); + in.readByte(); // version, already validated + String datasetUri = readString(in); + String columnName = readString(in); + long rowAddress = in.readLong(); + return new BlobReference(datasetUri, columnName, rowAddress); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize BlobReference", e); + } + } + + private static void writeString(DataOutputStream out, String s) throws IOException { + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + out.writeShort(bytes.length); + out.write(bytes); + } + + private static String readString(DataInputStream in) throws IOException { + int len = in.readUnsignedShort(); + byte[] bytes = new byte[len]; + in.readFully(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + // ========== Getters ========== + + public String getDatasetUri() { + return datasetUri; + } + + public String getColumnName() { + return columnName; + } + + public long getRowAddress() { + return rowAddress; + } + + @Override + public String toString() { + return String.format( + "BlobReference{dataset=%s, column=%s, rowAddr=0x%016X}", + datasetUri, columnName, rowAddress); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java new file mode 100644 index 000000000..2a273ffcb --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.utils; + +import org.lance.BlobFile; +import org.lance.Dataset; +import org.lance.ReadOptions; +import org.lance.spark.LanceRuntime; + +import org.apache.arrow.vector.LargeVarBinaryVector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Resolves {@link BlobReference} objects to actual blob bytes by opening the source datasets and + * calling {@code Dataset.takeBlobs()}. + * + *

Datasets are cached for the lifetime of this resolver to amortize open costs across batches. + * Resolution is done in true batches: all pending references are grouped by (datasetUri, + * columnName) and each group is resolved with a single {@code takeBlobs()} call. + */ +public class BlobReferenceResolver implements AutoCloseable { + + /** Cache of opened datasets keyed by dataset URI. */ + private final Map datasetCache = new HashMap<>(); + + /** + * Resolves a single blob reference to actual blob bytes. + * + * @param ref the blob reference to resolve + * @return the actual blob bytes + * @throws IOException if reading the blob fails + */ + public byte[] resolve(BlobReference ref) throws IOException { + Dataset dataset = getOrOpenDataset(ref.getDatasetUri()); + List rowAddresses = new ArrayList<>(1); + rowAddresses.add(ref.getRowAddress()); + List blobs = dataset.takeBlobs(rowAddresses, ref.getColumnName()); + if (blobs.isEmpty()) { + return new byte[0]; + } + try (BlobFile blob = blobs.get(0)) { + return blob.read(); + } + } + + /** + * Checks if a byte array is a blob reference and resolves it. If the bytes are not a blob + * reference, returns them unchanged. + */ + public byte[] resolveIfNeeded(byte[] bytes) throws IOException { + if (BlobReference.isBlobReference(bytes)) { + BlobReference ref = BlobReference.deserialize(bytes); + return resolve(ref); + } + return bytes; + } + + /** + * Resolves a batch of blob references and writes the resolved bytes directly into the target + * vector. References are grouped by (datasetUri, columnName) and each group is resolved with a + * single {@code takeBlobs()} call. + * + * @param indices vector indices corresponding to each blob reference + * @param refs blob references to resolve + * @param vector the target vector to back-fill with resolved bytes + * @throws IOException if reading blobs fails + */ + public void resolveBatch( + List indices, List refs, LargeVarBinaryVector vector) + throws IOException { + // Group by (datasetUri, columnName) + Map> groups = new HashMap<>(); + for (int i = 0; i < refs.size(); i++) { + int vectorIndex = indices.get(i); + BlobReference ref = refs.get(i); + String groupKey = ref.getDatasetUri() + "\0" + ref.getColumnName(); + groups + .computeIfAbsent(groupKey, k -> new ArrayList<>()) + .add(new IndexedRef(vectorIndex, ref)); + } + + // Resolve each group with a single takeBlobs() call + for (List group : groups.values()) { + BlobReference first = group.get(0).ref; + Dataset dataset = getOrOpenDataset(first.getDatasetUri()); + + List rowAddresses = new ArrayList<>(group.size()); + for (IndexedRef ir : group) { + rowAddresses.add(ir.ref.getRowAddress()); + } + + List blobs = dataset.takeBlobs(rowAddresses, first.getColumnName()); + + for (int i = 0; i < group.size(); i++) { + IndexedRef ir = group.get(i); + if (i < blobs.size()) { + try (BlobFile blob = blobs.get(i)) { + byte[] data = blob.read(); + vector.setSafe(ir.vectorIndex, data); + } + } else { + vector.setSafe(ir.vectorIndex, new byte[0]); + } + } + } + } + + private Dataset getOrOpenDataset(String datasetUri) { + return datasetCache.computeIfAbsent( + datasetUri, + uri -> { + ReadOptions.Builder builder = new ReadOptions.Builder(); + builder.setSession(LanceRuntime.session()); + return Dataset.open() + .allocator(LanceRuntime.allocator()) + .uri(uri) + .readOptions(builder.build()) + .build(); + }); + } + + @Override + public void close() { + for (Dataset dataset : datasetCache.values()) { + try { + dataset.close(); + } catch (Exception e) { + // Best effort cleanup + } + } + datasetCache.clear(); + } + + private static class IndexedRef { + final int vectorIndex; + final BlobReference ref; + + IndexedRef(int vectorIndex, BlobReference ref) { + this.vectorIndex = vectorIndex; + this.ref = ref; + } + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java index 66458802a..91a1f57cc 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java @@ -13,6 +13,8 @@ */ package org.lance.spark.vectorized; +import org.lance.spark.utils.BlobReference; + import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.complex.StructVector; import org.apache.spark.sql.catalyst.InternalRow; @@ -23,6 +25,11 @@ public class BlobStructAccessor implements AutoCloseable { private final UInt8Vector positionVector; private final UInt8Vector sizeVector; + // Blob reference context — set by the scanner to enable serializing blob references + private String datasetUri; + private String columnName; + private long[] rowAddresses; + public BlobStructAccessor(StructVector structVector) { this.structVector = structVector; // Blob structs have two fields: position and size (both unsigned Int64) @@ -30,6 +37,26 @@ public BlobStructAccessor(StructVector structVector) { this.sizeVector = (UInt8Vector) structVector.getChild("size"); } + /** + * Sets the context needed to produce blob references. When set, {@link #getBlobReference(int)} + * will return a serialized {@link BlobReference} that the write side can use to fetch the actual + * blob bytes from the source dataset. + * + * @param datasetUri the URI of the source dataset + * @param columnName the blob column name + * @param rowAddresses row addresses for each row in this batch + */ + public void setBlobReferenceContext(String datasetUri, String columnName, long[] rowAddresses) { + this.datasetUri = datasetUri; + this.columnName = columnName; + this.rowAddresses = rowAddresses; + } + + /** Returns true if blob reference context has been set. */ + public boolean hasBlobReferenceContext() { + return datasetUri != null && columnName != null && rowAddresses != null; + } + public int getNullCount() { return structVector.getNullCount(); } @@ -38,6 +65,28 @@ public boolean isNullAt(int rowId) { return structVector.isNull(rowId); } + /** + * Returns a serialized blob reference for the given row. Returns null if the row is null or if + * the blob reference context is not set. Returns empty byte array if the blob has zero size (null + * blob value encoded as position=0, size=0). + */ + public byte[] getBlobReference(int rowId) { + if (isNullAt(rowId)) { + return null; + } + if (!hasBlobReferenceContext()) { + return new byte[0]; + } + Long size = getSize(rowId); + if (size == null || size == 0) { + // Zero-size blob — either truly empty or null encoded as (0,0) + return new byte[0]; + } + long rowAddr = rowAddresses[rowId]; + BlobReference ref = new BlobReference(datasetUri, columnName, rowAddr); + return ref.serialize(); + } + public InternalRow getStruct(int rowId) { if (isNullAt(rowId)) { return null; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java index 5bf469b18..c16040e31 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/LanceArrowColumnVector.java @@ -454,7 +454,7 @@ public byte[] getBinary(int rowId) { return fixedSizeBinaryAccessor.getBinary(rowId); } if (blobStructAccessor != null) { - return new byte[0]; + return blobStructAccessor.getBlobReference(rowId); } if (arrowColumnVector != null) { return arrowColumnVector.getBinary(rowId); diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index 7e6ed9020..53872d3dc 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -128,7 +128,6 @@ object LanceArrowWriter { throw new UnsupportedOperationException(s"Unsupported data type: $dt") } } - } /** @@ -323,10 +322,46 @@ private[arrow] class BinaryWriter(val valueVector: VarBinaryVector) extends Lanc private[arrow] class LargeBinaryWriter(val valueVector: LargeVarBinaryVector) extends LanceArrowFieldWriter { + + private val pendingIndices = new java.util.ArrayList[java.lang.Integer]() + private val pendingRefs = new java.util.ArrayList[org.lance.spark.utils.BlobReference]() + + @transient private lazy val resolver = new org.lance.spark.utils.BlobReferenceResolver() + override def setNull(): Unit = {} override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val bytes = input.getBinary(ordinal) - valueVector.setSafe(count, bytes) + if (bytes == null || bytes.length == 0) { + valueVector.setSafe(count, bytes) + } else if (org.lance.spark.utils.BlobReference.isBlobReference(bytes)) { + val ref = org.lance.spark.utils.BlobReference.deserialize(bytes) + pendingIndices.add(count) + pendingRefs.add(ref) + valueVector.setSafe(count, Array.emptyByteArray) + } else { + valueVector.setSafe(count, bytes) + } + } + + override def finish(): Unit = { + super.finish() + if (!pendingRefs.isEmpty) { + try { + resolver.resolveBatch(pendingIndices, pendingRefs, valueVector) + } catch { + case e: java.io.IOException => + throw new RuntimeException("Failed to resolve blob references", e) + } finally { + pendingIndices.clear() + pendingRefs.clear() + } + } + } + + override def reset(): Unit = { + super.reset() + pendingIndices.clear() + pendingRefs.clear() } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java index f49f75b53..c13317822 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobCreateTableTest.java @@ -177,8 +177,9 @@ public void testBlobWithSqlTblProperties() { assertEquals(1, blobRows.size()); Object blob = blobRows.get(0).get(1); assertNotNull(blob); + // Blob data is not fully materialized on scan — it contains a compact blob reference + // that can be used to fetch the actual bytes (e.g. via take_blobs or during INSERT INTO). assertTrue(blob instanceof byte[], "Blob data should be byte array"); - assertEquals(0, ((byte[]) blob).length, "Blob data should be empty (not materialized)"); // Clean up spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); @@ -246,8 +247,9 @@ public void testBlobWithTablePropertyAPI() { assertEquals(1, blobRows.size()); Object blob = blobRows.get(0).get(1); assertNotNull(blob); + // Blob data is not fully materialized on scan — it contains a compact blob reference + // that can be used to fetch the actual bytes (e.g. via take_blobs or during INSERT INTO). assertTrue(blob instanceof byte[], "Blob data should be byte array"); - assertEquals(0, ((byte[]) blob).length, "Blob data should be empty (not materialized)"); // Clean up spark.sql("DROP TABLE IF EXISTS " + catalogName + ".default." + tableName); @@ -324,7 +326,8 @@ public void testCreateEmptyTableWithBlobAndSQLInsert() { byte[] blobBytes = (byte[]) blobData; // Blob data is not materialized, so we get empty arrays - assertEquals(0, blobBytes.length, "Blob data should be empty (not materialized)"); + // Blob data is not fully materialized on scan — it contains a compact blob reference + assertNotNull(blobBytes, "Blob data should not be null"); } // Clean up @@ -460,7 +463,8 @@ public void testBlobVirtualColumns() { assertNotNull(blobData); assertTrue(blobData instanceof byte[], "Blob data should be byte array"); byte[] blobBytes = (byte[]) blobData; - assertEquals(0, blobBytes.length, "Blob data should be empty (not materialized)"); + // Blob data is not fully materialized on scan — it contains a compact blob reference + assertNotNull(blobBytes, "Blob data should not be null"); // Verify virtual columns for position and size long position = row.getLong(2); diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java new file mode 100644 index 000000000..ec6fe85ed --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java @@ -0,0 +1,437 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark; + +import org.lance.spark.utils.BlobReferenceResolver; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests that verify blob data is preserved when blob columns flow through Spark operations like + * JOIN and INSERT INTO SELECT. + * + *

Blob references (compact ~100 byte descriptors containing the source dataset URI and row + * address) are serialized through Spark's shuffle instead of the actual blob bytes. On the write + * side, the blob references are resolved by opening the source dataset and fetching the actual blob + * content via {@code Dataset.takeBlobs()}. + */ +public abstract class BaseBlobJoinTest { + private SparkSession spark; + private static final String catalogName = "lance_blob_join"; + + @TempDir protected Path tempDir; + + @BeforeEach + void setup() { + spark = + SparkSession.builder() + .appName("blob-join-test") + .master("local[*]") + .config( + "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") + .config("spark.sql.catalog." + catalogName + ".impl", "dir") + .config("spark.sql.catalog." + catalogName + ".root", tempDir.toString()) + .getOrCreate(); + spark.sql("CREATE NAMESPACE IF NOT EXISTS " + catalogName + ".default"); + } + + @AfterEach + void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + /** + * Verifies that blob data is preserved when selecting from a single blob table and inserting into + * another table. + */ + @Test + public void testBlobPreservedDuringInsertIntoSelect() throws Exception { + String sourceTable = "blob_source_" + System.currentTimeMillis(); + String targetTable = "blob_target_" + System.currentTimeMillis(); + String fqSource = catalogName + ".default." + sourceTable; + String fqTarget = catalogName + ".default." + targetTable; + + // Create source table with blob column + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqSource + + " (id INT NOT NULL, data BINARY) USING lance " + + "TBLPROPERTIES ('data.lance.encoding' = 'blob')"); + + // Insert known data into the source + byte[] blobContent1 = "hello-blob-world-12345".getBytes(StandardCharsets.UTF_8); + byte[] blobContent2 = "another-blob-value".getBytes(StandardCharsets.UTF_8); + List rows = new ArrayList<>(); + rows.add(RowFactory.create(1, blobContent1)); + rows.add(RowFactory.create(2, blobContent2)); + + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("data", DataTypes.BinaryType, true) + }); + + Dataset df = spark.createDataFrame(rows, schema); + try { + df.coalesce(1).writeTo(fqSource).append(); + } catch (Exception e) { + fail("Failed to write to source table: " + e.getMessage()); + } + + // Create target table with blob column + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqTarget + + " (id INT NOT NULL, data BINARY) USING lance " + + "TBLPROPERTIES ('data.lance.encoding' = 'blob')"); + + // INSERT INTO target SELECT FROM source + spark.sql("INSERT INTO " + fqTarget + " SELECT id, data FROM " + fqSource); + + // Verify row count + Dataset result = spark.sql("SELECT COUNT(*) FROM " + fqTarget); + assertEquals(2L, result.collectAsList().get(0).getLong(0), "Target should have 2 rows"); + + // Verify blob data is preserved in target via virtual columns + Dataset targetBlobs = + spark.sql( + "SELECT id, data, data" + + LanceConstant.BLOB_SIZE_SUFFIX + + " FROM " + + fqTarget + + " ORDER BY id"); + List targetRows = targetBlobs.collectAsList(); + assertEquals(2, targetRows.size()); + + // Row 1: blob size should match the original content + long blobSize1 = targetRows.get(0).getLong(2); + assertEquals( + blobContent1.length, blobSize1, "Blob size should match the original content length"); + + // Row 2: blob size should match the original content + long blobSize2 = targetRows.get(1).getLong(2); + assertEquals( + blobContent2.length, blobSize2, "Blob size should match the original content length"); + + // Verify actual blob content by resolving the blob references from the target table + try (BlobReferenceResolver resolver = new BlobReferenceResolver()) { + byte[] targetBlob1 = (byte[]) targetRows.get(0).get(1); + byte[] resolved1 = resolver.resolveIfNeeded(targetBlob1); + assertArrayEquals(blobContent1, resolved1, "Row 1: blob content should match original"); + + byte[] targetBlob2 = (byte[]) targetRows.get(1).get(1); + byte[] resolved2 = resolver.resolveIfNeeded(targetBlob2); + assertArrayEquals(blobContent2, resolved2, "Row 2: blob content should match original"); + } + + // Clean up + spark.sql("DROP TABLE IF EXISTS " + fqSource); + spark.sql("DROP TABLE IF EXISTS " + fqTarget); + } + + /** + * Verifies that blob data from two source tables is preserved when joining and inserting the + * result into a third table. + */ + @Test + public void testBlobPreservedDuringJoinAndInsert() throws Exception { + String tableA = "blob_join_a_" + System.currentTimeMillis(); + String tableB = "blob_join_b_" + System.currentTimeMillis(); + String targetTable = "blob_join_target_" + System.currentTimeMillis(); + String fqA = catalogName + ".default." + tableA; + String fqB = catalogName + ".default." + tableB; + String fqTarget = catalogName + ".default." + targetTable; + + // Create table A with blob column + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqA + + " (id INT NOT NULL, blob_a BINARY) USING lance " + + "TBLPROPERTIES ('blob_a.lance.encoding' = 'blob')"); + + // Create table B with blob column + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqB + + " (id INT NOT NULL, blob_b BINARY) USING lance " + + "TBLPROPERTIES ('blob_b.lance.encoding' = 'blob')"); + + // Insert data into table A + List rowsA = new ArrayList<>(); + Random rng = new Random(42); + byte[][] blobAData = new byte[5][]; + for (int i = 0; i < 5; i++) { + byte[] data = new byte[1000]; + rng.nextBytes(data); + data[0] = (byte) (i + 1); + blobAData[i] = data; + rowsA.add(RowFactory.create(i + 1, data)); + } + + StructType schemaA = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("blob_a", DataTypes.BinaryType, true) + }); + + try { + spark.createDataFrame(rowsA, schemaA).coalesce(1).writeTo(fqA).append(); + } catch (Exception e) { + fail("Failed to write to table A: " + e.getMessage()); + } + + // Insert data into table B + List rowsB = new ArrayList<>(); + byte[][] blobBData = new byte[5][]; + for (int i = 0; i < 5; i++) { + byte[] data = new byte[2000]; + rng.nextBytes(data); + data[0] = (byte) (i + 101); + blobBData[i] = data; + rowsB.add(RowFactory.create(i + 1, data)); + } + + StructType schemaB = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("blob_b", DataTypes.BinaryType, true) + }); + + try { + spark.createDataFrame(rowsB, schemaB).coalesce(1).writeTo(fqB).append(); + } catch (Exception e) { + fail("Failed to write to table B: " + e.getMessage()); + } + + // Create target table with both blob columns + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqTarget + + " (id INT NOT NULL, blob_a BINARY, blob_b BINARY) USING lance " + + "TBLPROPERTIES (" + + "'blob_a.lance.encoding' = 'blob', " + + "'blob_b.lance.encoding' = 'blob')"); + + // JOIN and INSERT + spark.sql( + "INSERT INTO " + + fqTarget + + " SELECT a.id, a.blob_a, b.blob_b FROM " + + fqA + + " a JOIN " + + fqB + + " b ON a.id = b.id"); + + // Verify row count + Dataset countResult = spark.sql("SELECT COUNT(*) FROM " + fqTarget); + assertEquals(5L, countResult.collectAsList().get(0).getLong(0), "Target should have 5 rows"); + + // Verify blob data sizes are preserved via virtual columns + Dataset targetBlobs = + spark.sql( + "SELECT id, blob_a, blob_b, " + + "blob_a" + + LanceConstant.BLOB_SIZE_SUFFIX + + ", " + + "blob_b" + + LanceConstant.BLOB_SIZE_SUFFIX + + " FROM " + + fqTarget + + " ORDER BY id"); + List targetRows = targetBlobs.collectAsList(); + assertEquals(5, targetRows.size()); + + for (Row row : targetRows) { + int id = row.getInt(0); + long blobASize = row.getLong(3); + long blobBSize = row.getLong(4); + + // Blob sizes should match the original data sizes + assertEquals( + 1000L, blobASize, "Row " + id + ": blob_a size should match original data (1000 bytes)"); + assertEquals( + 2000L, blobBSize, "Row " + id + ": blob_b size should match original data (2000 bytes)"); + } + + // Verify actual blob content by resolving blob references + try (BlobReferenceResolver resolver = new BlobReferenceResolver()) { + for (Row row : targetRows) { + int id = row.getInt(0); + byte[] blobARef = (byte[]) row.get(1); + byte[] blobBRef = (byte[]) row.get(2); + byte[] resolvedA = resolver.resolveIfNeeded(blobARef); + byte[] resolvedB = resolver.resolveIfNeeded(blobBRef); + + assertArrayEquals( + blobAData[id - 1], resolvedA, "Row " + id + ": blob_a content should match original"); + assertArrayEquals( + blobBData[id - 1], resolvedB, "Row " + id + ": blob_b content should match original"); + } + } + + // Clean up + spark.sql("DROP TABLE IF EXISTS " + fqA); + spark.sql("DROP TABLE IF EXISTS " + fqB); + spark.sql("DROP TABLE IF EXISTS " + fqTarget); + } + + /** + * Verifies that non-blob columns are preserved correctly during JOIN + INSERT when blob columns + * are also present. + */ + @Test + public void testNonBlobColumnsPreservedDuringJoinWithBlobs() throws Exception { + String tableA = "blob_join_nonblob_a_" + System.currentTimeMillis(); + String tableB = "blob_join_nonblob_b_" + System.currentTimeMillis(); + String targetTable = "blob_join_nonblob_target_" + System.currentTimeMillis(); + String fqA = catalogName + ".default." + tableA; + String fqB = catalogName + ".default." + tableB; + String fqTarget = catalogName + ".default." + targetTable; + + // Create table A with blob + text columns + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqA + + " (id INT NOT NULL, name STRING, blob_a BINARY) USING lance " + + "TBLPROPERTIES ('blob_a.lance.encoding' = 'blob')"); + + // Create table B with a score column + spark.sql("CREATE TABLE IF NOT EXISTS " + fqB + " (id INT NOT NULL, score DOUBLE) USING lance"); + + // Insert data into table A + byte[] aliceBlob = "alice-blob-content".getBytes(StandardCharsets.UTF_8); + byte[] bobBlob = "bob-blob-content".getBytes(StandardCharsets.UTF_8); + List rowsA = new ArrayList<>(); + rowsA.add(RowFactory.create(1, "alice", aliceBlob)); + rowsA.add(RowFactory.create(2, "bob", bobBlob)); + + StructType schemaA = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("blob_a", DataTypes.BinaryType, true) + }); + + try { + spark.createDataFrame(rowsA, schemaA).coalesce(1).writeTo(fqA).append(); + } catch (Exception e) { + fail("Failed to write to table A: " + e.getMessage()); + } + + // Insert data into table B + List rowsB = new ArrayList<>(); + rowsB.add(RowFactory.create(1, 99.5)); + rowsB.add(RowFactory.create(2, 87.3)); + + StructType schemaB = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("score", DataTypes.DoubleType, true) + }); + + try { + spark.createDataFrame(rowsB, schemaB).coalesce(1).writeTo(fqB).append(); + } catch (Exception e) { + fail("Failed to write to table B: " + e.getMessage()); + } + + // Create target table + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqTarget + + " (id INT NOT NULL, name STRING, score DOUBLE, blob_a BINARY) USING lance " + + "TBLPROPERTIES ('blob_a.lance.encoding' = 'blob')"); + + // JOIN and INSERT — non-blob columns should survive, blob data should be preserved + spark.sql( + "INSERT INTO " + + fqTarget + + " SELECT a.id, a.name, b.score, a.blob_a FROM " + + fqA + + " a JOIN " + + fqB + + " b ON a.id = b.id"); + + // Verify non-blob data is preserved + Dataset result = spark.sql("SELECT id, name, score FROM " + fqTarget + " ORDER BY id"); + List resultRows = result.collectAsList(); + assertEquals(2, resultRows.size()); + + assertEquals(1, resultRows.get(0).getInt(0)); + assertEquals("alice", resultRows.get(0).getString(1)); + assertEquals(99.5, resultRows.get(0).getDouble(2), 0.01); + + assertEquals(2, resultRows.get(1).getInt(0)); + assertEquals("bob", resultRows.get(1).getString(1)); + assertEquals(87.3, resultRows.get(1).getDouble(2), 0.01); + + // Verify blob data is preserved via direct blob column read. + // NOTE: We query the blob column itself (not the virtual __blob_size column alone) + // because there is a pre-existing bug where querying only virtual blob columns + // without the base blob column causes an ArrayIndexOutOfBoundsException. + Dataset blobResult = + spark.sql( + "SELECT id, blob_a, blob_a" + + LanceConstant.BLOB_SIZE_SUFFIX + + " FROM " + + fqTarget + + " ORDER BY id"); + List blobRows = blobResult.collectAsList(); + assertEquals(2, blobRows.size()); + + assertEquals( + aliceBlob.length, blobRows.get(0).getLong(2), "alice's blob size should match original"); + assertEquals( + bobBlob.length, blobRows.get(1).getLong(2), "bob's blob size should match original"); + + // Verify actual blob content + try (BlobReferenceResolver resolver = new BlobReferenceResolver()) { + byte[] aliceResolved = resolver.resolveIfNeeded((byte[]) blobRows.get(0).get(1)); + assertArrayEquals(aliceBlob, aliceResolved, "alice's blob content should match original"); + + byte[] bobResolved = resolver.resolveIfNeeded((byte[]) blobRows.get(1).get(1)); + assertArrayEquals(bobBlob, bobResolved, "bob's blob content should match original"); + } + + // Clean up + spark.sql("DROP TABLE IF EXISTS " + fqA); + spark.sql("DROP TABLE IF EXISTS " + fqB); + spark.sql("DROP TABLE IF EXISTS " + fqTarget); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java new file mode 100644 index 000000000..396333bf2 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.utils; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class BlobReferenceTest { + + @Test + public void testRoundTripSerialization() { + BlobReference original = new BlobReference("/tmp/my-dataset", "image_col", 0x0003_0000_0042L); + + byte[] serialized = original.serialize(); + assertTrue(BlobReference.isBlobReference(serialized)); + + BlobReference deserialized = BlobReference.deserialize(serialized); + assertEquals(original.getDatasetUri(), deserialized.getDatasetUri()); + assertEquals(original.getColumnName(), deserialized.getColumnName()); + assertEquals(original.getRowAddress(), deserialized.getRowAddress()); + } + + @Test + public void testRoundTripWithUnicodeUri() { + BlobReference original = new BlobReference("s3://bucket/path/日本語", "データ", 123456789L); + + byte[] serialized = original.serialize(); + BlobReference deserialized = BlobReference.deserialize(serialized); + + assertEquals(original.getDatasetUri(), deserialized.getDatasetUri()); + assertEquals(original.getColumnName(), deserialized.getColumnName()); + assertEquals(original.getRowAddress(), deserialized.getRowAddress()); + } + + @Test + public void testRoundTripWithEmptyStrings() { + BlobReference original = new BlobReference("", "", 0L); + + byte[] serialized = original.serialize(); + BlobReference deserialized = BlobReference.deserialize(serialized); + + assertEquals("", deserialized.getDatasetUri()); + assertEquals("", deserialized.getColumnName()); + assertEquals(0L, deserialized.getRowAddress()); + } + + @Test + public void testIsBlobReferenceRejectsNonReference() { + assertFalse(BlobReference.isBlobReference(null)); + assertFalse(BlobReference.isBlobReference(new byte[0])); + assertFalse(BlobReference.isBlobReference(new byte[] {1, 2, 3, 4})); + assertFalse(BlobReference.isBlobReference("not a blob reference".getBytes())); + } + + @Test + public void testDeserializeRejectsInvalidInput() { + assertThrows( + IllegalArgumentException.class, () -> BlobReference.deserialize(new byte[] {1, 2, 3, 4})); + } + + @Test + public void testMagicHeader() { + BlobReference ref = new BlobReference("uri", "col", 42L); + byte[] serialized = ref.serialize(); + + assertEquals('L', serialized[0]); + assertEquals('A', serialized[1]); + assertEquals('N', serialized[2]); + assertEquals('C', serialized[3]); + assertEquals('E', serialized[4]); + assertEquals('R', serialized[5]); + assertEquals('E', serialized[6]); + assertEquals('F', serialized[7]); + } + + @Test + public void testRandomBytesNotMisidentified() { + // Bytes that happen to start with LANCEREF but have invalid structure + byte[] fake = "LANCEREFxgarbage-data-here-padding".getBytes(); + assertFalse(BlobReference.isBlobReference(fake)); + } +} From 51794837de4a22f3574f5bdc604a4be8cca7fc37 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 21 May 2026 10:00:37 -0500 Subject: [PATCH 2/4] fix: harden blob preservation; route source credentials to write side MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback on the blob JOIN/INSERT preservation path and makes vended-credential auto-refresh work for blob sources. Correctness: - LargeBinaryWriter now buffers all per-row values and emits the vector in a single ascending pass at finish(), instead of back-filling resolved blobs out of order (which corrupts a LargeVarBinaryVector's offset buffer). resolveBatch() returns an index->bytes map rather than writing into the vector. Resource lifecycle: - The BlobReferenceResolver is now created once per write task, shared across batches/fragments, and closed at LanceDataWriter teardown — fixing a leak of native source datasets (one per blob column per batch). Credentials (the main design change): - New LanceBlobSourceContextRule optimizer rule collects each blob source table's BlobSourceContext (read options + namespace config) on the driver and stashes them, keyed by source URI, in the write command's options. LanceDataset.newWriteBuilder decodes them and threads them to the per-task resolver, which reopens sources via Utils.openDatasetBuilder().runtimeNamespace(...) so vended credentials keep auto-refreshing — mirroring compaction/index. No global registry, no per-row shuffle bloat. Registered in LanceSparkSessionExtensions. - BlobReferenceResolver no longer opens datasets directly; falls back to open-by-URI when no context is present (local sources / extension off). Performance: - BlobStructAccessor precomputes the constant reference prefix once per batch in setBlobReferenceContext; the per-row path only appends the 8-byte rowAddress. Misc: - Extract LargeBinaryWriter to its own file. - Fold the new constructor arguments into single canonical constructors rather than parallel overloads; update affected tests. - BaseBlobJoinTest enables the SQL extension so the rule path is covered. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../LanceSparkSessionExtensions.scala | 5 +- .../LanceSparkSessionExtensions.scala | 5 +- .../LanceSparkSessionExtensions.scala | 5 +- .../LanceSparkSessionExtensions.scala | 5 +- .../java/org/lance/spark/LanceConstant.java | 7 + .../java/org/lance/spark/LanceDataset.java | 23 +++- .../org/lance/spark/utils/BlobReference.java | 32 ++++- .../spark/utils/BlobReferenceResolver.java | 75 ++++++---- .../lance/spark/utils/BlobSourceContext.java | 73 ++++++++++ .../spark/vectorized/BlobStructAccessor.java | 11 +- .../lance/spark/write/LanceBatchWrite.java | 18 ++- .../lance/spark/write/LanceDataWriter.java | 63 +++++++-- .../write/QueuedArrowBatchWriteBuffer.java | 33 +++-- .../write/SemaphoreArrowBatchWriteBuffer.java | 37 ++++- .../org/lance/spark/write/SparkWrite.java | 21 ++- .../LanceBlobSourceContextRule.scala | 104 ++++++++++++++ .../lance/spark/arrow/LanceArrowWriter.scala | 86 +++++------- .../lance/spark/arrow/LargeBinaryWriter.scala | 129 ++++++++++++++++++ .../org/lance/spark/BaseBlobJoinTest.java | 4 + .../QueuedArrowBatchWriteBufferTest.java | 6 +- .../SemaphoreArrowBatchWriteBufferTest.java | 6 +- .../org/lance/spark/write/SparkWriteTest.java | 6 +- 22 files changed, 628 insertions(+), 126 deletions(-) create mode 100644 lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobSourceContext.java create mode 100644 lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRule.scala create mode 100644 lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala diff --git a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d1..87bc6403c 100644 --- a/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.4_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,7 +14,7 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,9 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // propagate blob source credentials from read scans to the write side + extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d1..87bc6403c 100644 --- a/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-3.5_2.12/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,7 +14,7 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,9 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // propagate blob source credentials from read scans to the write side + extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d1..87bc6403c 100644 --- a/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.0_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,7 +14,7 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,9 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // propagate blob source credentials from read scans to the write side + extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala index 6f9a905d1..87bc6403c 100644 --- a/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala +++ b/lance-spark-4.1_2.13/src/main/scala/org/lance/spark/extensions/LanceSparkSessionExtensions.scala @@ -14,7 +14,7 @@ package org.lance.spark.extensions import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.optimizer.LanceFragmentAwareJoinRule +import org.apache.spark.sql.catalyst.optimizer.{LanceBlobSourceContextRule, LanceFragmentAwareJoinRule} import org.apache.spark.sql.catalyst.parser.extensions.LanceSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.LanceDataSourceV2Strategy @@ -27,6 +27,9 @@ class LanceSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // optimizer rules for fragment-aware joins extensions.injectOptimizerRule(_ => LanceFragmentAwareJoinRule()) + // propagate blob source credentials from read scans to the write side + extensions.injectOptimizerRule(_ => LanceBlobSourceContextRule()) + extensions.injectPlannerStrategy(LanceDataSourceV2Strategy(_)) } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceConstant.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceConstant.java index dc6ec4e8f..069dce79d 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceConstant.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceConstant.java @@ -29,6 +29,13 @@ public class LanceConstant { public static final String BACKFILL_COLUMNS_KEY = "backfill_columns"; public static final String UPDATE_COLUMNS_KEY = "update_columns"; + /** + * Internal write option carrying the encoded blob source credential/open contexts for an INSERT + * whose query reads blob columns. Set on the driver by {@code LanceBlobSourceContextRule} and + * consumed by {@code LanceDataset.newWriteBuilder}; not a user-facing option. + */ + public static final String BLOB_SOURCE_CONTEXTS_KEY = "__lance_blob_source_contexts"; + /** Table property that declares the partition column(s) for SPJ. */ public static final String TABLE_OPT_PARTITION_COLUMNS = "lance.partition.columns"; } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java index c2a201950..e219689b5 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceDataset.java @@ -14,6 +14,7 @@ package org.lance.spark; import org.lance.spark.read.LanceScanBuilder; +import org.lance.spark.utils.BlobSourceContext; import org.lance.spark.utils.BlobUtils; import org.lance.spark.write.AddColumnsBackfillWrite; import org.lance.spark.write.SparkWrite; @@ -35,6 +36,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.apache.spark.sql.util.LanceSerializeUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -315,8 +317,11 @@ public Set capabilities() { public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { // Merge write-time options with the base options from read options CaseInsensitiveStringMap sparkWriteOptions = logicalWriteInfo.options(); + Map blobSourceContexts = decodeBlobSourceContexts(sparkWriteOptions); Map mergedOptions = new HashMap<>(readOptions.getStorageOptions()); mergedOptions.putAll(sparkWriteOptions.asCaseSensitiveMap()); + // Internal-only option (see LanceBlobSourceContextRule); never forward it as a storage option. + mergedOptions.remove(LanceConstant.BLOB_SOURCE_CONTEXTS_KEY); LanceSparkWriteOptions.Builder writeOptionsBuilder = LanceSparkWriteOptions.builder() @@ -380,7 +385,8 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { namespaceProperties, readOptions.getTableId(), managedVersioning, - tableProperties); + tableProperties, + blobSourceContexts); if (stagedCommit != null) { builder.setStagedCommit(stagedCommit); @@ -392,6 +398,21 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { return builder; } + /** + * Decodes the blob source contexts that {@code LanceBlobSourceContextRule} injected into the + * write options for an INSERT whose query reads blob columns. Returns an empty map when absent + * (e.g. no blob sources, or the SQL extension is not enabled). + */ + @SuppressWarnings("unchecked") + private static Map decodeBlobSourceContexts( + CaseInsensitiveStringMap writeOptions) { + String encoded = writeOptions.get(LanceConstant.BLOB_SOURCE_CONTEXTS_KEY); + if (encoded == null || encoded.isEmpty()) { + return Collections.emptyMap(); + } + return (Map) LanceSerializeUtil.decode(encoded); + } + @Override public MetadataColumn[] metadataColumns() { // Start with the base metadata columns diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java index c4cb33e66..5ebb470cd 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java @@ -19,6 +19,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; /** * A compact serializable reference to a blob stored in a Lance dataset. @@ -98,6 +99,18 @@ public static boolean isBlobReference(byte[] bytes) { /** Serialize this reference to a compact byte array. */ public byte[] serialize() { + return appendRowAddress(serializePrefix(datasetUri, columnName), rowAddress); + } + + /** + * Serializes the constant portion of a reference: everything except the trailing 8-byte + * rowAddress (i.e. magic + version + datasetUri + columnName). + * + *

{@code datasetUri} and {@code columnName} are constant for an entire scan batch, so callers + * on the per-row read hot path should compute this prefix once and then call {@link + * #appendRowAddress(byte[], long)} per row instead of re-encoding the strings every time. + */ + public static byte[] serializePrefix(String datasetUri, String columnName) { try { ByteArrayOutputStream baos = new ByteArrayOutputStream(128); DataOutputStream out = new DataOutputStream(baos); @@ -105,7 +118,6 @@ public byte[] serialize() { out.writeByte(VERSION); writeString(out, datasetUri); writeString(out, columnName); - out.writeLong(rowAddress); out.flush(); return baos.toByteArray(); } catch (IOException e) { @@ -113,6 +125,24 @@ public byte[] serialize() { } } + /** + * Returns a full serialized reference: {@code prefix} (from {@link #serializePrefix}) followed by + * the 8-byte big-endian {@code rowAddress}, matching {@link DataOutputStream#writeLong}. + */ + public static byte[] appendRowAddress(byte[] prefix, long rowAddress) { + byte[] out = Arrays.copyOf(prefix, prefix.length + 8); + int off = prefix.length; + out[off] = (byte) (rowAddress >>> 56); + out[off + 1] = (byte) (rowAddress >>> 48); + out[off + 2] = (byte) (rowAddress >>> 40); + out[off + 3] = (byte) (rowAddress >>> 32); + out[off + 4] = (byte) (rowAddress >>> 24); + out[off + 5] = (byte) (rowAddress >>> 16); + out[off + 6] = (byte) (rowAddress >>> 8); + out[off + 7] = (byte) rowAddress; + return out; + } + /** Deserialize a BlobReference from bytes. */ public static BlobReference deserialize(byte[] bytes) { if (!isBlobReference(bytes)) { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java index 2a273ffcb..a9bf17493 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java @@ -15,13 +15,11 @@ import org.lance.BlobFile; import org.lance.Dataset; -import org.lance.ReadOptions; -import org.lance.spark.LanceRuntime; - -import org.apache.arrow.vector.LargeVarBinaryVector; +import org.lance.spark.LanceSparkReadOptions; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -33,12 +31,32 @@ *

Datasets are cached for the lifetime of this resolver to amortize open costs across batches. * Resolution is done in true batches: all pending references are grouped by (datasetUri, * columnName) and each group is resolved with a single {@code takeBlobs()} call. + * + *

Source datasets are opened through {@link Utils#openDatasetBuilder(LanceSparkReadOptions)} + * using the per-source {@link BlobSourceContext} captured on the driver (keyed by dataset URI). + * This keeps the namespace client attached via {@code runtimeNamespace(...)} so vended credentials + * keep auto-refreshing — exactly how distributed compaction/index builds open datasets on + * executors. When no context is registered for a URI (e.g. a local filesystem source, or when the + * SQL extension that captures contexts is not enabled), it falls back to opening by URI with + * default options. */ public class BlobReferenceResolver implements AutoCloseable { /** Cache of opened datasets keyed by dataset URI. */ private final Map datasetCache = new HashMap<>(); + /** Per-source open/credential context, keyed by dataset URI. */ + private final Map sourceContexts; + + public BlobReferenceResolver() { + this(Collections.emptyMap()); + } + + public BlobReferenceResolver(Map sourceContexts) { + this.sourceContexts = + sourceContexts != null ? sourceContexts : Collections.emptyMap(); + } + /** * Resolves a single blob reference to actual blob bytes. * @@ -72,18 +90,24 @@ public byte[] resolveIfNeeded(byte[] bytes) throws IOException { } /** - * Resolves a batch of blob references and writes the resolved bytes directly into the target - * vector. References are grouped by (datasetUri, columnName) and each group is resolved with a + * Resolves a batch of blob references to their actual bytes, keyed by the caller-supplied vector + * indices. References are grouped by (datasetUri, columnName) and each group is resolved with a * single {@code takeBlobs()} call. * + *

The caller is responsible for writing the resolved bytes into the target vector. Resolved + * bytes are returned as a map rather than written here because back-filling a variable-width + * Arrow vector out of order corrupts its offset buffer; the caller must emit the whole vector in + * a single ascending pass. + * * @param indices vector indices corresponding to each blob reference * @param refs blob references to resolve - * @param vector the target vector to back-fill with resolved bytes + * @return a map from vector index to resolved blob bytes * @throws IOException if reading blobs fails */ - public void resolveBatch( - List indices, List refs, LargeVarBinaryVector vector) + public Map resolveBatch(List indices, List refs) throws IOException { + Map resolved = new HashMap<>(refs.size()); + // Group by (datasetUri, columnName) Map> groups = new HashMap<>(); for (int i = 0; i < refs.size(); i++) { @@ -111,28 +135,33 @@ public void resolveBatch( IndexedRef ir = group.get(i); if (i < blobs.size()) { try (BlobFile blob = blobs.get(i)) { - byte[] data = blob.read(); - vector.setSafe(ir.vectorIndex, data); + resolved.put(ir.vectorIndex, blob.read()); } } else { - vector.setSafe(ir.vectorIndex, new byte[0]); + resolved.put(ir.vectorIndex, new byte[0]); } } } + return resolved; } private Dataset getOrOpenDataset(String datasetUri) { - return datasetCache.computeIfAbsent( - datasetUri, - uri -> { - ReadOptions.Builder builder = new ReadOptions.Builder(); - builder.setSession(LanceRuntime.session()); - return Dataset.open() - .allocator(LanceRuntime.allocator()) - .uri(uri) - .readOptions(builder.build()) - .build(); - }); + return datasetCache.computeIfAbsent(datasetUri, this::openDataset); + } + + private Dataset openDataset(String datasetUri) { + BlobSourceContext context = sourceContexts.get(datasetUri); + if (context != null) { + // Reopen the source the same way executors do for compaction/index: route through the + // namespace client so vended (auto-refreshing) credentials remain valid while reading blobs. + return Utils.openDatasetBuilder(context.getReadOptions()) + .initialStorageOptions(context.getInitialStorageOptions()) + .runtimeNamespace( + context.getNamespaceImpl(), context.getNamespaceProperties(), context.getTableId()) + .build(); + } + // No captured context (e.g. local filesystem source, or the capture extension is not enabled). + return Utils.openDatasetBuilder(LanceSparkReadOptions.from(datasetUri)).build(); } @Override diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobSourceContext.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobSourceContext.java new file mode 100644 index 000000000..f72547a95 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobSourceContext.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.utils; + +import org.lance.spark.LanceSparkReadOptions; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +/** + * The credential/open context for a blob source dataset, captured on the driver when its scan is + * built and shipped to write executors so they can reopen the source to resolve blob references. + * + *

This mirrors how distributed compaction and index builds carry credentials to executors (see + * {@code OptimizeTaskExecutor} / {@code FragmentIndexTask}): the executor reconstructs the open via + * {@link Utils#openDatasetBuilder(LanceSparkReadOptions)} with {@code runtimeNamespace(...)} so + * that vended credentials returned by {@code namespace.describeTable()} (e.g. STS tokens from + * Iceberg REST, Polaris, Unity) keep auto-refreshing while blobs are read. + * + *

The blob source is generally a different table than the one a write task targets, so + * this context cannot be derived from the write options; it must travel from the read side. + */ +public class BlobSourceContext implements Serializable { + private static final long serialVersionUID = 1L; + + private final LanceSparkReadOptions readOptions; + private final Map initialStorageOptions; + private final String namespaceImpl; + private final Map namespaceProperties; + + public BlobSourceContext( + LanceSparkReadOptions readOptions, + Map initialStorageOptions, + String namespaceImpl, + Map namespaceProperties) { + this.readOptions = readOptions; + this.initialStorageOptions = initialStorageOptions; + this.namespaceImpl = namespaceImpl; + this.namespaceProperties = namespaceProperties; + } + + public LanceSparkReadOptions getReadOptions() { + return readOptions; + } + + public Map getInitialStorageOptions() { + return initialStorageOptions; + } + + public String getNamespaceImpl() { + return namespaceImpl; + } + + public Map getNamespaceProperties() { + return namespaceProperties; + } + + public List getTableId() { + return readOptions.getTableId(); + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java index 91a1f57cc..1b1286e25 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java @@ -30,6 +30,10 @@ public class BlobStructAccessor implements AutoCloseable { private String columnName; private long[] rowAddresses; + // Constant serialized prefix (magic + version + datasetUri + columnName) for this batch. + // Precomputed once in setBlobReferenceContext so the per-row hot path only appends rowAddress. + private byte[] referencePrefix; + public BlobStructAccessor(StructVector structVector) { this.structVector = structVector; // Blob structs have two fields: position and size (both unsigned Int64) @@ -50,6 +54,9 @@ public void setBlobReferenceContext(String datasetUri, String columnName, long[] this.datasetUri = datasetUri; this.columnName = columnName; this.rowAddresses = rowAddresses; + // datasetUri and columnName are constant for the batch — encode the reference prefix once + // here rather than re-encoding both strings per row in getBlobReference(). + this.referencePrefix = BlobReference.serializePrefix(datasetUri, columnName); } /** Returns true if blob reference context has been set. */ @@ -82,9 +89,7 @@ public byte[] getBlobReference(int rowId) { // Zero-size blob — either truly empty or null encoded as (0,0) return new byte[0]; } - long rowAddr = rowAddresses[rowId]; - BlobReference ref = new BlobReference(datasetUri, columnName, rowAddr); - return ref.serialize(); + return BlobReference.appendRowAddress(referencePrefix, rowAddresses[rowId]); } public InternalRow getStruct(int rowId) { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceBatchWrite.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceBatchWrite.java index 9de14859d..eddef4dda 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceBatchWrite.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceBatchWrite.java @@ -23,6 +23,7 @@ import org.lance.operation.Overwrite; import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.utils.BlobSourceContext; import org.lance.spark.utils.Utils; import org.apache.arrow.vector.types.pojo.Schema; @@ -70,6 +71,12 @@ public class LanceBatchWrite implements BatchWrite { */ private final List partitionColumns; + /** + * Per-source blob credential/open contexts keyed by source dataset URI, captured on the driver + * and passed to write tasks so they can reopen source datasets to resolve blob references. + */ + private final Map blobSourceContexts; + public LanceBatchWrite( StructType schema, LanceSparkWriteOptions writeOptions, @@ -90,7 +97,8 @@ public LanceBatchWrite( tableId, managedVersioning, stagedCommit, - Collections.emptyList()); + Collections.emptyList(), + Collections.emptyMap()); } public LanceBatchWrite( @@ -103,7 +111,8 @@ public LanceBatchWrite( List tableId, boolean managedVersioning, StagedCommit stagedCommit, - List partitionColumns) { + List partitionColumns, + Map blobSourceContexts) { this.schema = schema; this.overwrite = overwrite; this.initialStorageOptions = initialStorageOptions; @@ -113,6 +122,8 @@ public LanceBatchWrite( this.managedVersioning = managedVersioning; this.stagedCommit = stagedCommit; this.partitionColumns = partitionColumns == null ? Collections.emptyList() : partitionColumns; + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; // For staged operations, the dataset is managed by StagedCommit. // For non-staged operations, pin the dataset version for OCC. @@ -136,7 +147,8 @@ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) { namespaceImpl, namespaceProperties, tableId, - partitionColumns); + partitionColumns, + blobSourceContexts); } @Override diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java index 222a53ef4..630c92594 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java @@ -18,6 +18,8 @@ import org.lance.WriteParams; import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.utils.BlobReferenceResolver; +import org.lance.spark.utils.BlobSourceContext; import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.Data; @@ -49,6 +51,13 @@ public class LanceDataWriter implements DataWriter { private final DataType[] partitionColumnTypes; private final List completedFragments = new ArrayList<>(); + /** + * Resolves blob references to actual bytes during writes. Shared across all batches/fragments of + * this write task and closed at teardown to release the source datasets it opens. Null when blob + * resolution is not needed (e.g. the test-only constructor). + */ + private final BlobReferenceResolver blobResolver; + private ArrowBatchWriteBuffer writeBuffer; private FutureTask> fragmentCreationTask; private Thread fragmentCreationThread; @@ -65,7 +74,8 @@ public LanceDataWriter( fragmentCreationThread, null, new int[0], - new DataType[0]); + new DataType[0], + null); } LanceDataWriter( @@ -74,13 +84,15 @@ public LanceDataWriter( Thread fragmentCreationThread, Supplier bufferTaskFactory, int[] partitionColumnIndices, - DataType[] partitionColumnTypes) { + DataType[] partitionColumnTypes, + BlobReferenceResolver blobResolver) { this.writeBuffer = writeBuffer; this.fragmentCreationThread = fragmentCreationThread; this.fragmentCreationTask = fragmentCreationTask; this.bufferTaskFactory = bufferTaskFactory; this.partitionColumnIndices = partitionColumnIndices; this.partitionColumnTypes = partitionColumnTypes; + this.blobResolver = blobResolver; } @Override @@ -195,7 +207,15 @@ public void abort() throws IOException { @Override public void close() throws IOException { - writeBuffer.close(); + try { + writeBuffer.close(); + } finally { + // Release any source datasets opened to resolve blob references. Spark always calls close() + // (after commit, and after abort), so this is the single teardown point for the resolver. + if (blobResolver != null) { + blobResolver.close(); + } + } } static List stripRowIdMeta(List fragments) { @@ -245,6 +265,14 @@ public static class WriterFactory implements DataWriterFactory { private final List tableId; private final List partitionColumns; + /** + * Per-source blob credential/open contexts keyed by source dataset URI, captured on the driver + * (see {@code LanceBlobSourceContextRule}). Used by the per-task resolver to reopen source + * datasets and fetch blob bytes for references that flowed through the shuffle. Empty when no + * blob sources were detected or the SQL extension that captures them is not enabled. + */ + private final Map blobSourceContexts; + public WriterFactory( StructType schema, LanceSparkWriteOptions writeOptions, @@ -259,7 +287,8 @@ public WriterFactory( namespaceImpl, namespaceProperties, tableId, - Collections.emptyList()); + Collections.emptyList(), + Collections.emptyMap()); } public WriterFactory( @@ -269,7 +298,8 @@ public WriterFactory( String namespaceImpl, Map namespaceProperties, List tableId, - List partitionColumns) { + List partitionColumns, + Map blobSourceContexts) { // Everything passed to writer factory should be serializable this.schema = schema; this.writeOptions = writeOptions; @@ -278,9 +308,11 @@ public WriterFactory( this.namespaceProperties = namespaceProperties; this.tableId = tableId; this.partitionColumns = partitionColumns == null ? Collections.emptyList() : partitionColumns; + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; } - private BufferAndTask buildBufferAndTask() { + private BufferAndTask buildBufferAndTask(BlobReferenceResolver resolver) { int batchSize = writeOptions.getBatchSize(); boolean useQueuedBuffer = writeOptions.isUseQueuedWriteBuffer(); boolean useLargeVarTypes = writeOptions.isUseLargeVarTypes(); @@ -295,10 +327,11 @@ private BufferAndTask buildBufferAndTask() { int queueDepth = writeOptions.getQueueDepth(); writeBuffer = new QueuedArrowBatchWriteBuffer( - schema, batchSize, queueDepth, useLargeVarTypes, maxBatchBytes); + schema, batchSize, queueDepth, useLargeVarTypes, maxBatchBytes, resolver); } else { writeBuffer = - new SemaphoreArrowBatchWriteBuffer(schema, batchSize, useLargeVarTypes, maxBatchBytes); + new SemaphoreArrowBatchWriteBuffer( + schema, batchSize, useLargeVarTypes, maxBatchBytes, resolver); } final ArrowBatchWriteBuffer bufferRef = writeBuffer; @@ -318,14 +351,24 @@ private BufferAndTask buildBufferAndTask() { @Override public DataWriter createWriter(int partitionId, long taskId) { - BufferAndTask initial = buildBufferAndTask(); + // One resolver per write task, shared across all batches and fragments (rolled by + // bufferTaskFactory) and closed when the LanceDataWriter is closed. Always created so blob + // references can be resolved even without captured contexts (it falls back to open-by-URI). + BlobReferenceResolver resolver = new BlobReferenceResolver(blobSourceContexts); + BufferAndTask initial = buildBufferAndTask(resolver); initial.thread.start(); int[] indices = resolvePartitionColumnIndices(); DataType[] types = resolvePartitionColumnTypes(indices); return new LanceDataWriter( - initial.buffer, initial.task, initial.thread, this::buildBufferAndTask, indices, types); + initial.buffer, + initial.task, + initial.thread, + () -> buildBufferAndTask(resolver), + indices, + types, + resolver); } private int[] resolvePartitionColumnIndices() { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java index 2fdc5c2b2..7e8982abb 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java @@ -15,6 +15,7 @@ import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.utils.BlobReferenceResolver; import com.google.common.base.Preconditions; import org.apache.arrow.memory.BufferAllocator; @@ -85,6 +86,9 @@ public class QueuedArrowBatchWriteBuffer extends ArrowBatchWriteBuffer { /** Arrow writer for current batch. */ private org.lance.spark.arrow.LanceArrowWriter currentArrowWriter; + /** Resolves blob references during writes; null when blob resolution is not needed. */ + private final BlobReferenceResolver resolver; + /** Count of rows in current batch. */ private final AtomicInteger currentBatchRowCount = new AtomicInteger(0); @@ -124,12 +128,19 @@ public QueuedArrowBatchWriteBuffer( sparkSchema, batchSize, DEFAULT_QUEUE_DEPTH, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } /** Simplified constructor that uses LanceRuntime allocator and converts Spark schema to Arrow. */ public QueuedArrowBatchWriteBuffer(StructType sparkSchema, int batchSize, int queueDepth) { - this(sparkSchema, batchSize, queueDepth, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + this( + sparkSchema, + batchSize, + queueDepth, + false, + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } /** Constructor with large var types support, using LanceRuntime allocator. */ @@ -140,7 +151,8 @@ public QueuedArrowBatchWriteBuffer( batchSize, queueDepth, useLargeVarTypes, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } /** Constructor with all tuning parameters, using LanceRuntime allocator. */ @@ -149,14 +161,16 @@ public QueuedArrowBatchWriteBuffer( int batchSize, int queueDepth, boolean useLargeVarTypes, - long maxBatchBytes) { + long maxBatchBytes, + BlobReferenceResolver resolver) { this( LanceRuntime.allocator(), LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, useLargeVarTypes), sparkSchema, batchSize, queueDepth, - maxBatchBytes); + maxBatchBytes, + resolver); } public QueuedArrowBatchWriteBuffer( @@ -171,7 +185,8 @@ public QueuedArrowBatchWriteBuffer( sparkSchema, batchSize, queueDepth, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } public QueuedArrowBatchWriteBuffer( @@ -180,7 +195,8 @@ public QueuedArrowBatchWriteBuffer( StructType sparkSchema, int batchSize, int queueDepth, - long maxBatchBytes) { + long maxBatchBytes, + BlobReferenceResolver resolver) { super(allocator); Preconditions.checkNotNull(schema); Preconditions.checkArgument(batchSize > 0, "Batch size must be positive"); @@ -192,6 +208,7 @@ public QueuedArrowBatchWriteBuffer( this.batchSize = batchSize; this.maxBatchBytes = maxBatchBytes; this.queueDepth = queueDepth; + this.resolver = resolver; this.batchQueue = new ArrayBlockingQueue<>(queueDepth); // Create a child allocator for producer that shares the same root as the consumer @@ -219,7 +236,7 @@ private void allocateNewBatch() { throw e; } currentArrowWriter = - org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(currentBatch, sparkSchema); + org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(currentBatch, sparkSchema, resolver); currentBatchRowCount.set(0); } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java index 0c5cd62b5..f641b531a 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java @@ -15,6 +15,7 @@ import org.lance.spark.LanceRuntime; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.utils.BlobReferenceResolver; import com.google.common.base.Preconditions; import org.apache.arrow.memory.BufferAllocator; @@ -79,12 +80,16 @@ public class SemaphoreArrowBatchWriteBuffer extends ArrowBatchWriteBuffer { private org.lance.spark.arrow.LanceArrowWriter arrowWriter = null; + /** Resolves blob references during writes; null when blob resolution is not needed. */ + private final BlobReferenceResolver resolver; + public SemaphoreArrowBatchWriteBuffer( BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize, - long maxBatchBytes) { + long maxBatchBytes, + BlobReferenceResolver resolver) { // Pass a child allocator to ArrowReader so VectorSchemaRoot allocation is tracked super(allocator.newChildAllocator("semaphore-buffer", 0, Long.MAX_VALUE)); Preconditions.checkNotNull(schema); @@ -94,6 +99,7 @@ public SemaphoreArrowBatchWriteBuffer( this.sparkSchema = sparkSchema; this.batchSize = batchSize; this.maxBatchBytes = maxBatchBytes; + this.resolver = resolver; // Start with count = batchSize so the writer blocks on canWrite.await() until the // reader's prepareLoadNextBatch() initializes arrowWriter and resets count to 0. this.count = batchSize; @@ -101,29 +107,45 @@ public SemaphoreArrowBatchWriteBuffer( public SemaphoreArrowBatchWriteBuffer( BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize) { - this(allocator, schema, sparkSchema, batchSize, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + this( + allocator, + schema, + sparkSchema, + batchSize, + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } /** Simplified constructor that uses LanceRuntime allocator and converts Spark schema to Arrow. */ public SemaphoreArrowBatchWriteBuffer(StructType sparkSchema, int batchSize) { - this(sparkSchema, batchSize, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + this(sparkSchema, batchSize, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, null); } /** Constructor with large var types support, using LanceRuntime allocator. */ public SemaphoreArrowBatchWriteBuffer( StructType sparkSchema, int batchSize, boolean useLargeVarTypes) { - this(sparkSchema, batchSize, useLargeVarTypes, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + this( + sparkSchema, + batchSize, + useLargeVarTypes, + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES, + null); } /** Constructor with all tuning parameters, using LanceRuntime allocator. */ public SemaphoreArrowBatchWriteBuffer( - StructType sparkSchema, int batchSize, boolean useLargeVarTypes, long maxBatchBytes) { + StructType sparkSchema, + int batchSize, + boolean useLargeVarTypes, + long maxBatchBytes, + BlobReferenceResolver resolver) { this( LanceRuntime.allocator(), LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, useLargeVarTypes), sparkSchema, batchSize, - maxBatchBytes); + maxBatchBytes, + resolver); } @Override @@ -204,7 +226,8 @@ public void prepareLoadNextBatch() throws IOException { v.allocateNew(); } root.setRowCount(0); - arrowWriter = org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema); + arrowWriter = + org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema, resolver); lock.lock(); try { count = 0; diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java index a7401d1c8..fc25ecbdf 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SparkWrite.java @@ -16,6 +16,7 @@ import org.lance.WriteParams; import org.lance.spark.LanceConstant; import org.lance.spark.LanceSparkWriteOptions; +import org.lance.spark.utils.BlobSourceContext; import org.apache.spark.sql.connector.distributions.Distribution; import org.apache.spark.sql.connector.distributions.Distributions; @@ -72,6 +73,9 @@ LanceSparkWriteOptions getWriteOptions() { private final StagedCommit stagedCommit; private final Map tableProperties; + /** Per-source blob credential/open contexts keyed by source dataset URI. */ + private final Map blobSourceContexts; + SparkWrite( StructType schema, LanceSparkWriteOptions writeOptions, @@ -82,7 +86,8 @@ LanceSparkWriteOptions getWriteOptions() { List tableId, boolean managedVersioning, StagedCommit stagedCommit, - Map tableProperties) { + Map tableProperties, + Map blobSourceContexts) { this.schema = schema; this.writeOptions = writeOptions; this.overwrite = overwrite; @@ -96,6 +101,8 @@ LanceSparkWriteOptions getWriteOptions() { tableProperties != null ? Collections.unmodifiableMap(tableProperties) : Collections.emptyMap(); + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; } /** Returns partition column names from the table property, empty list if unset. */ @@ -146,7 +153,8 @@ public BatchWrite toBatch() { tableId, managedVersioning, stagedCommit, - partitionColumnList()); + partitionColumnList(), + blobSourceContexts); } @Override @@ -174,6 +182,7 @@ public static class SparkWriteBuilder implements SupportsTruncate, WriteBuilder private final List tableId; private final boolean managedVersioning; private final Map tableProperties; + private final Map blobSourceContexts; public SparkWriteBuilder( StructType schema, @@ -183,7 +192,8 @@ public SparkWriteBuilder( Map namespaceProperties, List tableId, boolean managedVersioning, - Map tableProperties) { + Map tableProperties, + Map blobSourceContexts) { this.schema = schema; this.writeOptions = writeOptions; this.initialStorageOptions = initialStorageOptions; @@ -192,6 +202,8 @@ public SparkWriteBuilder( this.tableId = tableId; this.managedVersioning = managedVersioning; this.tableProperties = tableProperties; + this.blobSourceContexts = + blobSourceContexts == null ? Collections.emptyMap() : blobSourceContexts; } public void setStagedCommit(StagedCommit stagedCommit) { @@ -230,7 +242,8 @@ public Write build() { tableId, managedVersioning, stagedCommit, - tableProperties); + tableProperties, + blobSourceContexts); } @Override diff --git a/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRule.scala b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRule.scala new file mode 100644 index 000000000..1c2d2b854 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRule.scala @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.NamedRelation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.util.LanceSerializeUtil +import org.lance.spark.{LanceConstant, LanceDataset} +import org.lance.spark.utils.{BlobSourceContext, BlobUtils} + +/** + * Optimizer rule that propagates blob source credentials to the write side. + * + * When a Lance table with blob columns is read and its blob columns flow through a shuffle into a + * write (e.g. `INSERT INTO target SELECT ... [JOIN ...]`), the blob bytes are not materialized: a + * compact reference carrying the source dataset URI travels instead. To resolve those references the + * write executors must reopen the source dataset — but Spark's DSv2 write is never handed the read + * plan, so it has no way to learn the source's credentials. + * + * This rule bridges that gap on the driver, where both the write command and its source query are + * visible: it collects each blob source's [[BlobSourceContext]] (read options + namespace config for + * credential refresh), encodes them keyed by source URI, and stashes the result in the write + * command's options under [[LanceConstant.BLOB_SOURCE_CONTEXTS_KEY]]. `LanceDataset.newWriteBuilder` + * decodes it and threads it down to the per-task blob resolver. This keeps the credential context + * query-scoped (no global state) and rides the write's own options channel rather than bloating + * every shuffled row. + * + * No-op when the target is not a Lance table or the source query has no Lance blob tables. + */ +case class LanceBlobSourceContextRule() extends Rule[LogicalPlan] { + + private val key = LanceConstant.BLOB_SOURCE_CONTEXTS_KEY + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case a: AppendData if shouldAnnotate(a.table, a.writeOptions) => + annotate(a.query) match { + case Some(v) => a.copy(writeOptions = a.writeOptions + (key -> v)) + case None => a + } + case o: OverwriteByExpression if shouldAnnotate(o.table, o.writeOptions) => + annotate(o.query) match { + case Some(v) => o.copy(writeOptions = o.writeOptions + (key -> v)) + case None => o + } + case o: OverwritePartitionsDynamic if shouldAnnotate(o.table, o.writeOptions) => + annotate(o.query) match { + case Some(v) => o.copy(writeOptions = o.writeOptions + (key -> v)) + case None => o + } + } + + private def shouldAnnotate(table: NamedRelation, writeOptions: Map[String, String]): Boolean = + !writeOptions.contains(key) && isLanceTarget(table) + + private def isLanceTarget(table: NamedRelation): Boolean = table match { + case r: DataSourceV2Relation => r.table.isInstanceOf[LanceDataset] + case _ => false + } + + /** Encodes {sourceUri -> context} for the query's blob sources, or None if there are none. */ + private def annotate(query: LogicalPlan): Option[String] = { + val contexts = new java.util.HashMap[String, BlobSourceContext]() + query.foreach { node => + lanceTableWithBlobs(node).foreach { ds => + contexts.put( + ds.readOptions().getDatasetUri, + new BlobSourceContext( + ds.readOptions(), + ds.getInitialStorageOptions(), + ds.getNamespaceImpl(), + ds.getNamespaceProperties())) + } + } + if (contexts.isEmpty) None else Some(LanceSerializeUtil.encode(contexts)) + } + + /** Returns the Lance table backing a relation iff it has blob columns (pre- or post-pushdown). */ + private def lanceTableWithBlobs(node: LogicalPlan): Option[LanceDataset] = { + val table: Option[Table] = node match { + case r: DataSourceV2Relation => Some(r.table) + case sr: DataSourceV2ScanRelation => Some(sr.relation.table) + case _ => None + } + table.collect { case ds: LanceDataset if hasBlobColumns(ds) => ds } + } + + private def hasBlobColumns(ds: LanceDataset): Boolean = + ds.schema().fields.exists((f: StructField) => BlobUtils.isBlobSparkField(f)) +} diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index 53872d3dc..79900ced4 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.types._ import org.apache.spark.sql.util.LanceArrowUtils -import org.lance.spark.utils.Float16Utils +import org.lance.spark.utils.{BlobReferenceResolver, Float16Utils} import scala.collection.JavaConverters._ @@ -51,11 +51,22 @@ object LanceArrowWriter { create(root, schema) } - def create(root: VectorSchemaRoot, sparkSchema: StructType): LanceArrowWriter = { + def create(root: VectorSchemaRoot, sparkSchema: StructType): LanceArrowWriter = + create(root, sparkSchema, null) + + /** + * Creates a writer, injecting a shared {@link BlobReferenceResolver} used by binary writers to + * resolve blob references that flow through a shuffle. The resolver is owned by the caller (one per + * write task) and reused across batches; pass {@code null} when blob resolution is not needed. + */ + def create( + root: VectorSchemaRoot, + sparkSchema: StructType, + resolver: BlobReferenceResolver): LanceArrowWriter = { val children = root.getFieldVectors().asScala.zipWithIndex.map { case (vector, index) => vector.allocateNew() val sparkField = sparkSchema.fields(index) - createFieldWriter(vector, sparkField.dataType, sparkField.metadata) + createFieldWriter(vector, sparkField.dataType, sparkField.metadata, resolver) } new LanceArrowWriter(root, children.toArray) } @@ -63,14 +74,15 @@ object LanceArrowWriter { private[arrow] def createFieldWriter( vector: ValueVector, sparkType: DataType, - metadata: org.apache.spark.sql.types.Metadata = null): LanceArrowFieldWriter = { + metadata: org.apache.spark.sql.types.Metadata = null, + resolver: BlobReferenceResolver = null): LanceArrowFieldWriter = { (sparkType, vector) match { case (ArrayType(elementType: NumericType, _), vector: FixedSizeListVector) => - val elementWriter = createFieldWriter(vector.getDataVector(), elementType, null) + val elementWriter = createFieldWriter(vector.getDataVector(), elementType, null, resolver) new FixedSizeListWriter(vector, elementWriter) case (ArrayType(elementType, _), vector: ListVector) => - val elementWriter = createFieldWriter(vector.getDataVector(), elementType, null) + val elementWriter = createFieldWriter(vector.getDataVector(), elementType, null, resolver) new ArrayWriter(vector, elementWriter) case (BooleanType, vector: BitVector) => new BooleanWriter(vector) @@ -96,7 +108,7 @@ object LanceArrowWriter { case (_: CharType | _: VarcharType, vector: LargeVarCharVector) => new LargeStringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) - case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) + case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector, resolver) case (DateType, vector: DateDayVector) => new DateWriter(vector) case (DateType, vector: DateMilliVector) => new DateMilliWriter(vector) case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) @@ -106,15 +118,21 @@ object LanceArrowWriter { val keyWriter = createFieldWriter( structVector.getChild(MapVector.KEY_NAME), sparkType.asInstanceOf[MapType].keyType, - null) + null, + resolver) val valueWriter = createFieldWriter( structVector.getChild(MapVector.VALUE_NAME), sparkType.asInstanceOf[MapType].valueType, - null) + null, + resolver) new MapWriter(vector, structVector, keyWriter, valueWriter) case (StructType(fields), vector: StructVector) => val children = fields.zipWithIndex.map { case (field, ordinal) => - createFieldWriter(vector.getChildByOrdinal(ordinal), field.dataType, field.metadata) + createFieldWriter( + vector.getChildByOrdinal(ordinal), + field.dataType, + field.metadata, + resolver) } new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) @@ -123,7 +141,7 @@ object LanceArrowWriter { case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => new IntervalMonthDayNanoWriter(vector) case (udt: UserDefinedType[_], _) => - createFieldWriter(vector, udt.sqlType, metadata) + createFieldWriter(vector, udt.sqlType, metadata, resolver) case (dt, _) => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } @@ -320,50 +338,8 @@ private[arrow] class BinaryWriter(val valueVector: VarBinaryVector) extends Lanc } } -private[arrow] class LargeBinaryWriter(val valueVector: LargeVarBinaryVector) - extends LanceArrowFieldWriter { - - private val pendingIndices = new java.util.ArrayList[java.lang.Integer]() - private val pendingRefs = new java.util.ArrayList[org.lance.spark.utils.BlobReference]() - - @transient private lazy val resolver = new org.lance.spark.utils.BlobReferenceResolver() - - override def setNull(): Unit = {} - override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - val bytes = input.getBinary(ordinal) - if (bytes == null || bytes.length == 0) { - valueVector.setSafe(count, bytes) - } else if (org.lance.spark.utils.BlobReference.isBlobReference(bytes)) { - val ref = org.lance.spark.utils.BlobReference.deserialize(bytes) - pendingIndices.add(count) - pendingRefs.add(ref) - valueVector.setSafe(count, Array.emptyByteArray) - } else { - valueVector.setSafe(count, bytes) - } - } - - override def finish(): Unit = { - super.finish() - if (!pendingRefs.isEmpty) { - try { - resolver.resolveBatch(pendingIndices, pendingRefs, valueVector) - } catch { - case e: java.io.IOException => - throw new RuntimeException("Failed to resolve blob references", e) - } finally { - pendingIndices.clear() - pendingRefs.clear() - } - } - } - - override def reset(): Unit = { - super.reset() - pendingIndices.clear() - pendingRefs.clear() - } -} +// LargeBinaryWriter (BinaryType -> LargeVarBinaryVector) is a custom, non-trivial writer that also +// resolves blob references flowing through a shuffle; it lives in its own file LargeBinaryWriter.scala. private[arrow] class DateWriter(val valueVector: DateDayVector) extends LanceArrowFieldWriter { override def setNull(): Unit = {} diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala new file mode 100644 index 000000000..1b87180a0 --- /dev/null +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.arrow + +import org.apache.arrow.vector.LargeVarBinaryVector +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.lance.spark.utils.{BlobReference, BlobReferenceResolver} + +/** + * Writer for binary columns backed by a [[LargeVarBinaryVector]]. + * + * When a blob column flows through a Spark shuffle, its values arrive as serialized + * [[BlobReference]]s rather than the actual bytes. This writer detects those and resolves them to + * real blob bytes via the injected (shared, per-write-task) [[BlobReferenceResolver]]. + * + * All per-row values are buffered and the vector is emitted in a single ascending pass in + * [[finish]]. This ordering is required for correctness: resolving references produces bytes for + * arbitrary, non-contiguous indices, and writing into the middle of an already-populated + * variable-width Arrow vector corrupts its offset buffer (`setBytes` reads the start offset from the + * entry being overwritten and only rewrites the next offset, shifting every following row's bytes). + * Buffering one batch's values is bounded by the batch size. + */ +private[arrow] class LargeBinaryWriter( + val valueVector: LargeVarBinaryVector, + injectedResolver: BlobReferenceResolver) extends LanceArrowFieldWriter { + + // One buffered entry per row, in row order. Each is one of: + // null -> SQL NULL (validity bit left unset) + // Array[Byte] -> literal binary (possibly empty) + // BlobReference -> a reference to resolve to actual blob bytes + private val entries = new java.util.ArrayList[AnyRef]() + private var hasRefs = false + + // Only created when no resolver is injected (e.g. non-shuffle build paths). Owned and closed here. + private var localResolver: BlobReferenceResolver = _ + + private def resolver: BlobReferenceResolver = { + if (injectedResolver != null) { + injectedResolver + } else { + if (localResolver == null) { + localResolver = new BlobReferenceResolver() + } + localResolver + } + } + + override def setNull(): Unit = entries.add(null) + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + if (bytes != null && BlobReference.isBlobReference(bytes)) { + entries.add(BlobReference.deserialize(bytes)) + hasRefs = true + } else { + entries.add(bytes) + } + } + + override def finish(): Unit = { + try { + val resolved: java.util.Map[Integer, Array[Byte]] = + if (hasRefs) resolveReferences() else java.util.Collections.emptyMap() + + // Single ascending pass over the batch: write literals and resolved references in order. + var i = 0 + while (i < entries.size()) { + entries.get(i) match { + case null => // SQL NULL: leave the validity bit unset + case _: BlobReference => + val data = resolved.get(i) + valueVector.setSafe(i, if (data != null) data else Array.emptyByteArray) + case bytes: Array[Byte] => + valueVector.setSafe(i, bytes) + case other => + throw new IllegalStateException(s"Unexpected buffered binary entry: $other") + } + i += 1 + } + super.finish() + } finally { + entries.clear() + hasRefs = false + if (localResolver != null) { + localResolver.close() + localResolver = null + } + } + } + + /** Collects the buffered references and resolves them to bytes keyed by their row index. */ + private def resolveReferences(): java.util.Map[Integer, Array[Byte]] = { + val indices = new java.util.ArrayList[Integer]() + val refs = new java.util.ArrayList[BlobReference]() + var i = 0 + while (i < entries.size()) { + entries.get(i) match { + case ref: BlobReference => + indices.add(i) + refs.add(ref) + case _ => + } + i += 1 + } + try { + resolver.resolveBatch(indices, refs) + } catch { + case e: java.io.IOException => + throw new RuntimeException("Failed to resolve blob references", e) + } + } + + override def reset(): Unit = { + super.reset() + entries.clear() + hasRefs = false + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java index ec6fe85ed..98539778d 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java @@ -56,6 +56,10 @@ void setup() { SparkSession.builder() .appName("blob-join-test") .master("local[*]") + // Enable the Lance extensions so LanceBlobSourceContextRule runs and propagates the + // source dataset's credentials/open context to the write side for blob resolution. + .config( + "spark.sql.extensions", "org.lance.spark.extensions.LanceSparkSessionExtensions") .config( "spark.sql.catalog." + catalogName, "org.lance.spark.LanceNamespaceSparkCatalog") .config("spark.sql.catalog." + catalogName + ".impl", "dir") diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java index a640804a6..db9f545de 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java @@ -621,7 +621,7 @@ public void testByteBasedFlush() throws Exception { final int queueDepth = 4; final QueuedArrowBatchWriteBuffer writeBuffer = new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -703,7 +703,7 @@ public void testByteBasedFlushWithSmallRows() throws Exception { final int queueDepth = 4; final QueuedArrowBatchWriteBuffer writeBuffer = new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -772,7 +772,7 @@ public void testByteBasedFlushSingleLargeRow() throws Exception { final int queueDepth = 4; final QueuedArrowBatchWriteBuffer writeBuffer = new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java index cf23c249e..2d5d67076 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java @@ -301,7 +301,7 @@ public void testByteBasedFlushWithSmallRows() throws Exception { final long maxBatchBytes = 100 * 1024 * 1024; // 100MB - should never be reached final SemaphoreArrowBatchWriteBuffer writeBuffer = new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -331,7 +331,7 @@ public void testByteBasedFlush() throws Exception { final int rowSizeBytes = 100 * 1024; // ~100KB per row final SemaphoreArrowBatchWriteBuffer writeBuffer = new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -406,7 +406,7 @@ public void testByteBasedFlushSingleLargeRow() throws Exception { final int rowSizeBytes = 10 * 1024; // 10KB per row final SemaphoreArrowBatchWriteBuffer writeBuffer = new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); + allocator, schema, sparkSchema, batchSize, maxBatchBytes, null); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java index 363b2165b..e19bdcf93 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SparkWriteTest.java @@ -74,6 +74,7 @@ private SparkWrite.SparkWriteBuilder createBuilder(String datasetUri) { Collections.emptyMap(), Arrays.asList("default", "test_table"), false, + Collections.emptyMap(), Collections.emptyMap()); } @@ -115,6 +116,7 @@ public void testTruncateThenToBatch(TestInfo testInfo) { Collections.emptyMap(), null, false, + Collections.emptyMap(), Collections.emptyMap()); assertSame(builder, builder.truncate()); BatchWrite batchWrite = builder.build().toBatch(); @@ -139,6 +141,7 @@ public void testTruncatePreservesUseLargeVarTypes(TestInfo testInfo) { Collections.emptyMap(), null, false, + Collections.emptyMap(), Collections.emptyMap()); builder.truncate(); SparkWrite sparkWrite = (SparkWrite) builder.build(); @@ -161,7 +164,8 @@ private SparkWrite createWriteWithTableProperties( Collections.emptyMap(), Arrays.asList("default", "test_table"), false, - tableProps); + tableProps, + Collections.emptyMap()); return (SparkWrite) builder.build(); } From c66a06fca3bd9996a999968f6277422946ad2145 Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 21 May 2026 11:29:47 -0500 Subject: [PATCH 3/4] fix: harden blob batch resolution; add writer/rule tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address PR review feedback on the blob join-preservation path: - BlobReferenceResolver.resolveBatch: deduplicate row addresses within each (datasetUri, columnName) group and fan resolved bytes back out by address instead of by list position. Guard against takeBlobs count mismatch and null elements by failing loudly rather than writing wrong bytes. Document the takeBlobs ordering/1:1 contract. - BlobStructAccessor.getBlobReference: test blob size with the primitive UInt8Vector accessor instead of boxing through a per-row BigInteger on the scan hot path. - LargeBinaryWriter: buffer lazily — write rows directly until the first blob reference, then buffer only the tail. The common non-blob binary case now buffers nothing. Tests (written as JUnit 5 in Scala so surefire actually executes them; the existing ScalaTest *Suite classes are not picked up by the build): - LargeBinaryWriterTest: direct/buffered ordering, reference resolution at correct indices, IOException->RuntimeException, reset. - LanceBlobSourceContextRuleTest: non-Lance target no-op, blob-free source no-op, positive annotation, existing-key guard, idempotence. - BaseBlobJoinTest: one-to-many JOIN test exercising resolver dedup. Verified across Spark 3.5 and 4.1 (Scala 2.13). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../spark/utils/BlobReferenceResolver.java | 109 +++++++---- .../spark/vectorized/BlobStructAccessor.java | 6 +- .../lance/spark/arrow/LargeBinaryWriter.scala | 88 +++++---- .../org/lance/spark/BaseBlobJoinTest.java | 91 +++++++++ .../LanceBlobSourceContextRuleTest.scala | 126 ++++++++++++ .../spark/arrow/LargeBinaryWriterTest.scala | 182 ++++++++++++++++++ 6 files changed, 532 insertions(+), 70 deletions(-) create mode 100644 lance-spark-base_2.12/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRuleTest.scala create mode 100644 lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java index a9bf17493..1ccf96de8 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReferenceResolver.java @@ -30,7 +30,18 @@ * *

Datasets are cached for the lifetime of this resolver to amortize open costs across batches. * Resolution is done in true batches: all pending references are grouped by (datasetUri, - * columnName) and each group is resolved with a single {@code takeBlobs()} call. + * columnName), deduplicated by row address within each group, and each group is resolved with a + * single {@code takeBlobs()} call. + * + *

takeBlobs ordering contract. {@code Dataset.takeBlobs(addresses, column)} returns one + * {@code BlobFile} per requested address, in the same order as the input (Lance's take operation + * sorts internally for IO efficiency but remaps the result back to the requested order; when row + * addresses are projected it also errors rather than silently dropping deleted rows). We still + * deduplicate addresses before calling it — a one-to-many JOIN (this feature's primary target) + * repeats the same source row across many output rows, so reading each distinct blob once both + * avoids redundant reads and removes any reliance on how takeBlobs treats repeated addresses. A + * returned-count mismatch (e.g. a null-descriptor row that take silently drops) is treated as a + * hard error rather than risking a positional skew that would write the wrong bytes downstream. * *

Source datasets are opened through {@link Utils#openDatasetBuilder(LanceSparkReadOptions)} * using the per-source {@link BlobSourceContext} captured on the driver (keyed by dataset URI). @@ -91,8 +102,9 @@ public byte[] resolveIfNeeded(byte[] bytes) throws IOException { /** * Resolves a batch of blob references to their actual bytes, keyed by the caller-supplied vector - * indices. References are grouped by (datasetUri, columnName) and each group is resolved with a - * single {@code takeBlobs()} call. + * indices. References are grouped by (datasetUri, columnName), deduplicated by row address, and + * each group is resolved with a single {@code takeBlobs()} call (see the class javadoc for the + * ordering/dedup contract this relies on). * *

The caller is responsible for writing the resolved bytes into the target vector. Resolved * bytes are returned as a map rather than written here because back-filling a variable-width @@ -102,43 +114,54 @@ public byte[] resolveIfNeeded(byte[] bytes) throws IOException { * @param indices vector indices corresponding to each blob reference * @param refs blob references to resolve * @return a map from vector index to resolved blob bytes - * @throws IOException if reading blobs fails + * @throws IOException if reading blobs fails, or if {@code takeBlobs} returns an unexpected + * count/null that would make positional mapping unsafe */ public Map resolveBatch(List indices, List refs) throws IOException { Map resolved = new HashMap<>(refs.size()); - // Group by (datasetUri, columnName) - Map> groups = new HashMap<>(); + // Group by (datasetUri, columnName), deduplicating row addresses within each group. + Map groups = new HashMap<>(); for (int i = 0; i < refs.size(); i++) { - int vectorIndex = indices.get(i); BlobReference ref = refs.get(i); String groupKey = ref.getDatasetUri() + "\0" + ref.getColumnName(); - groups - .computeIfAbsent(groupKey, k -> new ArrayList<>()) - .add(new IndexedRef(vectorIndex, ref)); + Group group = groups.computeIfAbsent(groupKey, k -> new Group(ref)); + group.add(ref.getRowAddress(), indices.get(i)); } - // Resolve each group with a single takeBlobs() call - for (List group : groups.values()) { - BlobReference first = group.get(0).ref; - Dataset dataset = getOrOpenDataset(first.getDatasetUri()); - - List rowAddresses = new ArrayList<>(group.size()); - for (IndexedRef ir : group) { - rowAddresses.add(ir.ref.getRowAddress()); + // Resolve each group with a single takeBlobs() call over its distinct addresses, then fan the + // bytes back out to every vector index that referenced that address. + for (Group group : groups.values()) { + Dataset dataset = getOrOpenDataset(group.datasetUri); + List addresses = group.distinctAddresses; // requested order + List blobs = dataset.takeBlobs(addresses, group.columnName); + + // takeBlobs must return exactly one BlobFile per requested address, in order. A mismatch + // means the selection hit deleted/null-descriptor rows, in which case positional mapping + // would skew and silently write the wrong bytes into the target table — fail loudly instead. + if (blobs.size() != addresses.size()) { + throw new IOException( + String.format( + "takeBlobs returned %d blobs for %d requested addresses (column=%s, dataset=%s); " + + "cannot map results to rows", + blobs.size(), addresses.size(), group.columnName, group.datasetUri)); } - List blobs = dataset.takeBlobs(rowAddresses, first.getColumnName()); - - for (int i = 0; i < group.size(); i++) { - IndexedRef ir = group.get(i); - if (i < blobs.size()) { - try (BlobFile blob = blobs.get(i)) { - resolved.put(ir.vectorIndex, blob.read()); - } - } else { - resolved.put(ir.vectorIndex, new byte[0]); + for (int i = 0; i < addresses.size(); i++) { + BlobFile blob = blobs.get(i); + if (blob == null) { + throw new IOException( + String.format( + "takeBlobs returned a null blob for address %d (column=%s, dataset=%s)", + addresses.get(i), group.columnName, group.datasetUri)); + } + byte[] data; + try (BlobFile b = blob) { + data = b.read(); + } + for (int vectorIndex : group.indicesByAddress.get(addresses.get(i))) { + resolved.put(vectorIndex, data); } } } @@ -176,13 +199,31 @@ public void close() { datasetCache.clear(); } - private static class IndexedRef { - final int vectorIndex; - final BlobReference ref; + /** + * A set of references sharing one (datasetUri, columnName), with row addresses deduplicated. + * {@link #distinctAddresses} preserves first-seen order and is the exact list handed to {@code + * takeBlobs}; {@link #indicesByAddress} fans each address's resolved bytes back to every vector + * index that referenced it. + */ + private static class Group { + final String datasetUri; + final String columnName; + final List distinctAddresses = new ArrayList<>(); + final Map> indicesByAddress = new HashMap<>(); + + Group(BlobReference first) { + this.datasetUri = first.getDatasetUri(); + this.columnName = first.getColumnName(); + } - IndexedRef(int vectorIndex, BlobReference ref) { - this.vectorIndex = vectorIndex; - this.ref = ref; + void add(long address, int vectorIndex) { + List forAddress = indicesByAddress.get(address); + if (forAddress == null) { + forAddress = new ArrayList<>(); + indicesByAddress.put(address, forAddress); + distinctAddresses.add(address); + } + forAddress.add(vectorIndex); } } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java index 1b1286e25..6e7a14569 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java @@ -84,8 +84,10 @@ public byte[] getBlobReference(int rowId) { if (!hasBlobReferenceContext()) { return new byte[0]; } - Long size = getSize(rowId); - if (size == null || size == 0) { + // Hot path (once per scanned blob row): test size with the primitive accessor instead of + // boxing through getObjectNoOverflow(), which allocates a BigInteger per row. The unsigned + // overflow handling is irrelevant for a null/zero check — only 0L compares equal to zero. + if (sizeVector.isNull(rowId) || sizeVector.get(rowId) == 0L) { // Zero-size blob — either truly empty or null encoded as (0,0) return new byte[0]; } diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala index 1b87180a0..d29c7861f 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala @@ -24,23 +24,27 @@ import org.lance.spark.utils.{BlobReference, BlobReferenceResolver} * [[BlobReference]]s rather than the actual bytes. This writer detects those and resolves them to * real blob bytes via the injected (shared, per-write-task) [[BlobReferenceResolver]]. * - * All per-row values are buffered and the vector is emitted in a single ascending pass in - * [[finish]]. This ordering is required for correctness: resolving references produces bytes for - * arbitrary, non-contiguous indices, and writing into the middle of an already-populated - * variable-width Arrow vector corrupts its offset buffer (`setBytes` reads the start offset from the - * entry being overwritten and only rewrites the next offset, shifting every following row's bytes). - * Buffering one batch's values is bounded by the batch size. + * Buffering is lazy: rows are written straight to the vector in ascending order until the first blob + * reference is seen, after which every subsequent row is buffered and the tail is emitted in a + * single ascending pass in [[finish]]. Buffering only the tail is required for correctness once + * references are present: resolving references produces bytes for arbitrary, non-contiguous indices, + * and writing into the middle of an already-populated variable-width Arrow vector corrupts its + * offset buffer (`setBytes` reads the start offset from the entry being overwritten and only + * rewrites the next offset, shifting every following row's bytes). The common case — a binary column + * with no shuffled references — buffers nothing and writes directly. Buffering is bounded by the + * batch size. */ private[arrow] class LargeBinaryWriter( val valueVector: LargeVarBinaryVector, injectedResolver: BlobReferenceResolver) extends LanceArrowFieldWriter { - // One buffered entry per row, in row order. Each is one of: + // Buffered tail entries, in row order, starting at absolute row index `bufferStart`. Each is: // null -> SQL NULL (validity bit left unset) // Array[Byte] -> literal binary (possibly empty) // BlobReference -> a reference to resolve to actual blob bytes + // `bufferStart` is the absolute index of the first buffered row, or -1 while still writing direct. private val entries = new java.util.ArrayList[AnyRef]() - private var hasRefs = false + private var bufferStart = -1 // Only created when no resolver is injected (e.g. non-shuffle build paths). Owned and closed here. private var localResolver: BlobReferenceResolver = _ @@ -56,42 +60,58 @@ private[arrow] class LargeBinaryWriter( } } - override def setNull(): Unit = entries.add(null) + // `count` (from the base class) is the absolute index of the row currently being written. + override def setNull(): Unit = { + if (bufferStart >= 0) { + entries.add(null) + } + // Direct mode: leave the validity bit unset; setValueCount fills the offset hole on finish. + } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val bytes = input.getBinary(ordinal) if (bytes != null && BlobReference.isBlobReference(bytes)) { + // First reference flips us into buffering mode, starting at this row. + if (bufferStart < 0) { + bufferStart = count + } entries.add(BlobReference.deserialize(bytes)) - hasRefs = true - } else { + } else if (bufferStart >= 0) { + // Buffering mode: defer literals (and null bytes) so the whole tail emits in one pass. entries.add(bytes) + } else if (bytes != null) { + // Direct mode: write literals straight through in ascending order, no buffering needed. + valueVector.setSafe(count, bytes) } + // Direct mode + null bytes: leave the validity bit unset (SQL null), same as setNull. } override def finish(): Unit = { try { - val resolved: java.util.Map[Integer, Array[Byte]] = - if (hasRefs) resolveReferences() else java.util.Collections.emptyMap() + if (bufferStart >= 0) { + val resolved: java.util.Map[Integer, Array[Byte]] = resolveReferences() - // Single ascending pass over the batch: write literals and resolved references in order. - var i = 0 - while (i < entries.size()) { - entries.get(i) match { - case null => // SQL NULL: leave the validity bit unset - case _: BlobReference => - val data = resolved.get(i) - valueVector.setSafe(i, if (data != null) data else Array.emptyByteArray) - case bytes: Array[Byte] => - valueVector.setSafe(i, bytes) - case other => - throw new IllegalStateException(s"Unexpected buffered binary entry: $other") + // Single ascending pass over the buffered tail (rows bufferStart .. count-1). + var j = 0 + while (j < entries.size()) { + val rowId = bufferStart + j + entries.get(j) match { + case null => // SQL NULL: leave the validity bit unset + case _: BlobReference => + val data = resolved.get(rowId) + valueVector.setSafe(rowId, if (data != null) data else Array.emptyByteArray) + case bytes: Array[Byte] => + valueVector.setSafe(rowId, bytes) + case other => + throw new IllegalStateException(s"Unexpected buffered binary entry: $other") + } + j += 1 } - i += 1 } super.finish() } finally { entries.clear() - hasRefs = false + bufferStart = -1 if (localResolver != null) { localResolver.close() localResolver = null @@ -99,19 +119,19 @@ private[arrow] class LargeBinaryWriter( } } - /** Collects the buffered references and resolves them to bytes keyed by their row index. */ + /** Collects the buffered references and resolves them to bytes keyed by their absolute row index. */ private def resolveReferences(): java.util.Map[Integer, Array[Byte]] = { val indices = new java.util.ArrayList[Integer]() val refs = new java.util.ArrayList[BlobReference]() - var i = 0 - while (i < entries.size()) { - entries.get(i) match { + var j = 0 + while (j < entries.size()) { + entries.get(j) match { case ref: BlobReference => - indices.add(i) + indices.add(bufferStart + j) refs.add(ref) case _ => } - i += 1 + j += 1 } try { resolver.resolveBatch(indices, refs) @@ -124,6 +144,6 @@ private[arrow] class LargeBinaryWriter( override def reset(): Unit = { super.reset() entries.clear() - hasRefs = false + bufferStart = -1 } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java index 98539778d..c4b61a8cb 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/BaseBlobJoinTest.java @@ -313,6 +313,97 @@ public void testBlobPreservedDuringJoinAndInsert() throws Exception { spark.sql("DROP TABLE IF EXISTS " + fqTarget); } + /** + * Verifies that a one-to-many JOIN — where a single source blob row fans out to several output + * rows — preserves the blob content for every output row. This exercises the resolver's + * deduplication path: the same source row address is referenced by multiple shuffled rows and + * must resolve to identical bytes in each, rather than skewing positionally. + */ + @Test + public void testBlobPreservedDuringOneToManyJoin() throws Exception { + String blobTable = "blob_one_many_a_" + System.currentTimeMillis(); + String tagTable = "blob_one_many_b_" + System.currentTimeMillis(); + String targetTable = "blob_one_many_target_" + System.currentTimeMillis(); + String fqBlob = catalogName + ".default." + blobTable; + String fqTag = catalogName + ".default." + tagTable; + String fqTarget = catalogName + ".default." + targetTable; + + // Source blob table: one blob per id. + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqBlob + + " (id INT NOT NULL, blob_a BINARY) USING lance " + + "TBLPROPERTIES ('blob_a.lance.encoding' = 'blob')"); + + byte[] blob1 = "blob-for-id-1".getBytes(StandardCharsets.UTF_8); + byte[] blob2 = "blob-for-id-2".getBytes(StandardCharsets.UTF_8); + List blobRows = new ArrayList<>(); + blobRows.add(RowFactory.create(1, blob1)); + blobRows.add(RowFactory.create(2, blob2)); + StructType blobSchema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("blob_a", DataTypes.BinaryType, true) + }); + spark.createDataFrame(blobRows, blobSchema).coalesce(1).writeTo(fqBlob).append(); + + // Tag table with multiple rows per id, so the join fans each blob row out to several outputs. + spark.sql("CREATE TABLE IF NOT EXISTS " + fqTag + " (id INT NOT NULL, tag STRING) USING lance"); + List tagRows = new ArrayList<>(); + tagRows.add(RowFactory.create(1, "a")); + tagRows.add(RowFactory.create(1, "b")); + tagRows.add(RowFactory.create(1, "c")); + tagRows.add(RowFactory.create(2, "d")); + StructType tagSchema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("tag", DataTypes.StringType, true) + }); + spark.createDataFrame(tagRows, tagSchema).coalesce(1).writeTo(fqTag).append(); + + // Target carries the (duplicated) blob plus the tag that made it duplicate. + spark.sql( + "CREATE TABLE IF NOT EXISTS " + + fqTarget + + " (id INT NOT NULL, blob_a BINARY, tag STRING) USING lance " + + "TBLPROPERTIES ('blob_a.lance.encoding' = 'blob')"); + + spark.sql( + "INSERT INTO " + + fqTarget + + " SELECT a.id, a.blob_a, b.tag FROM " + + fqBlob + + " a JOIN " + + fqTag + + " b ON a.id = b.id"); + + Dataset result = + spark.sql("SELECT id, blob_a, tag FROM " + fqTarget + " ORDER BY id, tag"); + List rows = result.collectAsList(); + assertEquals(4, rows.size(), "one-to-many join should produce 4 rows"); + + // id=1 fans out to 3 rows (tags a, b, c), id=2 to 1 row (tag d); blob content must match the + // source blob for the row's id in every case. + try (BlobReferenceResolver resolver = new BlobReferenceResolver()) { + for (Row row : rows) { + int id = row.getInt(0); + byte[] expected = id == 1 ? blob1 : blob2; + byte[] resolved = resolver.resolveIfNeeded((byte[]) row.get(1)); + assertArrayEquals( + expected, resolved, "id=" + id + " tag=" + row.getString(2) + " blob mismatch"); + } + } + + String[] tags = rows.stream().map(r -> r.getString(2)).toArray(String[]::new); + assertArrayEquals(new String[] {"a", "b", "c", "d"}, tags, "tags should be preserved"); + + spark.sql("DROP TABLE IF EXISTS " + fqBlob); + spark.sql("DROP TABLE IF EXISTS " + fqTag); + spark.sql("DROP TABLE IF EXISTS " + fqTarget); + } + /** * Verifies that non-blob columns are preserved correctly during JOIN + INSERT when blob columns * are also present. diff --git a/lance-spark-base_2.12/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRuleTest.scala b/lance-spark-base_2.12/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRuleTest.scala new file mode 100644 index 000000000..03deb7b60 --- /dev/null +++ b/lance-spark-base_2.12/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LanceBlobSourceContextRuleTest.scala @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.connector.catalog.{Table, TableCapability} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{BinaryType, DoubleType, IntegerType, MetadataBuilder, StructType} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.lance.spark.{LanceConstant, LanceDataset, LanceSparkReadOptions} +import org.lance.spark.utils.BlobUtils + +/** + * Unit tests for [[LanceBlobSourceContextRule]]'s application logic, exercised directly on logical + * plans without a SparkSession. Covers the no-op edge cases (non-Lance target, blob-free source, + * already-annotated) alongside the positive annotation path. + * + * Written with JUnit 5 (not ScalaTest) so surefire actually executes them. + */ +class LanceBlobSourceContextRuleTest { + + private val key = LanceConstant.BLOB_SOURCE_CONTEXTS_KEY + private val rule = LanceBlobSourceContextRule() + + private val blobMetadata = new MetadataBuilder() + .putString(BlobUtils.LANCE_ENCODING_BLOB_KEY, BlobUtils.LANCE_ENCODING_BLOB_VALUE) + .build() + + private def blobSchema: StructType = new StructType() + .add("id", IntegerType, nullable = false) + .add("data", BinaryType, nullable = true, blobMetadata) + + private def plainSchema: StructType = new StructType() + .add("id", IntegerType, nullable = false) + .add("score", DoubleType, nullable = true) + + private def lanceTable(uri: String, schema: StructType): LanceDataset = + new LanceDataset( + LanceSparkReadOptions.from(uri), + schema, + java.util.Collections.emptyMap[String, String](), + null, // namespaceImpl + java.util.Collections.emptyMap[String, String](), // namespaceProperties + false, // managedVersioning + null + ) // fileFormatVersion + + private def relation(table: Table): DataSourceV2Relation = + DataSourceV2Relation.create(table, None, None) + + /** A minimal non-Lance table so the rule's isLanceTarget guard sees a foreign target. */ + private class NonLanceTable(tableName: String, tableSchema: StructType) extends Table { + override def name(): String = tableName + override def schema(): StructType = tableSchema + override def capabilities(): java.util.Set[TableCapability] = + java.util.Collections.emptySet[TableCapability]() + } + + private def append( + target: Table, + source: Table, + writeOptions: Map[String, String] = Map.empty): AppendData = + AppendData.byName(relation(target), relation(source), writeOptions) + + @Test + def doesNotAnnotateNonLanceTargets(): Unit = { + val plan = AppendData.byName( + relation(new NonLanceTable("foreign", plainSchema)), + relation(lanceTable("file:///src.lance", blobSchema)), + Map.empty[String, String]) + val result = rule(plan).asInstanceOf[AppendData] + assertFalse(result.writeOptions.contains(key)) + } + + @Test + def doesNotAnnotateWhenSourceHasNoBlobColumns(): Unit = { + val plan = append( + lanceTable("file:///target.lance", blobSchema), + lanceTable("file:///src.lance", plainSchema)) + val result = rule(plan).asInstanceOf[AppendData] + assertFalse(result.writeOptions.contains(key)) + } + + @Test + def annotatesLanceTargetWithBlobSource(): Unit = { + val plan = append( + lanceTable("file:///target.lance", blobSchema), + lanceTable("file:///src.lance", blobSchema)) + val result = rule(plan).asInstanceOf[AppendData] + assertTrue(result.writeOptions.contains(key)) + assertTrue(result.writeOptions(key).nonEmpty) + } + + @Test + def doesNotOverwriteExistingContextsKey(): Unit = { + val plan = append( + lanceTable("file:///target.lance", blobSchema), + lanceTable("file:///src.lance", blobSchema), + Map(key -> "preexisting")) + val result = rule(plan).asInstanceOf[AppendData] + assertEquals("preexisting", result.writeOptions(key)) + } + + @Test + def isIdempotentAcrossRepeatedApplications(): Unit = { + val plan = append( + lanceTable("file:///target.lance", blobSchema), + lanceTable("file:///src.lance", blobSchema)) + val once = rule(plan).asInstanceOf[AppendData] + val twice = rule(once).asInstanceOf[AppendData] + assertTrue(once.writeOptions.contains(key)) + assertEquals(once, twice) + } +} diff --git a/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala b/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala new file mode 100644 index 000000000..6f3ec455b --- /dev/null +++ b/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.arrow + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.LargeVarBinaryVector +import org.apache.spark.sql.catalyst.InternalRow +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.lance.spark.utils.{BlobReference, BlobReferenceResolver} + +import java.nio.charset.StandardCharsets.UTF_8 + +/** + * Regression tests for [[LargeBinaryWriter]]'s buffering and reference-resolution logic. These guard + * the offset-buffer ordering contract (writes must land in ascending order) and the boundary between + * direct writes and the buffered tail that begins at the first blob reference. + * + * Written with JUnit 5 (not ScalaTest) so surefire actually executes them. + */ +class LargeBinaryWriterTest { + + /** A resolver that records what it was asked to resolve and returns deterministic bytes. */ + private class RecordingResolver extends BlobReferenceResolver { + var capturedIndices: java.util.List[Integer] = _ + var capturedRefs: java.util.List[BlobReference] = _ + + override def resolveBatch( + indices: java.util.List[Integer], + refs: java.util.List[BlobReference]): java.util.Map[Integer, Array[Byte]] = { + capturedIndices = indices + capturedRefs = refs + val out = new java.util.HashMap[Integer, Array[Byte]]() + var i = 0 + while (i < indices.size()) { + out.put(indices.get(i), resolvedBytes(refs.get(i).getRowAddress)) + i += 1 + } + out + } + } + + private def resolvedBytes(addr: Long): Array[Byte] = s"resolved-$addr".getBytes(UTF_8) + + private def withVector(body: LargeVarBinaryVector => Unit): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val vector = new LargeVarBinaryVector("blob", allocator) + vector.allocateNew() + try body(vector) + finally { + vector.close() + allocator.close() + } + } + + /** Builds a single-column binary row; `null` produces a SQL NULL. */ + private def row(bytes: Array[Byte]): InternalRow = InternalRow(bytes) + + private def ref(addr: Long): Array[Byte] = + new BlobReference("file:///src.lance", "blob", addr).serialize() + + private def assertBytes(vector: LargeVarBinaryVector, index: Int, expected: Array[Byte]): Unit = { + assertFalse(vector.isNull(index), s"row $index should not be null") + assertArrayEquals(expected, vector.get(index), s"row $index bytes mismatch") + } + + @Test + def directWritesAscendingWhenNoReferences(): Unit = { + withVector { vector => + val writer = new LargeBinaryWriter(vector, null) + writer.write(row("a".getBytes(UTF_8)), 0) + writer.write(row(null), 0) + writer.write(row("bb".getBytes(UTF_8)), 0) + writer.write(row(Array.emptyByteArray), 0) + writer.write(row(null), 0) + writer.finish() + + assertEquals(5, vector.getValueCount) + assertBytes(vector, 0, "a".getBytes(UTF_8)) + assertTrue(vector.isNull(1)) + assertBytes(vector, 2, "bb".getBytes(UTF_8)) + assertBytes(vector, 3, Array.emptyByteArray) // empty is distinct from null + assertTrue(vector.isNull(4)) + } + } + + @Test + def referencesBufferOnlyTheTailAndResolveAtRightIndices(): Unit = { + withVector { vector => + val resolver = new RecordingResolver + val writer = new LargeBinaryWriter(vector, resolver) + writer.write(row("x".getBytes(UTF_8)), 0) // direct + writer.write(row(ref(10)), 0) // first reference -> buffering starts at index 1 + writer.write(row("y".getBytes(UTF_8)), 0) // buffered literal + writer.write(row(ref(20)), 0) // buffered reference + writer.write(row(null), 0) // buffered null + writer.finish() + + // References were collected in ascending order, with their absolute row indices, before + // resolution — and only the references (not the interleaved literal/null) were handed over. + assertEquals(java.util.Arrays.asList[Integer](1, 3), resolver.capturedIndices) + assertEquals(2, resolver.capturedRefs.size()) + assertEquals(10L, resolver.capturedRefs.get(0).getRowAddress) + assertEquals(20L, resolver.capturedRefs.get(1).getRowAddress) + + assertEquals(5, vector.getValueCount) + assertBytes(vector, 0, "x".getBytes(UTF_8)) + assertBytes(vector, 1, resolvedBytes(10)) + assertBytes(vector, 2, "y".getBytes(UTF_8)) + assertBytes(vector, 3, resolvedBytes(20)) + assertTrue(vector.isNull(4)) + } + } + + @Test + def referenceInFirstRowBuffersWholeBatch(): Unit = { + withVector { vector => + val resolver = new RecordingResolver + val writer = new LargeBinaryWriter(vector, resolver) + writer.write(row(ref(5)), 0) + writer.write(row("z".getBytes(UTF_8)), 0) + writer.finish() + + assertEquals(java.util.Arrays.asList[Integer](0), resolver.capturedIndices) + assertBytes(vector, 0, resolvedBytes(5)) + assertBytes(vector, 1, "z".getBytes(UTF_8)) + } + } + + @Test + def ioExceptionDuringResolutionIsPropagatedAsRuntimeException(): Unit = { + withVector { vector => + val resolver = new BlobReferenceResolver { + override def resolveBatch( + indices: java.util.List[Integer], + refs: java.util.List[BlobReference]): java.util.Map[Integer, Array[Byte]] = + throw new java.io.IOException("boom") + } + val writer = new LargeBinaryWriter(vector, resolver) + writer.write(row(ref(1)), 0) + + val ex = assertThrows(classOf[RuntimeException], () => writer.finish()) + assertTrue(ex.getMessage.contains("Failed to resolve blob references")) + assertTrue(ex.getCause.isInstanceOf[java.io.IOException]) + } + } + + @Test + def resetClearsBufferingStateForNextBatch(): Unit = { + withVector { vector => + val resolver = new RecordingResolver + val writer = new LargeBinaryWriter(vector, resolver) + // First batch flips into buffering mode via a reference. + writer.write(row(ref(7)), 0) + writer.finish() + assertBytes(vector, 0, resolvedBytes(7)) + + // After reset, a reference-free batch must write directly (no stale buffer offset). + writer.reset() + resolver.capturedIndices = null + writer.write(row("p".getBytes(UTF_8)), 0) + writer.write(row("qq".getBytes(UTF_8)), 0) + writer.finish() + + assertNull(resolver.capturedIndices, "resolver must not be invoked without references") + assertEquals(2, vector.getValueCount) + assertBytes(vector, 0, "p".getBytes(UTF_8)) + assertBytes(vector, 1, "qq".getBytes(UTF_8)) + } + } +} From a28e26aea1aa88c7901c437c2dbfda6975644cbc Mon Sep 17 00:00:00 2001 From: Daniel Rammer Date: Thu, 21 May 2026 11:50:24 -0500 Subject: [PATCH 4/4] fix: bound blob batch by resolved size to restore maxBatchBytes guard Blob references flow through the shuffle as ~200-byte placeholders, so the per-batch byte guard sized the references, not the blobs they resolve to. The batch capped at 8192 rows; at finish() those references resolved to full blob bytes all at once (resolved map + vector copy), reliably OOMing an executor at the feature's target scale with maxBatchBytes providing no protection. Carry the resolved blob size (already known from the source size vector at read time) in BlobReference, and feed an exact buffered-bytes estimate into both the semaphore and queued write buffers' byte budgets. A batch now flushes before its resolved blobs exceed maxBatchBytes, bounding the Arrow vector; peak transient at resolution stays ~2x maxBatchBytes. - BlobReference: wire format v2 appends 8-byte size - BlobStructAccessor: stamp real size onto each emitted reference - LargeBinaryWriter: track pendingBytes, expose estimatedBufferedBytes - LanceArrowWriter + field-writer base + container writers: propagate estimate - Semaphore/Queued buffers: add estimate to per-batch byte total Tests: size round-trip, estimatedBufferedBytes accounting/reset, and an end-to-end buffer test proving references trip the byte budget. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../org/lance/spark/utils/BlobReference.java | 76 ++++++++---- .../spark/vectorized/BlobStructAccessor.java | 12 +- .../write/QueuedArrowBatchWriteBuffer.java | 6 +- .../write/SemaphoreArrowBatchWriteBuffer.java | 7 +- .../spark/arrow/LanceArrowFieldWriter.scala | 10 ++ .../lance/spark/arrow/LanceArrowWriter.scala | 29 +++++ .../lance/spark/arrow/LargeBinaryWriter.scala | 25 +++- .../lance/spark/utils/BlobReferenceTest.java | 30 +++++ .../SemaphoreArrowBatchWriteBufferTest.java | 112 ++++++++++++++++++ .../spark/arrow/LargeBinaryWriterTest.scala | 28 +++++ 10 files changed, 303 insertions(+), 32 deletions(-) diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java index 5ebb470cd..445f5b488 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/utils/BlobReference.java @@ -30,6 +30,12 @@ * the source dataset, fetches the actual blob bytes via {@code Dataset.takeBlobs()}, and writes * them to the target table. * + *

The trailing {@code size} is the resolved blob's byte length, read straight from the source + * blob descriptor's size vector. It carries no addressing information — resolution only needs the + * dataset/column/rowAddress — but it lets the write side budget the (potentially huge) resolved + * bytes against {@code maxBatchBytes} while the value is still just a ~200-byte reference, so a + * batch flushes before materialization OOMs the executor (see {@code LargeBinaryWriter}). + * *

Wire format: * *

@@ -38,6 +44,7 @@
  *   [2+N bytes] datasetUri (length-prefixed UTF-8)
  *   [2+N bytes] columnName (length-prefixed UTF-8)
  *   [8 bytes] rowAddress
+ *   [8 bytes] size (resolved blob byte length)
  * 
*/ public class BlobReference { @@ -45,19 +52,26 @@ public class BlobReference { /** 8-byte magic header to identify a serialized BlobReference. */ public static final byte[] MAGIC = {'L', 'A', 'N', 'C', 'E', 'R', 'E', 'F'}; - /** Min byte length: magic(8) + version(1) + two empty strings(2+2) + rowAddress(8). */ - private static final int MIN_SIZE = MAGIC.length + 1 + 2 + 2 + 8; + /** Min byte length: magic(8) + version(1) + two empty strings(2+2) + rowAddress(8) + size(8). */ + private static final int MIN_SIZE = MAGIC.length + 1 + 2 + 2 + 8 + 8; - private static final byte VERSION = 1; + private static final byte VERSION = 2; private final String datasetUri; private final String columnName; private final long rowAddress; + private final long size; + /** Constructs a reference with an unknown resolved size ({@code 0}); see {@link #getSize()}. */ public BlobReference(String datasetUri, String columnName, long rowAddress) { + this(datasetUri, columnName, rowAddress, 0L); + } + + public BlobReference(String datasetUri, String columnName, long rowAddress, long size) { this.datasetUri = datasetUri; this.columnName = columnName; this.rowAddress = rowAddress; + this.size = size; } /** @@ -90,7 +104,7 @@ public static boolean isBlobReference(byte[] bytes) { if (colLen < 0 || colLen > remaining) { return false; } - int expectedRemaining = colLen + 8; + int expectedRemaining = colLen + 8 + 8; // columnName + rowAddress + size return remaining == expectedRemaining; } catch (IOException e) { return false; @@ -99,16 +113,17 @@ public static boolean isBlobReference(byte[] bytes) { /** Serialize this reference to a compact byte array. */ public byte[] serialize() { - return appendRowAddress(serializePrefix(datasetUri, columnName), rowAddress); + return appendRowAddressAndSize(serializePrefix(datasetUri, columnName), rowAddress, size); } /** - * Serializes the constant portion of a reference: everything except the trailing 8-byte - * rowAddress (i.e. magic + version + datasetUri + columnName). + * Serializes the constant portion of a reference: everything except the trailing per-row + * rowAddress and size (i.e. magic + version + datasetUri + columnName). * *

{@code datasetUri} and {@code columnName} are constant for an entire scan batch, so callers * on the per-row read hot path should compute this prefix once and then call {@link - * #appendRowAddress(byte[], long)} per row instead of re-encoding the strings every time. + * #appendRowAddressAndSize(byte[], long, long)} per row instead of re-encoding the strings every + * time. */ public static byte[] serializePrefix(String datasetUri, String columnName) { try { @@ -127,22 +142,27 @@ public static byte[] serializePrefix(String datasetUri, String columnName) { /** * Returns a full serialized reference: {@code prefix} (from {@link #serializePrefix}) followed by - * the 8-byte big-endian {@code rowAddress}, matching {@link DataOutputStream#writeLong}. + * the 8-byte big-endian {@code rowAddress} and 8-byte big-endian {@code size}, each matching + * {@link DataOutputStream#writeLong}. */ - public static byte[] appendRowAddress(byte[] prefix, long rowAddress) { - byte[] out = Arrays.copyOf(prefix, prefix.length + 8); - int off = prefix.length; - out[off] = (byte) (rowAddress >>> 56); - out[off + 1] = (byte) (rowAddress >>> 48); - out[off + 2] = (byte) (rowAddress >>> 40); - out[off + 3] = (byte) (rowAddress >>> 32); - out[off + 4] = (byte) (rowAddress >>> 24); - out[off + 5] = (byte) (rowAddress >>> 16); - out[off + 6] = (byte) (rowAddress >>> 8); - out[off + 7] = (byte) rowAddress; + public static byte[] appendRowAddressAndSize(byte[] prefix, long rowAddress, long size) { + byte[] out = Arrays.copyOf(prefix, prefix.length + 16); + writeLongBE(out, prefix.length, rowAddress); + writeLongBE(out, prefix.length + 8, size); return out; } + private static void writeLongBE(byte[] out, int off, long value) { + out[off] = (byte) (value >>> 56); + out[off + 1] = (byte) (value >>> 48); + out[off + 2] = (byte) (value >>> 40); + out[off + 3] = (byte) (value >>> 32); + out[off + 4] = (byte) (value >>> 24); + out[off + 5] = (byte) (value >>> 16); + out[off + 6] = (byte) (value >>> 8); + out[off + 7] = (byte) value; + } + /** Deserialize a BlobReference from bytes. */ public static BlobReference deserialize(byte[] bytes) { if (!isBlobReference(bytes)) { @@ -155,7 +175,8 @@ public static BlobReference deserialize(byte[] bytes) { String datasetUri = readString(in); String columnName = readString(in); long rowAddress = in.readLong(); - return new BlobReference(datasetUri, columnName, rowAddress); + long size = in.readLong(); + return new BlobReference(datasetUri, columnName, rowAddress, size); } catch (IOException e) { throw new RuntimeException("Failed to deserialize BlobReference", e); } @@ -188,10 +209,19 @@ public long getRowAddress() { return rowAddress; } + /** + * The resolved blob's byte length, captured from the source size vector at read time. Used by the + * write side to budget resolved bytes against {@code maxBatchBytes}; {@code 0} when unknown (e.g. + * a reference constructed without a size). + */ + public long getSize() { + return size; + } + @Override public String toString() { return String.format( - "BlobReference{dataset=%s, column=%s, rowAddr=0x%016X}", - datasetUri, columnName, rowAddress); + "BlobReference{dataset=%s, column=%s, rowAddr=0x%016X, size=%d}", + datasetUri, columnName, rowAddress, size); } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java index 6e7a14569..c97027c72 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/vectorized/BlobStructAccessor.java @@ -84,14 +84,18 @@ public byte[] getBlobReference(int rowId) { if (!hasBlobReferenceContext()) { return new byte[0]; } - // Hot path (once per scanned blob row): test size with the primitive accessor instead of + // Hot path (once per scanned blob row): read size with the primitive accessor instead of // boxing through getObjectNoOverflow(), which allocates a BigInteger per row. The unsigned - // overflow handling is irrelevant for a null/zero check — only 0L compares equal to zero. - if (sizeVector.isNull(rowId) || sizeVector.get(rowId) == 0L) { + // overflow handling is irrelevant here — a real blob never approaches 2^63 bytes, and only 0L + // compares equal to zero for the null/empty check below. + long blobSize = sizeVector.isNull(rowId) ? 0L : sizeVector.get(rowId); + if (blobSize == 0L) { // Zero-size blob — either truly empty or null encoded as (0,0) return new byte[0]; } - return BlobReference.appendRowAddress(referencePrefix, rowAddresses[rowId]); + // Carry the blob size so the write side can budget resolved bytes against maxBatchBytes before + // the reference is materialized into actual blob bytes. + return BlobReference.appendRowAddressAndSize(referencePrefix, rowAddresses[rowId], blobSize); } public InternalRow getStruct(int rowId) { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java index 7e8982abb..de6c4a769 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java @@ -245,7 +245,11 @@ private boolean isBatchFullByBytes() { if (maxBatchBytes == Long.MAX_VALUE) { return false; } - return currentBatchAllocator.getAllocatedMemory() >= maxBatchBytes; + // Include bytes buffered outside the vector (unresolved blob references resolve to far larger + // bytes on finish); sizing only the ~200-byte references would let the batch grow until + // resolution OOMs the executor. + return currentBatchAllocator.getAllocatedMemory() + currentArrowWriter.estimatedBufferedBytes() + >= maxBatchBytes; } /** diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java index f641b531a..7b7e6f0d4 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java @@ -186,7 +186,12 @@ public void write(InternalRow row) { } arrowWriter.write(row); - currentBatchBytes = this.allocator.getAllocatedMemory() - batchStartBytes; + // Add bytes buffered outside the vectors (unresolved blob references resolve to far larger + // bytes on finish); without this the guard would size the ~200-byte references and let the + // batch grow until resolution OOMs the executor. + currentBatchBytes = + (this.allocator.getAllocatedMemory() - batchStartBytes) + + arrowWriter.estimatedBufferedBytes(); count++; if (isBatchFull()) { diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowFieldWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowFieldWriter.scala index 61f019fe6..7eb528c1e 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowFieldWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowFieldWriter.scala @@ -40,6 +40,16 @@ abstract private[arrow] class LanceArrowFieldWriter { private[arrow] var count: Int = 0 + /** + * Estimated bytes this writer is holding outside its Arrow vector that will materialize into the + * vector on [[finish]]. Almost always 0 — writers put values straight into the vector, so the + * allocator already accounts for them. The exception is [[LargeBinaryWriter]], which buffers + * unresolved blob references (~200 bytes each on the shuffle path) whose resolved bytes can be + * orders of magnitude larger; reporting that pending size lets the write buffer flush against + * `maxBatchBytes` before resolution materializes the blobs and OOMs the executor. + */ + def estimatedBufferedBytes: Long = 0L + def write(input: SpecializedGetters, ordinal: Int): Unit = { if (input.isNullAt(ordinal)) { setNull() diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala index 79900ced4..f1ffb915b 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LanceArrowWriter.scala @@ -171,6 +171,21 @@ class LanceArrowWriter(root: VectorSchemaRoot, fields: Array[LanceArrowFieldWrit root.setRowCount(0) } + /** + * Bytes buffered outside the Arrow vectors that will materialize on [[finish]] (resolved blob + * references; see [[LargeBinaryWriter]]). Write buffers add this to the allocator's measured size + * so the per-batch byte guard accounts for blob bytes that are still cheap references. + */ + def estimatedBufferedBytes: Long = { + var sum = 0L + var i = 0 + while (i < fields.length) { + sum += fields(i).estimatedBufferedBytes + i += 1 + } + sum + } + def field(index: Int): LanceArrowFieldWriter = fields(index) } @@ -209,6 +224,8 @@ private[arrow] class FixedSizeListWriter( elementWriter.finish() } + override def estimatedBufferedBytes: Long = elementWriter.estimatedBufferedBytes + override def reset(): Unit = { super.reset() elementWriter.reset() @@ -391,6 +408,7 @@ private[arrow] class ArrayWriter( super.finish() elementWriter.finish() } + override def estimatedBufferedBytes: Long = elementWriter.estimatedBufferedBytes override def reset(): Unit = { super.reset() elementWriter.reset() @@ -428,6 +446,8 @@ private[arrow] class MapWriter( keyWriter.finish() valueWriter.finish() } + override def estimatedBufferedBytes: Long = + keyWriter.estimatedBufferedBytes + valueWriter.estimatedBufferedBytes override def reset(): Unit = { super.reset() structVector.reset() @@ -468,6 +488,15 @@ private[arrow] class StructWriter( super.finish() children.foreach(_.finish()) } + override def estimatedBufferedBytes: Long = { + var sum = 0L + var i = 0 + while (i < children.length) { + sum += children(i).estimatedBufferedBytes + i += 1 + } + sum + } override def reset(): Unit = { super.reset() children.foreach(_.reset()) diff --git a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala index d29c7861f..9b30d3999 100644 --- a/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/org/lance/spark/arrow/LargeBinaryWriter.scala @@ -31,8 +31,12 @@ import org.lance.spark.utils.{BlobReference, BlobReferenceResolver} * and writing into the middle of an already-populated variable-width Arrow vector corrupts its * offset buffer (`setBytes` reads the start offset from the entry being overwritten and only * rewrites the next offset, shifting every following row's bytes). The common case — a binary column - * with no shuffled references — buffers nothing and writes directly. Buffering is bounded by the - * batch size. + * with no shuffled references — buffers nothing and writes directly. + * + * Resolved blobs can be orders of magnitude larger than the buffered references, so the buffered + * tail must stay bounded by bytes, not just row count: [[estimatedBufferedBytes]] reports the + * resolved size carried by each reference, which the write buffer adds to its per-batch byte budget + * so the batch flushes (and this tail resolves) before materialization exceeds `maxBatchBytes`. */ private[arrow] class LargeBinaryWriter( val valueVector: LargeVarBinaryVector, @@ -46,6 +50,13 @@ private[arrow] class LargeBinaryWriter( private val entries = new java.util.ArrayList[AnyRef]() private var bufferStart = -1 + // Sum of the bytes the buffered tail will occupy in the vector once emitted in `finish`: each + // reference's resolved blob size (carried in the reference) plus each buffered literal's length. + // The write buffer reads this via `estimatedBufferedBytes` to budget against maxBatchBytes — the + // buffered references are tiny (~200 bytes) but resolve to potentially huge blobs, so without this + // the byte guard sizes the references and the batch can balloon to tens of GB at resolution time. + private var pendingBytes = 0L + // Only created when no resolver is injected (e.g. non-shuffle build paths). Owned and closed here. private var localResolver: BlobReferenceResolver = _ @@ -75,10 +86,13 @@ private[arrow] class LargeBinaryWriter( if (bufferStart < 0) { bufferStart = count } - entries.add(BlobReference.deserialize(bytes)) + val ref = BlobReference.deserialize(bytes) + entries.add(ref) + pendingBytes += ref.getSize } else if (bufferStart >= 0) { // Buffering mode: defer literals (and null bytes) so the whole tail emits in one pass. entries.add(bytes) + if (bytes != null) pendingBytes += bytes.length } else if (bytes != null) { // Direct mode: write literals straight through in ascending order, no buffering needed. valueVector.setSafe(count, bytes) @@ -86,6 +100,9 @@ private[arrow] class LargeBinaryWriter( // Direct mode + null bytes: leave the validity bit unset (SQL null), same as setNull. } + // Resolved blob bytes (plus deferred literals) the buffered tail will add to the vector on finish. + override def estimatedBufferedBytes: Long = pendingBytes + override def finish(): Unit = { try { if (bufferStart >= 0) { @@ -112,6 +129,7 @@ private[arrow] class LargeBinaryWriter( } finally { entries.clear() bufferStart = -1 + pendingBytes = 0L if (localResolver != null) { localResolver.close() localResolver = null @@ -145,5 +163,6 @@ private[arrow] class LargeBinaryWriter( super.reset() entries.clear() bufferStart = -1 + pendingBytes = 0L } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java index 396333bf2..f0e652afd 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/utils/BlobReferenceTest.java @@ -32,6 +32,36 @@ public void testRoundTripSerialization() { assertEquals(original.getRowAddress(), deserialized.getRowAddress()); } + @Test + public void testRoundTripPreservesSize() { + BlobReference original = + new BlobReference("/tmp/my-dataset", "image_col", 0x0003_0000_0042L, 5L * 1024 * 1024); + + byte[] serialized = original.serialize(); + assertTrue(BlobReference.isBlobReference(serialized)); + + BlobReference deserialized = BlobReference.deserialize(serialized); + assertEquals(original.getRowAddress(), deserialized.getRowAddress()); + assertEquals(5L * 1024 * 1024, deserialized.getSize()); + } + + @Test + public void testAppendRowAddressAndSizeMatchesSerialize() { + byte[] prefix = BlobReference.serializePrefix("/tmp/ds", "col"); + byte[] perRow = BlobReference.appendRowAddressAndSize(prefix, 99L, 4096L); + + BlobReference deserialized = BlobReference.deserialize(perRow); + assertEquals("/tmp/ds", deserialized.getDatasetUri()); + assertEquals("col", deserialized.getColumnName()); + assertEquals(99L, deserialized.getRowAddress()); + assertEquals(4096L, deserialized.getSize()); + } + + @Test + public void testUnsizedConstructorDefaultsToZero() { + assertEquals(0L, new BlobReference("uri", "col", 1L).getSize()); + } + @Test public void testRoundTripWithUnicodeUri() { BlobReference original = new BlobReference("s3://bucket/path/日本語", "データ", 123456789L); diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java index 2d5d67076..e273cba21 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java @@ -13,6 +13,9 @@ */ package org.lance.spark.write; +import org.lance.spark.utils.BlobReference; +import org.lance.spark.utils.BlobReferenceResolver; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; @@ -31,6 +34,9 @@ import org.junit.jupiter.api.Test; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; @@ -288,6 +294,112 @@ private UTF8String generateLargeString(int sizeBytes) { return UTF8String.fromBytes(data); } + private Schema createBlobSchema() { + Field field = + new Field( + "blob", + FieldType.nullable( + org.apache.arrow.vector.types.Types.MinorType.LARGEVARBINARY.getType()), + null); + return new Schema(Collections.singletonList(field)); + } + + private StructType createBlobSparkSchema() { + return new StructType( + new StructField[] {DataTypes.createStructField("blob", DataTypes.BinaryType, true)}); + } + + /** Resolver that fabricates {@code size} bytes per reference — no source dataset needed. */ + private static class SizedFakeResolver extends BlobReferenceResolver { + @Override + public Map resolveBatch(List indices, List refs) { + Map out = new HashMap<>(); + for (int i = 0; i < indices.size(); i++) { + out.put(indices.get(i), new byte[(int) refs.get(i).getSize()]); + } + return out; + } + } + + @Test + public void testByteBasedFlushAccountsForUnresolvedBlobReferences() throws Exception { + // Regression: blob references are ~200-byte placeholders in the vector but resolve to large + // blobs on finish(). The byte guard must size the *resolved* blobs (carried in each reference), + // otherwise the references never trip maxBatchBytes and the whole batch materializes at once. + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + Schema schema = createBlobSchema(); + StructType sparkSchema = createBlobSparkSchema(); + + final int totalRows = 12; + final int batchSize = 1000; // High row limit - only the byte budget should flush + final long maxBatchBytes = 1024 * 1024; // 1MB + final long blobSize = 400 * 1024; // 400KB resolved per reference -> flush every ~3 rows + final SemaphoreArrowBatchWriteBuffer writeBuffer = + new SemaphoreArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, maxBatchBytes, new SizedFakeResolver()); + + AtomicInteger rowsWritten = new AtomicInteger(0); + AtomicInteger rowsRead = new AtomicInteger(0); + AtomicInteger batchCount = new AtomicInteger(0); + AtomicInteger maxRowsInBatch = new AtomicInteger(0); + + Thread writerThread = + new Thread( + () -> { + try { + for (int i = 0; i < totalRows; i++) { + byte[] reference = + new BlobReference("file:///src.lance", "blob", i, blobSize).serialize(); + writeBuffer.write(new GenericInternalRow(new Object[] {reference})); + rowsWritten.incrementAndGet(); + } + } finally { + writeBuffer.setFinished(); + } + }); + + Callable readerCallable = + () -> { + while (writeBuffer.loadNextBatch()) { + VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); + int rowCount = root.getRowCount(); + for (int i = 0; i < rowCount; i++) { + // Each placeholder reference must have resolved to a full-size blob in the vector. + byte[] resolved = (byte[]) root.getVector("blob").getObject(i); + assertEquals(blobSize, resolved.length); + } + rowsRead.addAndGet(rowCount); + batchCount.incrementAndGet(); + maxRowsInBatch.updateAndGet(prev -> Math.max(prev, rowCount)); + } + return null; + }; + + FutureTask readerTask = writeBuffer.createTrackedTask(readerCallable); + Thread readerThread = new Thread(readerTask); + writerThread.start(); + readerThread.start(); + writerThread.join(); + readerThread.join(); + + try { + assertEquals(totalRows, rowsWritten.get()); + assertEquals(totalRows, rowsRead.get()); + // Without sizing the resolved blobs, 12 tiny references stay under 1MB and form one batch. + Assertions.assertTrue( + batchCount.get() > 1, + "Blob references should trip the byte budget, but got " + + batchCount.get() + + " batches"); + Assertions.assertTrue( + maxRowsInBatch.get() < batchSize, + "Max rows per batch (" + maxRowsInBatch.get() + ") should be bounded by bytes"); + } finally { + writeBuffer.close(); + } + } + } + @Test public void testByteBasedFlushWithSmallRows() throws Exception { // With small rows, the row count limit should be reached before byte limit. diff --git a/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala b/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala index 6f3ec455b..70e4e3818 100644 --- a/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala +++ b/lance-spark-base_2.12/src/test/scala/org/lance/spark/arrow/LargeBinaryWriterTest.scala @@ -70,6 +70,9 @@ class LargeBinaryWriterTest { private def ref(addr: Long): Array[Byte] = new BlobReference("file:///src.lance", "blob", addr).serialize() + private def ref(addr: Long, size: Long): Array[Byte] = + new BlobReference("file:///src.lance", "blob", addr, size).serialize() + private def assertBytes(vector: LargeVarBinaryVector, index: Int, expected: Array[Byte]): Unit = { assertFalse(vector.isNull(index), s"row $index should not be null") assertArrayEquals(expected, vector.get(index), s"row $index bytes mismatch") @@ -156,6 +159,31 @@ class LargeBinaryWriterTest { } } + @Test + def estimatedBufferedBytesTracksResolvedSizesNotReferenceSizes(): Unit = { + withVector { vector => + val writer = new LargeBinaryWriter(vector, new RecordingResolver) + // Direct mode (no references yet): nothing buffered, so the byte budget sees nothing extra. + writer.write(row("x".getBytes(UTF_8)), 0) + assertEquals(0L, writer.estimatedBufferedBytes) + + // First reference flips into buffering: the budget must reflect the resolved blob size + // (1 MB here), not the ~200-byte reference that is actually buffered. + writer.write(row(ref(10, 1024 * 1024)), 0) + assertEquals(1024L * 1024, writer.estimatedBufferedBytes) + + // Buffered literals count their own length; a second reference adds its resolved size. + writer.write(row("yz".getBytes(UTF_8)), 0) + writer.write(row(ref(20, 512)), 0) + writer.write(row(null), 0) // buffered null contributes nothing + assertEquals(1024L * 1024 + 2 + 512, writer.estimatedBufferedBytes) + + // finish() drains the tail and clears the running total for the next batch. + writer.finish() + assertEquals(0L, writer.estimatedBufferedBytes) + } + } + @Test def resetClearsBufferingStateForNextBatch(): Unit = { withVector { vector =>