Skip to content

Commit 5d137e2

Browse files
authored
fix: [Spark 4.1] preserve union output partitioning in CometUnionExec (#4207)
1 parent c799d62 commit 5d137e2

10 files changed

Lines changed: 401 additions & 107 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ jobs:
355355
org.apache.comet.exec.CometWindowExecSuite
356356
org.apache.comet.exec.CometJoinSuite
357357
org.apache.comet.CometNativeSuite
358+
org.apache.comet.CometSetOpWithGroupBySuite
358359
org.apache.comet.CometSparkSessionExtensionsSuite
359360
org.apache.spark.CometPluginsSuite
360361
org.apache.spark.CometPluginsDefaultSuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ jobs:
194194
org.apache.comet.exec.CometWindowExecSuite
195195
org.apache.comet.exec.CometJoinSuite
196196
org.apache.comet.CometNativeSuite
197+
org.apache.comet.CometSetOpWithGroupBySuite
197198
org.apache.comet.CometSparkSessionExtensionsSuite
198199
org.apache.spark.CometPluginsSuite
199200
org.apache.spark.CometPluginsDefaultSuite

dev/diffs/4.1.1.diff

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -150,50 +150,6 @@ index 4410fe50912..43bcce2a038 100644
150150
case _ => Map[String, String]()
151151
}
152152
val childrenInfo = children.flatMap {
153-
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out
154-
index 69b4001ff34..6fda691652d 100644
155-
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out
156-
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/intersect-all.sql.out
157-
@@ -1,7 +1,7 @@
158-
-- Automatically generated by SQLQueryTestSuite
159-
-- !query
160-
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
161-
- (1, 2),
162-
+ (1, 2),
163-
(1, 2),
164-
(1, 3),
165-
(1, 3),
166-
@@ -11,7 +11,7 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
167-
AS tab1(k, v)
168-
-- !query analysis
169-
CreateViewCommand `tab1`, SELECT * FROM VALUES
170-
- (1, 2),
171-
+ (1, 2),
172-
(1, 2),
173-
(1, 3),
174-
(1, 3),
175-
@@ -26,8 +26,8 @@ CreateViewCommand `tab1`, SELECT * FROM VALUES
176-
177-
-- !query
178-
CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
179-
- (1, 2),
180-
- (1, 2),
181-
+ (1, 2),
182-
+ (1, 2),
183-
(2, 3),
184-
(3, 4),
185-
(null, null),
186-
@@ -35,8 +35,8 @@ CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
187-
AS tab2(k, v)
188-
-- !query analysis
189-
CreateViewCommand `tab2`, SELECT * FROM VALUES
190-
- (1, 2),
191-
- (1, 2),
192-
+ (1, 2),
193-
+ (1, 2),
194-
(2, 3),
195-
(3, 4),
196-
(null, null),
197153
diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
198154
index 13bbd9d81b7..541cdfb1e04 100644
199155
--- a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
@@ -211,18 +167,6 @@ index 13bbd9d81b7..541cdfb1e04 100644
211167
CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b;
212168

213169
-- division, remainder and pmod by 0 return NULL
214-
diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
215-
index e28f0721a64..788b43c242a 100644
216-
--- a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
217-
+++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql
218-
@@ -1,3 +1,7 @@
219-
+-- TODO(https://github.com/apache/datafusion-comet/issues/4122)
220-
+-- EXCEPT ALL with GROUP BY returns incorrect results on Spark 4.1
221-
+--SET spark.comet.enabled = false
222-
+
223-
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
224-
(0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1);
225-
CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
226170
diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql b/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
227171
index 7aef901da4f..f3d6e18926d 100644
228172
--- a/sql/core/src/test/resources/sql-tests/inputs/explain-aqe.sql
@@ -280,32 +224,6 @@ index 35128da97fd..25b873ae859 100644
280224
-- Positive test cases
281225
-- Create a table with some testing data.
282226
DROP TABLE IF EXISTS t1;
283-
diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
284-
index 077caa5dd44..697457d4251 100644
285-
--- a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
286-
+++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql
287-
@@ -1,5 +1,9 @@
288-
+-- TODO(https://github.com/apache/datafusion-comet/issues/4122)
289-
+-- INTERSECT ALL with GROUP BY returns incorrect results on Spark 4.1
290-
+--SET spark.comet.enabled = false
291-
+
292-
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
293-
- (1, 2),
294-
+ (1, 2),
295-
(1, 2),
296-
(1, 3),
297-
(1, 3),
298-
@@ -8,8 +12,8 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
299-
(null, null)
300-
AS tab1(k, v);
301-
CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
302-
- (1, 2),
303-
- (1, 2),
304-
+ (1, 2),
305-
+ (1, 2),
306-
(2, 3),
307-
(3, 4),
308-
(null, null),
309227
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
310228
index 41fd4de2a09..162d5a817b6 100644
311229
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -428,30 +346,6 @@ index 21a3ce1e122..f4762ab98f0 100644
428346
SET spark.sql.ansi.enabled = false;
429347

430348
-- In COMPENSATION views get invalidated if the type can't cast
431-
diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
432-
index 44f95f225ab..361866fc298 100644
433-
--- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
434-
+++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out
435-
@@ -1,7 +1,7 @@
436-
-- Automatically generated by SQLQueryTestSuite
437-
-- !query
438-
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
439-
- (1, 2),
440-
+ (1, 2),
441-
(1, 2),
442-
(1, 3),
443-
(1, 3),
444-
@@ -17,8 +17,8 @@ struct<>
445-
446-
-- !query
447-
CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES
448-
- (1, 2),
449-
- (1, 2),
450-
+ (1, 2),
451-
+ (1, 2),
452-
(2, 3),
453-
(3, 4),
454-
(null, null),
455349
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
456350
index 0d807aeae4d..6d7744e771b 100644
457351
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,8 +1342,37 @@ case class CometUnionExec(
13421342
children: Seq[SparkPlan])
13431343
extends CometExec {
13441344

1345+
// CometExec's default outputPartitioning delegates to `originalPlan`, which captures the
1346+
// children that were live at CometExecRule conversion time. AQE post-stage rewrites
1347+
// (coalesce, skew join, etc.) later re-parent our `children` field but do not update
1348+
// `originalPlan`, so the partitioning read from the frozen snapshot can describe a
1349+
// pre-coalesce layout with more partitions than the RDDs will actually produce. Recompute
1350+
// from current children so SPARK-52921's union-output-partitioning inference is based on
1351+
// the live plan. Safe on older Spark too: UnionExec.outputPartitioning returns
1352+
// UnknownPartitioning when UNION_OUTPUT_PARTITIONING is off (the pre-4.1 default).
1353+
//
1354+
// Only advertise SinglePartition or HashPartitioningLike — the same whitelist that Spark's
1355+
// UnionExec.comparePartitioning uses and that ShimCometUnionExec.unionRDDs honors via
1356+
// SQLPartitioningAwareUnionRDD. For anything else, report UnknownPartitioning so that the
1357+
// declared partitioning and the RDD layer always agree.
1358+
override lazy val outputPartitioning: Partitioning = {
1359+
originalPlan.withNewChildren(children).outputPartitioning match {
1360+
case p @ (SinglePartition | _: HashPartitioningLike) => p
1361+
case p => UnknownPartitioning(p.numPartitions)
1362+
}
1363+
}
1364+
13451365
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
1346-
sparkContext.union(children.map(_.executeColumnar()))
1366+
// Spark 4.1's UnionExec (SPARK-52921) can report a non-trivial output partitioning when all
1367+
// children share the same hash/single partitioning, and downstream plans may skip an
1368+
// otherwise-required shuffle in response. Plain `sparkContext.union` concatenates partitions
1369+
// (so partition i of the result holds only one child's partition i), which violates that
1370+
// partitioning claim and silently corrupts aggregates layered above the union. The shim
1371+
// routes through SQLPartitioningAwareUnionRDD on 4.1+ when a known partitioning is declared.
1372+
shims.ShimCometUnionExec.unionRDDs(
1373+
sparkContext,
1374+
children.map(_.executeColumnar()),
1375+
outputPartitioning)
13471376
}
13481377

13491378
override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.shims
21+
22+
import scala.reflect.ClassTag
23+
24+
import org.apache.spark.SparkContext
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
27+
28+
object ShimCometUnionExec {
29+
30+
/**
31+
* Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark
32+
* 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports
33+
* [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply
34+
* concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed
35+
* on Spark 4.1+ (see SPARK-52921).
36+
*/
37+
def unionRDDs[T: ClassTag](
38+
sc: SparkContext,
39+
rdds: Seq[RDD[T]],
40+
@annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = {
41+
sc.union(rdds)
42+
}
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.shims
21+
22+
import scala.reflect.ClassTag
23+
24+
import org.apache.spark.SparkContext
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
27+
28+
object ShimCometUnionExec {
29+
30+
/**
31+
* Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark
32+
* 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports
33+
* [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply
34+
* concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed
35+
* on Spark 4.1+ (see SPARK-52921).
36+
*/
37+
def unionRDDs[T: ClassTag](
38+
sc: SparkContext,
39+
rdds: Seq[RDD[T]],
40+
@annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = {
41+
sc.union(rdds)
42+
}
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.shims
21+
22+
import scala.reflect.ClassTag
23+
24+
import org.apache.spark.SparkContext
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
27+
28+
object ShimCometUnionExec {
29+
30+
/**
31+
* Unions a sequence of RDDs while preserving the declared output partitioning. Before Spark
32+
* 4.1, [[org.apache.spark.sql.execution.UnionExec]] always reports
33+
* [[org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning]], so this shim simply
34+
* concatenates partitions via `SparkContext.union`. The partitioning-aware path is only needed
35+
* on Spark 4.1+ (see SPARK-52921).
36+
*/
37+
def unionRDDs[T: ClassTag](
38+
sc: SparkContext,
39+
rdds: Seq[RDD[T]],
40+
@annotation.nowarn("cat=unused") outputPartitioning: Partitioning): RDD[T] = {
41+
sc.union(rdds)
42+
}
43+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.shims
21+
22+
import scala.reflect.ClassTag
23+
24+
import org.apache.spark.SparkContext
25+
import org.apache.spark.internal.Logging
26+
import org.apache.spark.rdd.{RDD, SQLPartitioningAwareUnionRDD}
27+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, SinglePartition}
28+
29+
object ShimCometUnionExec extends Logging {
30+
31+
/**
32+
* Unions a sequence of RDDs while preserving the declared output partitioning. Spark 4.1
33+
* introduced [[org.apache.spark.sql.internal.SQLConf.UNION_OUTPUT_PARTITIONING]] (SPARK-52921),
34+
* which lets [[org.apache.spark.sql.execution.UnionExec]] report a non-trivial output
35+
* partitioning when all children share the same partitioning. Downstream operators may then
36+
* skip an otherwise-required shuffle, so the columnar Union path must honor that contract by
37+
* routing through [[SQLPartitioningAwareUnionRDD]] rather than plain `SparkContext.union`,
38+
* which concatenates partitions and breaks the partitioning invariant.
39+
*/
40+
def unionRDDs[T: ClassTag](
41+
sc: SparkContext,
42+
rdds: Seq[RDD[T]],
43+
outputPartitioning: Partitioning): RDD[T] = {
44+
outputPartitioning match {
45+
case SinglePartition | _: HashPartitioningLike =>
46+
val numPartitions = outputPartitioning.numPartitions
47+
val nonEmpty = rdds.filter(_.partitions.nonEmpty)
48+
// SQLPartitioningAwareUnionRDD indexes every child at every output partition, so any
49+
// child whose partition count diverges from the declared numPartitions would raise
50+
// ArrayIndexOutOfBoundsException. That would only happen if the declared partitioning
51+
// is stale relative to the RDDs (e.g. children were coalesced by AQE but the reported
52+
// partitioning was not). Fall back to plain concat in that case.
53+
if (nonEmpty.isEmpty || nonEmpty.exists(_.partitions.length != numPartitions)) {
54+
val childCounts = rdds.map(_.partitions.length).mkString(", ")
55+
logWarning(
56+
s"CometUnionExec: child partition counts ($childCounts) do not match " +
57+
s"declared output partitioning numPartitions=$numPartitions; " +
58+
"falling back to SparkContext.union concat.")
59+
sc.union(rdds)
60+
} else {
61+
new SQLPartitioningAwareUnionRDD(sc, nonEmpty, numPartitions)
62+
}
63+
case _ => sc.union(rdds)
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)