diff --git a/common/src/main/java/org/apache/comet/parquet/ImmutableConstantColumnReader.java b/common/src/main/java/org/apache/comet/parquet/ImmutableConstantColumnReader.java new file mode 100644 index 0000000000..3cb475cd74 --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/ImmutableConstantColumnReader.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns; +import org.apache.spark.sql.types.*; + +import org.apache.comet.vector.CometConstantVector; +import org.apache.comet.vector.CometVector; + +/** + * A column reader that returns constant vectors without using native mutable buffers. This is used + * for reading partition columns and missing columns in NativeBatchReader. + * + *
Unlike {@link ConstantColumnReader} which uses native Rust code with mutable buffers, this + * implementation creates Arrow vectors directly in Java using Arrow's immutable buffer APIs. + */ +public class ImmutableConstantColumnReader extends AbstractColumnReader { + + /** + * Checks if the given Spark DataType is supported by this reader. This is used at query planning + * time to determine if NativeBatchReader can handle the partition schema or if it should fall + * back to Spark. + * + * @param type the Spark DataType to check + * @return true if the type is supported, false otherwise + */ + public static boolean isTypeSupported(DataType type) { + if (type == DataTypes.BooleanType + || type == DataTypes.ByteType + || type == DataTypes.ShortType + || type == DataTypes.IntegerType + || type == DataTypes.LongType + || type == DataTypes.FloatType + || type == DataTypes.DoubleType + || type == DataTypes.StringType + || type == DataTypes.BinaryType + || type == DataTypes.DateType + || type == DataTypes.TimestampType + || type == TimestampNTZType$.MODULE$ + || type == DataTypes.NullType + || type instanceof DecimalType) { + return true; + } + // Complex types (StructType, ArrayType, MapType) and other types are not supported + return false; + } + + /** Whether all the values in this constant column are nulls */ + private boolean isNull; + + /** The constant value */ + private Object value; + + /** The current vector */ + private CometVector vector; + + /** The Arrow field type for this column */ + private final Field arrowField; + + /** Constructor for missing columns with default values */ + ImmutableConstantColumnReader(StructField field, int batchSize, boolean useDecimal128) { + super(field.dataType(), TypeUtil.convertToParquet(field), useDecimal128, false); + this.batchSize = batchSize; + this.arrowField = toArrowField(field); + this.value = + ResolveDefaultColumns.getExistenceDefaultValues(new StructType(new StructField[] {field}))[ + 0]; + this.isNull = (this.value == null); + } + + /** Constructor for partition columns */ + ImmutableConstantColumnReader( + StructField field, int batchSize, InternalRow values, int index, boolean useDecimal128) { + super(field.dataType(), TypeUtil.convertToParquet(field), useDecimal128, false); + this.batchSize = batchSize; + this.arrowField = toArrowField(field); + this.value = values.get(index, field.dataType()); + this.isNull = (this.value == null); + } + + @Override + public void setBatchSize(int batchSize) { + close(); + this.batchSize = batchSize; + } + + @Override + public void readBatch(int total) { + if (vector != null) { + vector.close(); + vector = null; + } + vector = createConstantVector(total); + } + + @Override + public CometVector currentBatch() { + return vector; + } + + @Override + public void close() { + if (vector != null) { + vector.close(); + vector = null; + } + } + + @Override + protected void initNative() { + // No native initialization needed - we create vectors purely in Java + nativeHandle = 0; + } + + /** Creates a constant vector with the specified logical row count. */ + private CometVector createConstantVector(int numRows) { + return new CometConstantVector(type, arrowField, useDecimal128, value, isNull, numRows); + } + + /** Converts a Spark StructField to an Arrow Field. */ + private Field toArrowField(StructField field) { + ArrowType arrowType = toArrowType(field.dataType()); + FieldType fieldType = new FieldType(field.nullable(), arrowType, null); + return new Field(field.name(), fieldType, null); + } + + /** Converts a Spark DataType to an Arrow ArrowType. */ + private ArrowType toArrowType(DataType type) { + if (type == DataTypes.BooleanType) { + return ArrowType.Bool.INSTANCE; + } else if (type == DataTypes.ByteType) { + return new ArrowType.Int(8, true); + } else if (type == DataTypes.ShortType) { + return new ArrowType.Int(16, true); + } else if (type == DataTypes.IntegerType) { + return new ArrowType.Int(32, true); + } else if (type == DataTypes.LongType) { + return new ArrowType.Int(64, true); + } else if (type == DataTypes.FloatType) { + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + } else if (type == DataTypes.DoubleType) { + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + } else if (type == DataTypes.StringType) { + return ArrowType.Utf8.INSTANCE; + } else if (type == DataTypes.BinaryType) { + return ArrowType.Binary.INSTANCE; + } else if (type == DataTypes.DateType) { + return new ArrowType.Date(DateUnit.DAY); + } else if (type == DataTypes.TimestampType) { + return new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + } else if (type == TimestampNTZType$.MODULE$) { + return new ArrowType.Timestamp(TimeUnit.MICROSECOND, null); + } else if (type instanceof DecimalType) { + DecimalType dt = (DecimalType) type; + return new ArrowType.Decimal(dt.precision(), dt.scale(), 128); + } else if (type == DataTypes.NullType) { + return ArrowType.Null.INSTANCE; + } else { + throw new UnsupportedOperationException("Unsupported Spark type: " + type); + } + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index d10a8932be..e49aad383c 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -473,8 +473,8 @@ public void init() throws Throwable { + filePath); } if (field.isPrimitive()) { - ConstantColumnReader reader = - new ConstantColumnReader(nonPartitionFields[i], capacity, useDecimal128); + ImmutableConstantColumnReader reader = + new ImmutableConstantColumnReader(nonPartitionFields[i], capacity, useDecimal128); columnReaders[i] = reader; missingColumns[i] = true; } else { @@ -492,8 +492,9 @@ public void init() throws Throwable { for (int i = fields.size(); i < columnReaders.length; i++) { int fieldIndex = i - fields.size(); StructField field = partitionFields[fieldIndex]; - ConstantColumnReader reader = - new ConstantColumnReader(field, capacity, partitionValues, fieldIndex, useDecimal128); + ImmutableConstantColumnReader reader = + new ImmutableConstantColumnReader( + field, capacity, partitionValues, fieldIndex, useDecimal128); columnReaders[i] = reader; } } diff --git a/common/src/main/java/org/apache/comet/vector/CometConstantVector.java b/common/src/main/java/org/apache/comet/vector/CometConstantVector.java new file mode 100644 index 0000000000..c82af431b4 --- /dev/null +++ b/common/src/main/java/org/apache/comet/vector/CometConstantVector.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.vector; + +import java.math.BigDecimal; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A CometVector that stores a single constant value. For native export, it lazily creates a + * 1-element Arrow vector, avoiding the cost of materializing N identical elements. The native side + * detects the 1-element array and expands it to the actual batch size. + * + *
For Spark-direct consumption (e.g., ColumnarToRow), all getters return the constant value + * regardless of rowId. + */ +public class CometConstantVector extends CometVector { + private final BufferAllocator allocator = new RootAllocator(); + + /** Whether the constant value is null */ + private final boolean isNull; + + /** The constant value (null if isNull is true) */ + private final Object value; + + /** The Spark data type */ + private final DataType sparkType; + + /** The Arrow field for creating the 1-element vector */ + private final Field arrowField; + + /** Logical number of rows this vector represents */ + private int numValues; + + /** Lazily created 1-element Arrow vector for native export */ + private ValueVector lazyVector; + + public CometConstantVector( + DataType sparkType, + Field arrowField, + boolean useDecimal128, + Object value, + boolean isNull, + int numValues) { + super(sparkType, useDecimal128); + this.sparkType = sparkType; + this.arrowField = arrowField; + this.value = value; + this.isNull = isNull; + this.numValues = numValues; + } + + @Override + public void setNumNulls(int numNulls) { + // No-op: null status is determined by the constant isNull flag + } + + @Override + public void setNumValues(int numValues) { + this.numValues = numValues; + } + + @Override + public int numValues() { + return numValues; + } + + @Override + public boolean hasNull() { + return isNull; + } + + @Override + public int numNulls() { + return isNull ? numValues : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return isNull; + } + + @Override + public boolean isFixedLength() { + return !(sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType); + } + + @Override + public boolean getBoolean(int rowId) { + return (Boolean) value; + } + + @Override + public byte getByte(int rowId) { + return (Byte) value; + } + + @Override + public short getShort(int rowId) { + return (Short) value; + } + + @Override + public int getInt(int rowId) { + return (Integer) value; + } + + @Override + public long getLong(int rowId) { + return (Long) value; + } + + @Override + public long getLongDecimal(int rowId) { + return (Long) value; + } + + @Override + public float getFloat(int rowId) { + return (Float) value; + } + + @Override + public double getDouble(int rowId) { + return (Double) value; + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNull) return null; + return (UTF8String) value; + } + + @Override + public byte[] getBinary(int rowId) { + if (isNull) return null; + return (byte[]) value; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNull) return null; + return (Decimal) value; + } + + @Override + public byte[] copyBinaryDecimal(int i, byte[] dest) { + // Override to avoid the memory-address code path in CometVector. + // For constant decimals, we go through getDecimal() instead. + throw new UnsupportedOperationException( + "CometConstantVector does not support copyBinaryDecimal; use getDecimal() instead"); + } + + @Override + public ValueVector getValueVector() { + if (lazyVector == null) { + lazyVector = createOneElementVector(); + } + return lazyVector; + } + + @Override + public CometVector slice(int offset, int length) { + return new CometConstantVector(sparkType, arrowField, useDecimal128, value, isNull, length); + } + + @Override + public void close() { + if (lazyVector != null) { + lazyVector.close(); + lazyVector = null; + } + allocator.close(); + } + + /** Creates a 1-element Arrow vector holding the constant value. */ + private ValueVector createOneElementVector() { + if (isNull) { + return new NullVector(arrowField.getName(), 1); + } + + if (sparkType == DataTypes.BooleanType) { + BitVector v = new BitVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Boolean) value ? 1 : 0); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.ByteType) { + TinyIntVector v = new TinyIntVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Byte) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.ShortType) { + SmallIntVector v = new SmallIntVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Short) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.IntegerType) { + IntVector v = new IntVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Integer) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.LongType) { + BigIntVector v = new BigIntVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Long) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.FloatType) { + Float4Vector v = new Float4Vector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Float) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.DoubleType) { + Float8Vector v = new Float8Vector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Double) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.StringType) { + VarCharVector v = new VarCharVector(arrowField, allocator); + byte[] bytes = ((UTF8String) value).getBytes(); + v.allocateNew((long) bytes.length, 1); + v.set(0, bytes); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.BinaryType) { + VarBinaryVector v = new VarBinaryVector(arrowField, allocator); + byte[] bytes = (byte[]) value; + v.allocateNew((long) bytes.length, 1); + v.set(0, bytes); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.DateType) { + DateDayVector v = new DateDayVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Integer) value); + v.setValueCount(1); + return v; + } else if (sparkType == DataTypes.TimestampType || sparkType == TimestampNTZType$.MODULE$) { + TimeStampMicroTZVector v = new TimeStampMicroTZVector(arrowField, allocator); + v.allocateNew(1); + v.set(0, (Long) value); + v.setValueCount(1); + return v; + } else if (sparkType instanceof DecimalType) { + DecimalType dt = (DecimalType) sparkType; + DecimalVector v = + new DecimalVector(arrowField.getName(), allocator, dt.precision(), dt.scale()); + v.allocateNew(1); + BigDecimal bigDecimal = ((Decimal) value).toJavaBigDecimal(); + v.set(0, bigDecimal); + v.setValueCount(1); + return v; + } else { + throw new UnsupportedOperationException("Unsupported Spark type: " + sparkType); + } + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 45245121a0..c6bbb8e405 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -140,6 +140,20 @@ class NativeUtil { provider, arrowArray, arrowSchema) + case constantVector: CometConstantVector => + // Export the 1-element Arrow vector for scalar constant columns. + // Skip adding its value count to numRows validation since it's 1, not N. + // The native side will detect and expand 1-element arrays to the actual batch size. + val valueVector = constantVector.getValueVector + + val arrowSchema = ArrowSchema.wrap(schemaAddrs(index)) + val arrowArray = ArrowArray.wrap(arrayAddrs(index)) + Data.exportVector( + allocator, + getFieldVector(valueVector, "export"), + null, + arrowArray, + arrowSchema) case a: CometVector => val valueVector = a.getValueVector diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 2543705fb0..a5ab7878a4 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -28,6 +28,7 @@ use arrow::compute::{cast_with_options, take, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::ffi::FFI_ArrowArray; use arrow::ffi::FFI_ArrowSchema; +use datafusion::common::ScalarValue; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -209,16 +210,23 @@ impl ScanExec { let array = make_array(array_data); - // Apply selection if selection vectors exist (applies to all columns) + // Apply selection if selection vectors exist (applies to all columns). + // Skip take() for 1-element scalar constant arrays since they represent + // constant values unaffected by row deletion. let array = if let Some(ref selection_arrays) = selection_indices_arrays { - let indices = &selection_arrays[i]; - // Apply the selection using Arrow's take kernel - match take(&*array, &**indices, None) { - Ok(selected_array) => selected_array, - Err(e) => { - return Err(CometError::from(ExecutionError::ArrowError(format!( - "Failed to apply selection for column {i}: {e}", - )))); + if array.len() == 1 { + // Scalar constant column - skip selection, will be expanded later + array + } else { + let indices = &selection_arrays[i]; + // Apply the selection using Arrow's take kernel + match take(&*array, &**indices, None) { + Ok(selected_array) => selected_array, + Err(e) => { + return Err(CometError::from(ExecutionError::ArrowError(format!( + "Failed to apply selection for column {i}: {e}", + )))); + } } } } else { @@ -256,6 +264,18 @@ impl ScanExec { num_rows as usize }; + // Expand 1-element scalar constant columns to the actual batch size. + // The JVM side exports constant columns (partition/missing) as 1-element arrays + // to avoid materializing N identical values. We detect and expand them here. + if actual_num_rows > 1 { + for col in &mut inputs { + if col.len() == 1 { + let scalar = ScalarValue::try_from_array(col, 0)?; + *col = scalar.to_array_of_size(actual_num_rows)?; + } + } + } + Ok(InputBatch::new(inputs, Some(actual_num_rows))) } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 45faa4d940..31d4ebdb88 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -47,7 +47,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, withInfo, wi import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection} import org.apache.comet.objectstore.NativeConfig -import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} +import org.apache.comet.parquet.{CometParquetScan, ImmutableConstantColumnReader, Native, SupportsComet} import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled, isEncryptionConfigSupported} import org.apache.comet.serde.operator.CometNativeScan import org.apache.comet.shims.CometTypeShim @@ -664,7 +664,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com val schemaSupported = typeChecker.isSchemaSupported(scanExec.requiredSchema, fallbackReasons) val partitionSchemaSupported = - typeChecker.isSchemaSupported(partitionSchema, fallbackReasons) + typeChecker.isPartitionSchemaSupported(partitionSchema, fallbackReasons) val cometExecEnabled = COMET_EXEC_ENABLED.get() if (!cometExecEnabled) { @@ -702,7 +702,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return false } val partitionSchemaSupported = - typeChecker.isSchemaSupported(r.partitionSchema, fallbackReasons) + typeChecker.isPartitionSchemaSupported(r.partitionSchema, fallbackReasons) if (!partitionSchemaSupported) { withInfo( scanExec, @@ -745,6 +745,32 @@ case class CometScanTypeChecker(scanImpl: String) extends DataTypeSupport with C super.isTypeSupported(dt, name, fallbackReasons) } } + + /** + * Checks if the partition schema is supported for constant column readers. For + * native_iceberg_compat scan, partition columns use ImmutableConstantColumnReader which only + * supports primitive types. + */ + def isPartitionSchemaSupported( + partitionSchema: StructType, + fallbackReasons: ListBuffer[String]): Boolean = { + if (scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT) { + // For native_iceberg_compat, partition columns must be supported by + // ImmutableConstantColumnReader which only supports primitive types + partitionSchema.fields.forall { field => + if (ImmutableConstantColumnReader.isTypeSupported(field.dataType)) { + true + } else { + fallbackReasons += s"Partition column '${field.name}' has unsupported type " + + s"${field.dataType} for ImmutableConstantColumnReader in $scanImpl scan" + false + } + } + } else { + // For other scan implementations, use the standard type check + isSchemaSupported(partitionSchema, fallbackReasons) + } + } } object CometScanRule extends Logging { diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPartitionColumnBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPartitionColumnBenchmark.scala new file mode 100644 index 0000000000..9f52ba5ef0 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPartitionColumnBenchmark.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf +import org.apache.comet.CometConf.{SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT} + +/** + * Benchmark to measure partition column scan performance. This exercises the CometConstantVector + * path where constant columns are exported as 1-element Arrow arrays and expanded on the native + * side. + * + * To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make \ + * benchmark-org.apache.spark.sql.benchmark.CometPartitionColumnBenchmark + * }}} + * + * Results will be written to "spark/benchmarks/CometPartitionColumnBenchmark-**results.txt". + */ +object CometPartitionColumnBenchmark extends CometBenchmarkBase { + + def partitionColumnScanBenchmark(values: Int, numPartitionCols: Int): Unit = { + val sqlBenchmark = new Benchmark( + s"Partitioned Scan with $numPartitionCols partition column(s)", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + val partCols = + (1 to numPartitionCols).map(i => s"'part$i' as p$i").mkString(", ") + val partNames = (1 to numPartitionCols).map(i => s"p$i") + val df = spark.sql(s"SELECT value as id, $partCols FROM $tbl") + val parquetDir = dir.getCanonicalPath + "/parquetV1" + df.write + .partitionBy(partNames: _*) + .mode("overwrite") + .option("compression", "snappy") + .parquet(parquetDir) + spark.read.parquet(parquetDir).createOrReplaceTempView("parquetV1Table") + + sqlBenchmark.addCase("SQL Parquet - Spark") { _ => + spark.sql("select sum(id) from parquetV1Table").noop() + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_DATAFUSION) { + spark.sql("select sum(id) from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native Iceberg Compat") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT) { + spark.sql("select sum(id) from parquetV1Table").noop() + } + } + + // Also benchmark reading partition columns themselves + val partSumExpr = + (1 to numPartitionCols).map(i => s"sum(length(p$i))").mkString(", ") + + sqlBenchmark.addCase("SQL Parquet - Spark (read partition cols)") { _ => + spark.sql(s"select $partSumExpr from parquetV1Table").noop() + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion (partition cols)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_DATAFUSION) { + spark.sql(s"select $partSumExpr from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native Iceberg Compat (partition cols)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT) { + spark.sql(s"select $partSumExpr from parquetV1Table").noop() + } + } + + sqlBenchmark.run() + } + } + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("Partitioned Column Scan", 1024 * 1024 * 15) { v => + for (numPartCols <- List(1, 5)) { + partitionColumnScanBenchmark(v, numPartCols) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala index 3bfbdee91a..25dab067fb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector import org.apache.comet.{CometConf, WithHdfsCluster} -import org.apache.comet.CometConf.{SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT} +import org.apache.comet.CometConf.{SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT} import org.apache.comet.parquet.BatchReader /** @@ -67,6 +67,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql(s"select $query from parquetV1Table").noop() } + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -167,6 +175,21 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( "spark.memory.offHeap.enabled" -> "true", @@ -222,6 +245,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql("select sum(id) from parquetV1Table").noop() } + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("select sum(id) from parquetV1Table").noop() + } + } + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -342,6 +373,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql("select sum(c2) from parquetV1Table where c1 + 1 > 0").noop() } + benchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("select sum(c2) from parquetV1Table where c1 + 1 > 0").noop() + } + } + benchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -392,6 +431,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql("select sum(length(id)) from parquetV1Table").noop() } + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("select sum(length(id)) from parquetV1Table").noop() + } + } + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -435,6 +482,17 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { .noop() } + benchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark + .sql("select sum(length(c2)) from parquetV1Table where c1 is " + + "not NULL and c2 is not NULL") + .noop() + } + } + benchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -480,6 +538,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql(s"SELECT sum(c$middle) FROM parquetV1Table").noop() } + benchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql(s"SELECT sum(c$middle) FROM parquetV1Table").noop() + } + } + benchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -523,6 +589,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql("SELECT * FROM parquetV1Table WHERE c1 + 1 > 0").noop() } + benchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("SELECT * FROM parquetV1Table WHERE c1 + 1 > 0").noop() + } + } + benchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -546,6 +620,63 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } + def partitionColumnScanBenchmark(values: Int, numPartitionCols: Int): Unit = { + val sqlBenchmark = new Benchmark( + s"Partitioned Scan with $numPartitionCols partition column(s)", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + // Create a table with data columns and partition columns + val partCols = (1 to numPartitionCols).map(i => s"'part$i' as p$i").mkString(", ") + val partNames = (1 to numPartitionCols).map(i => s"p$i").mkString(", ") + prepareTable(dir, spark.sql(s"SELECT value as id, $partCols FROM $tbl"), Some(partNames)) + + sqlBenchmark.addCase("SQL Parquet - Spark") { _ => + spark.sql("select sum(id) from parquetV1Table").noop() + } + + sqlBenchmark.addCase("SQL Parquet - Comet (Scan Only)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("select sum(id) from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet (Scan + Exec)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT) { + spark.sql("select sum(id) from parquetV1Table").noop() + } + } + + // Also benchmark reading partition columns themselves + val partSumExpr = (1 to numPartitionCols) + .map(i => s"sum(length(p$i))") + .mkString(", ") + + sqlBenchmark.addCase("SQL Parquet - Spark (read partition cols)") { _ => + spark.sql(s"select $partSumExpr from parquetV1Table").noop() + } + + sqlBenchmark.addCase("SQL Parquet - Comet (read partition cols)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT) { + spark.sql(s"select $partSumExpr from parquetV1Table").noop() + } + } + + sqlBenchmark.run() + } + } + } + def sortedLgStrFilterScanBenchmark(values: Int, fractionOfZeros: Double): Unit = { val percentageOfZeros = fractionOfZeros * 100 val benchmark = @@ -566,6 +697,14 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { spark.sql("SELECT * FROM parquetV1Table WHERE c1 + 1 > 0").noop() } + benchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET) { + spark.sql("SELECT * FROM parquetV1Table WHERE c1 + 1 > 0").noop() + } + } + benchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -669,6 +808,12 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { sortedLgStrFilterScanBenchmark(v, fractionOfZeros) } } + + runBenchmarkWithTable("Partitioned Column Scan", 1024 * 1024 * 15) { v => + for (numPartCols <- List(1, 5)) { + partitionColumnScanBenchmark(v, numPartCols) + } + } } }