diff --git a/docs/source/contributor-guide/jvm_shuffle.md b/docs/source/contributor-guide/jvm_shuffle.md index e011651d2c..285693a7b9 100644 --- a/docs/source/contributor-guide/jvm_shuffle.md +++ b/docs/source/contributor-guide/jvm_shuffle.md @@ -49,9 +49,9 @@ JVM shuffle (`CometColumnarExchange`) is used instead of native shuffle (`CometE 3. **Unsupported partitioning type**: Native shuffle only supports `HashPartitioning`, `RangePartitioning`, and `SinglePartition`. JVM shuffle additionally supports `RoundRobinPartitioning`. -4. **Unsupported partition key types**: For `HashPartitioning` and `RangePartitioning`, native shuffle - only supports primitive types as partition keys. Complex types (struct, array, map) cannot be used - as partition keys in native shuffle, though they are fully supported as data columns in both implementations. +4. **Unsupported partition key types**: For `RangePartitioning`, native shuffle only supports primitive + types as partition keys. Complex types (struct, array, map) are supported as hash partition keys in + native shuffle. ## Input Handling diff --git a/docs/source/contributor-guide/native_shuffle.md b/docs/source/contributor-guide/native_shuffle.md index e3d2dea473..7a316383c7 100644 --- a/docs/source/contributor-guide/native_shuffle.md +++ b/docs/source/contributor-guide/native_shuffle.md @@ -55,9 +55,9 @@ Native shuffle (`CometExchange`) is selected when all of the following condition `RoundRobinPartitioning` requires JVM shuffle. -4. **Supported partition key types**: For `HashPartitioning` and `RangePartitioning`, partition - keys must be primitive types. Complex types (struct, array, map) as partition keys require - JVM shuffle. Note that complex types are fully supported as data columns in native shuffle. +4. **Supported partition key types**: For `HashPartitioning`, both primitive and complex types + (struct, array, map) are supported as partition keys. For `RangePartitioning`, only primitive + types are supported as partition keys. ## Architecture diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index 5939e89ef3..e69cc5ed5d 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -141,8 +141,8 @@ back to Spark for shuffle operations. #### Native Shuffle Comet provides a fully native shuffle implementation, which generally provides the best performance. Native shuffle -supports `HashPartitioning`, `RangePartitioning` and `SinglePartitioning` but currently only supports primitive type -partitioning keys. Columns that are not partitioning keys may contain complex types like maps, structs, and arrays. +supports `HashPartitioning`, `RangePartitioning` and `SinglePartitioning`. Complex types (structs, arrays, and maps) +are supported as hash partition keys. Range partitioning only supports primitive types as partition keys. #### Columnar (JVM) Shuffle diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1805711d01..a8216a27f2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -260,14 +260,20 @@ object CometShuffleExchangeExec * Determine which data types are supported as partition columns in native shuffle. * * For HashPartitioning this defines the key that determines how data should be collocated for - * operations like `groupByKey`, `reduceByKey`, or `join`. Native code does not support - * hashing complex types, see hash_funcs/utils.rs + * operations like `groupByKey`, `reduceByKey`, or `join`. Native code supports hashing both + * primitive and complex types. */ def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | _: DecimalType | _: DateType => true + case StructType(fields) => + fields.nonEmpty && fields.forall(f => supportedHashPartitioningDataType(f.dataType)) + case ArrayType(elementType, _) => + supportedHashPartitioningDataType(elementType) + case MapType(keyType, valueType, _) => + supportedHashPartitioningDataType(keyType) && supportedHashPartitioningDataType(valueType) case _ => false } diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index fc3db183b3..2eae878ae8 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,8 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ +// Import MapSort for Spark 4.0 support +import org.apache.spark.sql.catalyst.expressions.MapSort import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation @@ -55,6 +57,11 @@ trait CometExprShim extends CommonStringExprs { inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { expr match { + // MapSort is used by Spark 4.0+ to make maps comparable for partitioning. + // For hash partitioning, we can just use the underlying map expression. + case MapSort(child) => + exprToProtoInternal(child, inputs, binding) + case s: StaticInvoke if s.staticObject == classOf[StringDecode] && s.dataType.isInstanceOf[StringType] && diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index 833314a5c6..5202e57ddd 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -179,7 +179,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") val columns = df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name) for (col <- columns) { - // DISTRIBUTE BY is equivalent to df.repartition($col) and uses + // DISTRIBUTE BY is equivalent to df.repartition($col) val sql = s"SELECT $col FROM t1 DISTRIBUTE BY $col" val df = spark.sql(sql) df.collect() @@ -191,13 +191,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase { // native_comet does not support reading complex types 0 case _ => - CometConf.COMET_SHUFFLE_MODE.get() match { - case "jvm" => - 1 - case "native" => - // native shuffle does not support complex types as partitioning keys - 0 - } + 1 } assert(cometShuffleExchanges.length == expectedNumCometShuffles) } diff --git a/spark/src/test/scala/org/apache/comet/CometHashExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometHashExpressionSuite.scala index 563ee18520..209d1cbc80 100644 --- a/spark/src/test/scala/org/apache/comet/CometHashExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometHashExpressionSuite.scala @@ -161,6 +161,10 @@ class CometHashExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe } } + // ==================== Complex Types ==================== + // Note: The SQL hash() expression for complex types falls back to Spark execution. + // These tests verify correctness of the hash values (used by native shuffle partitioning). + test("hash - array of decimal (precision > 18) falls back to Spark") { withTable("t") { sql("CREATE TABLE t(c ARRAY) USING parquet") diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index a682ff91a5..1fae78166f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -68,7 +68,9 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 1000) - var allTypes: Seq[Int] = (1 to 20) + // Exclude _17 which is DECIMAL(38, 37) - high precision decimals are not supported + // as partitioning keys in native shuffle + var allTypes: Seq[Int] = (1 to 20).filterNot(_ == 17) allTypes.map(i => s"_$i").foreach { c => withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> execEnabled.toString, @@ -100,8 +102,8 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper .filter($"_3" > 10) .repartition(numPartitions, $"_2") - // Partitioning on nested array falls back to Spark - checkShuffleAnswer(df, 0) + // Partitioning on nested array works with native shuffle + checkShuffleAnswer(df, 1) df = sql("SELECT * FROM tbl") .filter($"_3" > 10) @@ -116,6 +118,36 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } + test("native shuffle with struct as partition key") { + Seq(10, 201).foreach { numPartitions => + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { + withParquetTable((0 until 50).map(i => (i, (i % 10, s"value_${i % 10}"), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_3" > 10) + .repartition(numPartitions, $"_2") + + // Partitioning on struct works with native shuffle + checkShuffleAnswer(df, 1) + } + } + } + } + + test("native shuffle with map as partition key") { + Seq(10, 201).foreach { numPartitions => + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { + withParquetTable((0 until 50).map(i => (i, Map("key" -> (i % 10)), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_3" > 10) + .repartition(numPartitions, $"_2") + + // Partitioning on map works with native shuffle + checkShuffleAnswer(df, 1) + } + } + } + } + test("hash-based native shuffle") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)