Skip to content

Commit 33d87cd

Browse files
andygroveclaude
andcommitted
Replace mutable buffers with immutable Arrow vectors in NativeBatchReader
Add ImmutableConstantColumnReader that creates Arrow vectors directly in Java without using native Rust mutable buffers. This is used for partition columns and missing columns in NativeBatchReader. Key changes: - New ImmutableConstantColumnReader creates Arrow vectors using Arrow Java APIs, supporting primitive types (Boolean, Byte, Short, Integer, Long, Float, Double, String, Binary, Date, Timestamp, Decimal, Null) - NativeBatchReader now uses ImmutableConstantColumnReader instead of ConstantColumnReader for partition and missing columns - CometScanRule checks partition column types at planning time and falls back to Spark if complex types (StructType, ArrayType, MapType) are used, since ImmutableConstantColumnReader only supports primitives Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 0cc8fbe commit 33d87cd

3 files changed

Lines changed: 392 additions & 7 deletions

File tree

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.parquet;
21+
22+
import org.apache.arrow.memory.BufferAllocator;
23+
import org.apache.arrow.memory.RootAllocator;
24+
import org.apache.arrow.vector.*;
25+
import org.apache.arrow.vector.types.DateUnit;
26+
import org.apache.arrow.vector.types.FloatingPointPrecision;
27+
import org.apache.arrow.vector.types.TimeUnit;
28+
import org.apache.arrow.vector.types.pojo.ArrowType;
29+
import org.apache.arrow.vector.types.pojo.Field;
30+
import org.apache.arrow.vector.types.pojo.FieldType;
31+
import org.apache.spark.sql.catalyst.InternalRow;
32+
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns;
33+
import org.apache.spark.sql.types.*;
34+
import org.apache.spark.unsafe.types.UTF8String;
35+
36+
import org.apache.comet.vector.CometPlainVector;
37+
import org.apache.comet.vector.CometVector;
38+
39+
/**
40+
* A column reader that returns constant vectors without using native mutable buffers. This is used
41+
* for reading partition columns and missing columns in NativeBatchReader.
42+
*
43+
* <p>Unlike {@link ConstantColumnReader} which uses native Rust code with mutable buffers, this
44+
* implementation creates Arrow vectors directly in Java using Arrow's immutable buffer APIs.
45+
*/
46+
public class ImmutableConstantColumnReader extends AbstractColumnReader {
47+
48+
/**
49+
* Checks if the given Spark DataType is supported by this reader. This is used at query planning
50+
* time to determine if NativeBatchReader can handle the partition schema or if it should fall
51+
* back to Spark.
52+
*
53+
* @param type the Spark DataType to check
54+
* @return true if the type is supported, false otherwise
55+
*/
56+
public static boolean isTypeSupported(DataType type) {
57+
if (type == DataTypes.BooleanType
58+
|| type == DataTypes.ByteType
59+
|| type == DataTypes.ShortType
60+
|| type == DataTypes.IntegerType
61+
|| type == DataTypes.LongType
62+
|| type == DataTypes.FloatType
63+
|| type == DataTypes.DoubleType
64+
|| type == DataTypes.StringType
65+
|| type == DataTypes.BinaryType
66+
|| type == DataTypes.DateType
67+
|| type == DataTypes.TimestampType
68+
|| type == TimestampNTZType$.MODULE$
69+
|| type == DataTypes.NullType
70+
|| type instanceof DecimalType) {
71+
return true;
72+
}
73+
// Complex types (StructType, ArrayType, MapType) and other types are not supported
74+
return false;
75+
}
76+
77+
private final BufferAllocator allocator = new RootAllocator();
78+
79+
/** Whether all the values in this constant column are nulls */
80+
private boolean isNull;
81+
82+
/** The constant value */
83+
private Object value;
84+
85+
/** The current vector */
86+
private CometVector vector;
87+
88+
/** The Arrow field type for this column */
89+
private final Field arrowField;
90+
91+
/** Constructor for missing columns with default values */
92+
ImmutableConstantColumnReader(StructField field, int batchSize, boolean useDecimal128) {
93+
super(field.dataType(), TypeUtil.convertToParquet(field), useDecimal128, false);
94+
this.batchSize = batchSize;
95+
this.arrowField = toArrowField(field);
96+
this.value =
97+
ResolveDefaultColumns.getExistenceDefaultValues(new StructType(new StructField[] {field}))[
98+
0];
99+
this.isNull = (this.value == null);
100+
}
101+
102+
/** Constructor for partition columns */
103+
ImmutableConstantColumnReader(
104+
StructField field, int batchSize, InternalRow values, int index, boolean useDecimal128) {
105+
super(field.dataType(), TypeUtil.convertToParquet(field), useDecimal128, false);
106+
this.batchSize = batchSize;
107+
this.arrowField = toArrowField(field);
108+
this.value = values.get(index, field.dataType());
109+
this.isNull = (this.value == null);
110+
}
111+
112+
@Override
113+
public void setBatchSize(int batchSize) {
114+
close();
115+
this.batchSize = batchSize;
116+
}
117+
118+
@Override
119+
public void readBatch(int total) {
120+
if (vector != null) {
121+
vector.close();
122+
vector = null;
123+
}
124+
vector = createConstantVector(total);
125+
}
126+
127+
@Override
128+
public CometVector currentBatch() {
129+
return vector;
130+
}
131+
132+
@Override
133+
public void close() {
134+
if (vector != null) {
135+
vector.close();
136+
vector = null;
137+
}
138+
}
139+
140+
@Override
141+
protected void initNative() {
142+
// No native initialization needed - we create vectors purely in Java
143+
nativeHandle = 0;
144+
}
145+
146+
/** Creates a constant Arrow vector with the specified number of rows. */
147+
private CometVector createConstantVector(int numRows) {
148+
ValueVector arrowVector = createArrowVector(numRows);
149+
return new CometPlainVector(arrowVector, useDecimal128);
150+
}
151+
152+
/** Creates an Arrow vector filled with constant values. */
153+
private ValueVector createArrowVector(int numRows) {
154+
if (isNull) {
155+
return createNullVector(numRows);
156+
}
157+
158+
if (type == DataTypes.BooleanType) {
159+
return createBooleanVector(numRows, (Boolean) value);
160+
} else if (type == DataTypes.ByteType) {
161+
return createByteVector(numRows, (Byte) value);
162+
} else if (type == DataTypes.ShortType) {
163+
return createShortVector(numRows, (Short) value);
164+
} else if (type == DataTypes.IntegerType) {
165+
return createIntVector(numRows, (Integer) value);
166+
} else if (type == DataTypes.LongType) {
167+
return createLongVector(numRows, (Long) value);
168+
} else if (type == DataTypes.FloatType) {
169+
return createFloatVector(numRows, (Float) value);
170+
} else if (type == DataTypes.DoubleType) {
171+
return createDoubleVector(numRows, (Double) value);
172+
} else if (type == DataTypes.StringType) {
173+
return createStringVector(numRows, (UTF8String) value);
174+
} else if (type == DataTypes.BinaryType) {
175+
return createBinaryVector(numRows, (byte[]) value);
176+
} else if (type == DataTypes.DateType) {
177+
return createDateVector(numRows, (Integer) value);
178+
} else if (type == DataTypes.TimestampType || type == TimestampNTZType$.MODULE$) {
179+
return createTimestampVector(numRows, (Long) value);
180+
} else if (type instanceof DecimalType) {
181+
return createDecimalVector(numRows, (Decimal) value, (DecimalType) type);
182+
} else {
183+
throw new UnsupportedOperationException("Unsupported Spark type: " + type);
184+
}
185+
}
186+
187+
private ValueVector createNullVector(int numRows) {
188+
NullVector vector = new NullVector(arrowField.getName(), numRows);
189+
return vector;
190+
}
191+
192+
private ValueVector createBooleanVector(int numRows, boolean value) {
193+
BitVector vector = new BitVector(arrowField, allocator);
194+
vector.allocateNew(numRows);
195+
for (int i = 0; i < numRows; i++) {
196+
vector.set(i, value ? 1 : 0);
197+
}
198+
vector.setValueCount(numRows);
199+
return vector;
200+
}
201+
202+
private ValueVector createByteVector(int numRows, byte value) {
203+
TinyIntVector vector = new TinyIntVector(arrowField, allocator);
204+
vector.allocateNew(numRows);
205+
for (int i = 0; i < numRows; i++) {
206+
vector.set(i, value);
207+
}
208+
vector.setValueCount(numRows);
209+
return vector;
210+
}
211+
212+
private ValueVector createShortVector(int numRows, short value) {
213+
SmallIntVector vector = new SmallIntVector(arrowField, allocator);
214+
vector.allocateNew(numRows);
215+
for (int i = 0; i < numRows; i++) {
216+
vector.set(i, value);
217+
}
218+
vector.setValueCount(numRows);
219+
return vector;
220+
}
221+
222+
private ValueVector createIntVector(int numRows, int value) {
223+
IntVector vector = new IntVector(arrowField, allocator);
224+
vector.allocateNew(numRows);
225+
for (int i = 0; i < numRows; i++) {
226+
vector.set(i, value);
227+
}
228+
vector.setValueCount(numRows);
229+
return vector;
230+
}
231+
232+
private ValueVector createLongVector(int numRows, long value) {
233+
BigIntVector vector = new BigIntVector(arrowField, allocator);
234+
vector.allocateNew(numRows);
235+
for (int i = 0; i < numRows; i++) {
236+
vector.set(i, value);
237+
}
238+
vector.setValueCount(numRows);
239+
return vector;
240+
}
241+
242+
private ValueVector createFloatVector(int numRows, float value) {
243+
Float4Vector vector = new Float4Vector(arrowField, allocator);
244+
vector.allocateNew(numRows);
245+
for (int i = 0; i < numRows; i++) {
246+
vector.set(i, value);
247+
}
248+
vector.setValueCount(numRows);
249+
return vector;
250+
}
251+
252+
private ValueVector createDoubleVector(int numRows, double value) {
253+
Float8Vector vector = new Float8Vector(arrowField, allocator);
254+
vector.allocateNew(numRows);
255+
for (int i = 0; i < numRows; i++) {
256+
vector.set(i, value);
257+
}
258+
vector.setValueCount(numRows);
259+
return vector;
260+
}
261+
262+
private ValueVector createStringVector(int numRows, UTF8String value) {
263+
VarCharVector vector = new VarCharVector(arrowField, allocator);
264+
byte[] bytes = value.getBytes();
265+
vector.allocateNew((long) bytes.length * numRows, numRows);
266+
for (int i = 0; i < numRows; i++) {
267+
vector.set(i, bytes);
268+
}
269+
vector.setValueCount(numRows);
270+
return vector;
271+
}
272+
273+
private ValueVector createBinaryVector(int numRows, byte[] value) {
274+
VarBinaryVector vector = new VarBinaryVector(arrowField, allocator);
275+
vector.allocateNew((long) value.length * numRows, numRows);
276+
for (int i = 0; i < numRows; i++) {
277+
vector.set(i, value);
278+
}
279+
vector.setValueCount(numRows);
280+
return vector;
281+
}
282+
283+
private ValueVector createDateVector(int numRows, int value) {
284+
DateDayVector vector = new DateDayVector(arrowField, allocator);
285+
vector.allocateNew(numRows);
286+
for (int i = 0; i < numRows; i++) {
287+
vector.set(i, value);
288+
}
289+
vector.setValueCount(numRows);
290+
return vector;
291+
}
292+
293+
private ValueVector createTimestampVector(int numRows, long value) {
294+
TimeStampMicroTZVector vector = new TimeStampMicroTZVector(arrowField, allocator);
295+
vector.allocateNew(numRows);
296+
for (int i = 0; i < numRows; i++) {
297+
vector.set(i, value);
298+
}
299+
vector.setValueCount(numRows);
300+
return vector;
301+
}
302+
303+
private ValueVector createDecimalVector(int numRows, Decimal value, DecimalType dt) {
304+
DecimalVector vector =
305+
new DecimalVector(arrowField.getName(), allocator, dt.precision(), dt.scale());
306+
vector.allocateNew(numRows);
307+
308+
java.math.BigDecimal bigDecimal = value.toJavaBigDecimal();
309+
for (int i = 0; i < numRows; i++) {
310+
vector.set(i, bigDecimal);
311+
}
312+
vector.setValueCount(numRows);
313+
return vector;
314+
}
315+
316+
/** Converts a Spark StructField to an Arrow Field. */
317+
private Field toArrowField(StructField field) {
318+
ArrowType arrowType = toArrowType(field.dataType());
319+
FieldType fieldType = new FieldType(field.nullable(), arrowType, null);
320+
return new Field(field.name(), fieldType, null);
321+
}
322+
323+
/** Converts a Spark DataType to an Arrow ArrowType. */
324+
private ArrowType toArrowType(DataType type) {
325+
if (type == DataTypes.BooleanType) {
326+
return ArrowType.Bool.INSTANCE;
327+
} else if (type == DataTypes.ByteType) {
328+
return new ArrowType.Int(8, true);
329+
} else if (type == DataTypes.ShortType) {
330+
return new ArrowType.Int(16, true);
331+
} else if (type == DataTypes.IntegerType) {
332+
return new ArrowType.Int(32, true);
333+
} else if (type == DataTypes.LongType) {
334+
return new ArrowType.Int(64, true);
335+
} else if (type == DataTypes.FloatType) {
336+
return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
337+
} else if (type == DataTypes.DoubleType) {
338+
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
339+
} else if (type == DataTypes.StringType) {
340+
return ArrowType.Utf8.INSTANCE;
341+
} else if (type == DataTypes.BinaryType) {
342+
return ArrowType.Binary.INSTANCE;
343+
} else if (type == DataTypes.DateType) {
344+
return new ArrowType.Date(DateUnit.DAY);
345+
} else if (type == DataTypes.TimestampType) {
346+
return new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC");
347+
} else if (type == TimestampNTZType$.MODULE$) {
348+
return new ArrowType.Timestamp(TimeUnit.MICROSECOND, null);
349+
} else if (type instanceof DecimalType) {
350+
DecimalType dt = (DecimalType) type;
351+
return new ArrowType.Decimal(dt.precision(), dt.scale(), 128);
352+
} else if (type == DataTypes.NullType) {
353+
return ArrowType.Null.INSTANCE;
354+
} else {
355+
throw new UnsupportedOperationException("Unsupported Spark type: " + type);
356+
}
357+
}
358+
}

common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,8 @@ public void init() throws Throwable {
473473
+ filePath);
474474
}
475475
if (field.isPrimitive()) {
476-
ConstantColumnReader reader =
477-
new ConstantColumnReader(nonPartitionFields[i], capacity, useDecimal128);
476+
ImmutableConstantColumnReader reader =
477+
new ImmutableConstantColumnReader(nonPartitionFields[i], capacity, useDecimal128);
478478
columnReaders[i] = reader;
479479
missingColumns[i] = true;
480480
} else {
@@ -492,8 +492,9 @@ public void init() throws Throwable {
492492
for (int i = fields.size(); i < columnReaders.length; i++) {
493493
int fieldIndex = i - fields.size();
494494
StructField field = partitionFields[fieldIndex];
495-
ConstantColumnReader reader =
496-
new ConstantColumnReader(field, capacity, partitionValues, fieldIndex, useDecimal128);
495+
ImmutableConstantColumnReader reader =
496+
new ImmutableConstantColumnReader(
497+
field, capacity, partitionValues, fieldIndex, useDecimal128);
497498
columnReaders[i] = reader;
498499
}
499500
}

0 commit comments

Comments
 (0)