Skip to content

Commit bb0678f

Browse files
committed
test: expect Spark fallback for shuffle on map with array/struct keys
Spark 4.0 wraps map shuffle keys in mapsort(...). Comet's map_sort relies on Arrow's sort_to_indices, which only supports scalar key types, so maps with array or struct keys fall back to Spark. Update the 'columnar shuffle on array/struct map key/value' test to expect 0 Comet shuffles for the array-key and struct-key cases on Spark 4.0+, while keeping the scalar-key cases at 1.
1 parent 77cd1d9 commit bb0678f

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
152152
}
153153

154154
test("columnar shuffle on array/struct map key/value") {
155+
// Spark 4.0 normalizes maps used as shuffle keys with mapsort(...). Comet's map_sort
156+
// relies on Arrow's sort_to_indices, which only supports scalar key types, so a map
157+
// with array or struct keys cannot be sorted natively and the shuffle falls back.
158+
val complexKeyShuffles = if (isSpark40Plus) 0 else 1
155159
Seq("false", "true").foreach { execEnabled =>
156160
Seq(10, 201).foreach { numPartitions =>
157161
Seq("1.0", "10.0").foreach { ratio =>
@@ -164,7 +168,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
164168
.repartition(numPartitions, $"_1", $"_2")
165169
.sortWithinPartitions($"_2")
166170

167-
checkShuffleAnswer(df, 1)
171+
checkShuffleAnswer(df, complexKeyShuffles)
168172
}
169173

170174
withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") {
@@ -182,7 +186,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
182186
.repartition(numPartitions, $"_1", $"_2")
183187
.sortWithinPartitions($"_2")
184188

185-
checkShuffleAnswer(df, 1)
189+
checkShuffleAnswer(df, complexKeyShuffles)
186190
}
187191

188192
withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") {

0 commit comments

Comments
 (0)