|
19 | 19 |
|
20 | 20 | package org.apache.comet.parquet; |
21 | 21 |
|
22 | | -import org.apache.arrow.memory.BufferAllocator; |
23 | | -import org.apache.arrow.memory.RootAllocator; |
24 | | -import org.apache.arrow.vector.*; |
25 | 22 | import org.apache.arrow.vector.types.DateUnit; |
26 | 23 | import org.apache.arrow.vector.types.FloatingPointPrecision; |
27 | 24 | import org.apache.arrow.vector.types.TimeUnit; |
|
31 | 28 | import org.apache.spark.sql.catalyst.InternalRow; |
32 | 29 | import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns; |
33 | 30 | import org.apache.spark.sql.types.*; |
34 | | -import org.apache.spark.unsafe.types.UTF8String; |
35 | 31 |
|
36 | | -import org.apache.comet.vector.CometPlainVector; |
| 32 | +import org.apache.comet.vector.CometConstantVector; |
37 | 33 | import org.apache.comet.vector.CometVector; |
38 | 34 |
|
39 | 35 | /** |
@@ -74,8 +70,6 @@ public static boolean isTypeSupported(DataType type) { |
74 | 70 | return false; |
75 | 71 | } |
76 | 72 |
|
77 | | - private final BufferAllocator allocator = new RootAllocator(); |
78 | | - |
79 | 73 | /** Whether all the values in this constant column are nulls */ |
80 | 74 | private boolean isNull; |
81 | 75 |
|
@@ -143,174 +137,9 @@ protected void initNative() { |
143 | 137 | nativeHandle = 0; |
144 | 138 | } |
145 | 139 |
|
146 | | - /** Creates a constant Arrow vector with the specified number of rows. */ |
| 140 | + /** Creates a constant vector with the specified logical row count. */ |
147 | 141 | private CometVector createConstantVector(int numRows) { |
148 | | - ValueVector arrowVector = createArrowVector(numRows); |
149 | | - return new CometPlainVector(arrowVector, useDecimal128); |
150 | | - } |
151 | | - |
152 | | - /** Creates an Arrow vector filled with constant values. */ |
153 | | - private ValueVector createArrowVector(int numRows) { |
154 | | - if (isNull) { |
155 | | - return createNullVector(numRows); |
156 | | - } |
157 | | - |
158 | | - if (type == DataTypes.BooleanType) { |
159 | | - return createBooleanVector(numRows, (Boolean) value); |
160 | | - } else if (type == DataTypes.ByteType) { |
161 | | - return createByteVector(numRows, (Byte) value); |
162 | | - } else if (type == DataTypes.ShortType) { |
163 | | - return createShortVector(numRows, (Short) value); |
164 | | - } else if (type == DataTypes.IntegerType) { |
165 | | - return createIntVector(numRows, (Integer) value); |
166 | | - } else if (type == DataTypes.LongType) { |
167 | | - return createLongVector(numRows, (Long) value); |
168 | | - } else if (type == DataTypes.FloatType) { |
169 | | - return createFloatVector(numRows, (Float) value); |
170 | | - } else if (type == DataTypes.DoubleType) { |
171 | | - return createDoubleVector(numRows, (Double) value); |
172 | | - } else if (type == DataTypes.StringType) { |
173 | | - return createStringVector(numRows, (UTF8String) value); |
174 | | - } else if (type == DataTypes.BinaryType) { |
175 | | - return createBinaryVector(numRows, (byte[]) value); |
176 | | - } else if (type == DataTypes.DateType) { |
177 | | - return createDateVector(numRows, (Integer) value); |
178 | | - } else if (type == DataTypes.TimestampType || type == TimestampNTZType$.MODULE$) { |
179 | | - return createTimestampVector(numRows, (Long) value); |
180 | | - } else if (type instanceof DecimalType) { |
181 | | - return createDecimalVector(numRows, (Decimal) value, (DecimalType) type); |
182 | | - } else { |
183 | | - throw new UnsupportedOperationException("Unsupported Spark type: " + type); |
184 | | - } |
185 | | - } |
186 | | - |
187 | | - private ValueVector createNullVector(int numRows) { |
188 | | - NullVector vector = new NullVector(arrowField.getName(), numRows); |
189 | | - return vector; |
190 | | - } |
191 | | - |
192 | | - private ValueVector createBooleanVector(int numRows, boolean value) { |
193 | | - BitVector vector = new BitVector(arrowField, allocator); |
194 | | - vector.allocateNew(numRows); |
195 | | - for (int i = 0; i < numRows; i++) { |
196 | | - vector.set(i, value ? 1 : 0); |
197 | | - } |
198 | | - vector.setValueCount(numRows); |
199 | | - return vector; |
200 | | - } |
201 | | - |
202 | | - private ValueVector createByteVector(int numRows, byte value) { |
203 | | - TinyIntVector vector = new TinyIntVector(arrowField, allocator); |
204 | | - vector.allocateNew(numRows); |
205 | | - for (int i = 0; i < numRows; i++) { |
206 | | - vector.set(i, value); |
207 | | - } |
208 | | - vector.setValueCount(numRows); |
209 | | - return vector; |
210 | | - } |
211 | | - |
212 | | - private ValueVector createShortVector(int numRows, short value) { |
213 | | - SmallIntVector vector = new SmallIntVector(arrowField, allocator); |
214 | | - vector.allocateNew(numRows); |
215 | | - for (int i = 0; i < numRows; i++) { |
216 | | - vector.set(i, value); |
217 | | - } |
218 | | - vector.setValueCount(numRows); |
219 | | - return vector; |
220 | | - } |
221 | | - |
222 | | - private ValueVector createIntVector(int numRows, int value) { |
223 | | - IntVector vector = new IntVector(arrowField, allocator); |
224 | | - vector.allocateNew(numRows); |
225 | | - for (int i = 0; i < numRows; i++) { |
226 | | - vector.set(i, value); |
227 | | - } |
228 | | - vector.setValueCount(numRows); |
229 | | - return vector; |
230 | | - } |
231 | | - |
232 | | - private ValueVector createLongVector(int numRows, long value) { |
233 | | - BigIntVector vector = new BigIntVector(arrowField, allocator); |
234 | | - vector.allocateNew(numRows); |
235 | | - for (int i = 0; i < numRows; i++) { |
236 | | - vector.set(i, value); |
237 | | - } |
238 | | - vector.setValueCount(numRows); |
239 | | - return vector; |
240 | | - } |
241 | | - |
242 | | - private ValueVector createFloatVector(int numRows, float value) { |
243 | | - Float4Vector vector = new Float4Vector(arrowField, allocator); |
244 | | - vector.allocateNew(numRows); |
245 | | - for (int i = 0; i < numRows; i++) { |
246 | | - vector.set(i, value); |
247 | | - } |
248 | | - vector.setValueCount(numRows); |
249 | | - return vector; |
250 | | - } |
251 | | - |
252 | | - private ValueVector createDoubleVector(int numRows, double value) { |
253 | | - Float8Vector vector = new Float8Vector(arrowField, allocator); |
254 | | - vector.allocateNew(numRows); |
255 | | - for (int i = 0; i < numRows; i++) { |
256 | | - vector.set(i, value); |
257 | | - } |
258 | | - vector.setValueCount(numRows); |
259 | | - return vector; |
260 | | - } |
261 | | - |
262 | | - private ValueVector createStringVector(int numRows, UTF8String value) { |
263 | | - VarCharVector vector = new VarCharVector(arrowField, allocator); |
264 | | - byte[] bytes = value.getBytes(); |
265 | | - vector.allocateNew((long) bytes.length * numRows, numRows); |
266 | | - for (int i = 0; i < numRows; i++) { |
267 | | - vector.set(i, bytes); |
268 | | - } |
269 | | - vector.setValueCount(numRows); |
270 | | - return vector; |
271 | | - } |
272 | | - |
273 | | - private ValueVector createBinaryVector(int numRows, byte[] value) { |
274 | | - VarBinaryVector vector = new VarBinaryVector(arrowField, allocator); |
275 | | - vector.allocateNew((long) value.length * numRows, numRows); |
276 | | - for (int i = 0; i < numRows; i++) { |
277 | | - vector.set(i, value); |
278 | | - } |
279 | | - vector.setValueCount(numRows); |
280 | | - return vector; |
281 | | - } |
282 | | - |
283 | | - private ValueVector createDateVector(int numRows, int value) { |
284 | | - DateDayVector vector = new DateDayVector(arrowField, allocator); |
285 | | - vector.allocateNew(numRows); |
286 | | - for (int i = 0; i < numRows; i++) { |
287 | | - vector.set(i, value); |
288 | | - } |
289 | | - vector.setValueCount(numRows); |
290 | | - return vector; |
291 | | - } |
292 | | - |
293 | | - private ValueVector createTimestampVector(int numRows, long value) { |
294 | | - TimeStampMicroTZVector vector = new TimeStampMicroTZVector(arrowField, allocator); |
295 | | - vector.allocateNew(numRows); |
296 | | - for (int i = 0; i < numRows; i++) { |
297 | | - vector.set(i, value); |
298 | | - } |
299 | | - vector.setValueCount(numRows); |
300 | | - return vector; |
301 | | - } |
302 | | - |
303 | | - private ValueVector createDecimalVector(int numRows, Decimal value, DecimalType dt) { |
304 | | - DecimalVector vector = |
305 | | - new DecimalVector(arrowField.getName(), allocator, dt.precision(), dt.scale()); |
306 | | - vector.allocateNew(numRows); |
307 | | - |
308 | | - java.math.BigDecimal bigDecimal = value.toJavaBigDecimal(); |
309 | | - for (int i = 0; i < numRows; i++) { |
310 | | - vector.set(i, bigDecimal); |
311 | | - } |
312 | | - vector.setValueCount(numRows); |
313 | | - return vector; |
| 142 | + return new CometConstantVector(type, arrowField, useDecimal128, value, isNull, numRows); |
314 | 143 | } |
315 | 144 |
|
316 | 145 | /** Converts a Spark StructField to an Arrow Field. */ |
|
0 commit comments