Skip to content

Commit 174c939

Browse files
committed
add NullType to shuffles
1 parent 810e5d5 commit 174c939

4 files changed

Lines changed: 27 additions & 23 deletions

File tree

native/shuffle/src/spark_unsafe/row.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use arrow::array::{
2828
builder::{
2929
ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder,
3030
Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder,
31-
Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder,
32-
StructBuilder, TimestampMicrosecondBuilder,
31+
Int64Builder, Int8Builder, ListBuilder, MapBuilder, NullBuilder, StringBuilder,
32+
StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
3333
},
3434
types::Int32Type,
3535
Array, ArrayRef, RecordBatch, RecordBatchOptions,
@@ -267,6 +267,10 @@ pub(super) fn append_field(
267267
append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder
268268
.append_value(row.get_date(idx)));
269269
}
270+
DataType::Null => {
271+
let field_builder = get_field_builder!(struct_builder, NullBuilder, idx);
272+
field_builder.append_null();
273+
}
270274
DataType::Timestamp(TimeUnit::Microsecond, _) => {
271275
append_field_to_builder!(
272276
TimestampMicrosecondBuilder,
@@ -1148,6 +1152,12 @@ fn append_columns(
11481152
.append_value(row.get_date(idx))
11491153
);
11501154
}
1155+
DataType::Null => {
1156+
let null_builder = downcast_builder_ref!(NullBuilder, builder);
1157+
for _ in row_start..row_end {
1158+
null_builder.append_null();
1159+
}
1160+
}
11511161
DataType::Timestamp(TimeUnit::Microsecond, _) => {
11521162
append_column_to_builder!(
11531163
TimestampMicrosecondBuilder,
@@ -1252,6 +1262,7 @@ fn make_builders(
12521262
}
12531263
}
12541264
DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)),
1265+
DataType::Null => Box::new(NullBuilder::new()),
12551266
DataType::Timestamp(TimeUnit::Microsecond, _) => {
12561267
Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone()))
12571268
}

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
4040
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
4141
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
4242
import org.apache.spark.sql.internal.SQLConf
43-
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
43+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
4444
import org.apache.spark.sql.vectorized.ColumnarBatch
4545
import org.apache.spark.util.MutablePair
4646
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
@@ -364,7 +364,7 @@ object CometShuffleExchangeExec
364364
def supportedSerializableDataType(dt: DataType): Boolean = dt match {
365365
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
366366
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
367-
_: TimestampNTZType | _: DecimalType | _: DateType =>
367+
_: TimestampNTZType | _: DecimalType | _: DateType | _: NullType =>
368368
true
369369
case StructType(fields) =>
370370
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType))
@@ -487,7 +487,7 @@ object CometShuffleExchangeExec
487487
def supportedSerializableDataType(dt: DataType): Boolean = dt match {
488488
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
489489
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
490-
_: TimestampNTZType | _: DecimalType | _: DateType =>
490+
_: TimestampNTZType | _: DecimalType | _: DateType | _: NullType =>
491491
true
492492
case StructType(fields) =>
493493
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) &&

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ package org.apache.comet.exec
2222
import java.nio.file.{Files, Paths}
2323

2424
import scala.reflect.runtime.universe._
25-
import scala.util.Random
2625

2726
import org.scalactic.source.Position
2827
import org.scalatest.Tag
2928

3029
import org.apache.hadoop.fs.Path
3130
import org.apache.spark.{Partitioner, SparkConf}
32-
import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row}
31+
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
3332
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager}
3433
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec}
3534
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
@@ -94,22 +93,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
9493
""".stripMargin))
9594
}
9695

97-
test("Fallback to Spark for unsupported input besides ordering") {
98-
val dataGenerator = RandomDataGenerator
99-
.forType(
100-
dataType = NullType,
101-
nullable = true,
102-
new Random(System.nanoTime()),
103-
validJulianDatetime = false)
104-
.get
105-
106-
val schema = new StructType()
107-
.add("index", IntegerType, nullable = false)
108-
.add("col", NullType, nullable = true)
109-
val rdd =
110-
spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator())))
111-
val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
112-
checkSparkAnswer(df)
96+
test("columnar shuffle with NullType passthrough column") {
97+
val df = sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)")
98+
val shuffled = df.repartition(2, $"x")
99+
checkShuffleAnswer(shuffled, 1)
113100
}
114101

115102
test("columnar shuffle on nested struct including nulls") {

spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
218218
}
219219
}
220220

221+
test("native shuffle with NullType passthrough column") {
222+
val df = spark.sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)")
223+
val shuffled = df.repartition(2, $"x")
224+
checkShuffleAnswer(shuffled, 1)
225+
}
226+
221227
test("fix: Comet native shuffle with binary data") {
222228
withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") {
223229
val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl")

0 commit comments

Comments
 (0)