Skip to content

Commit b71b53b

Browse files
fix: enable Corr (#3892)
## Which issue does this PR close? Closes #2646 ## Rationale for this change This is a fix for the behavior for #2646 (comment) ## What changes are included in this PR? When both inputs to `Corr` are `NaN`, return `Null` ## How are these changes tested? Added tests
1 parent 9372a5e commit b71b53b

File tree

6 files changed

+67
-45
lines changed

6 files changed

+67
-45
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
7676
timezone is UTC.
7777
[#2649](https://github.com/apache/datafusion-comet/issues/2649)
7878

79-
### Aggregate Expressions
80-
81-
- **Corr**: Returns null instead of NaN in some edge cases.
82-
[#2646](https://github.com/apache/datafusion-comet/issues/2646)
83-
8479
### Struct Expressions
8580

8681
- **StructsToJson (to_json)**: Does not support `+Infinity` and `-Infinity` for numeric types (float, double).

docs/source/user-guide/latest/expressions.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -195,27 +195,27 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
195195

196196
## Aggregate Expressions
197197

198-
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
199-
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------- |
200-
| Average | | Yes, except for ANSI mode | |
201-
| BitAndAgg | | Yes | |
202-
| BitOrAgg | | Yes | |
203-
| BitXorAgg | | Yes | |
204-
| BoolAnd | `bool_and` | Yes | |
205-
| BoolOr | `bool_or` | Yes | |
206-
| Corr | | No | Returns null instead of NaN in some edge cases ([#2646](https://github.com/apache/datafusion-comet/issues/2646)) |
207-
| Count | | Yes | |
208-
| CovPopulation | | Yes | |
209-
| CovSample | | Yes | |
210-
| First | | No | This function is not deterministic. Results may not match Spark. |
211-
| Last | | No | This function is not deterministic. Results may not match Spark. |
212-
| Max | | Yes | |
213-
| Min | | Yes | |
214-
| StddevPop | | Yes | |
215-
| StddevSamp | | Yes | |
216-
| Sum | | Yes, except for ANSI mode | |
217-
| VariancePop | | Yes | |
218-
| VarianceSamp | | Yes | |
198+
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
199+
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------- |
200+
| Average | | Yes, except for ANSI mode | |
201+
| BitAndAgg | | Yes | |
202+
| BitOrAgg | | Yes | |
203+
| BitXorAgg | | Yes | |
204+
| BoolAnd | `bool_and` | Yes | |
205+
| BoolOr | `bool_or` | Yes | |
206+
| Corr | | Yes | |
207+
| Count | | Yes | |
208+
| CovPopulation | | Yes | |
209+
| CovSample | | Yes | |
210+
| First | | No | This function is not deterministic. Results may not match Spark. |
211+
| Last | | No | This function is not deterministic. Results may not match Spark. |
212+
| Max | | Yes | |
213+
| Min | | Yes | |
214+
| StddevPop | | Yes | |
215+
| StddevSamp | | Yes | |
216+
| Sum | | Yes, except for ANSI mode | |
217+
| VariancePop | | Yes | |
218+
| VarianceSamp | | Yes | |
219219

220220
## Window Functions
221221

native/spark-expr/src/agg_funcs/correlation.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,22 @@ impl Accumulator for CorrelationAccumulator {
221221
let stddev1 = self.stddev1.evaluate()?;
222222
let stddev2 = self.stddev2.evaluate()?;
223223

224+
if self.covar.get_count() == 0.0 {
225+
return Ok(ScalarValue::Float64(None));
226+
} else if self.covar.get_count() == 1.0 {
227+
if self.null_on_divide_by_zero {
228+
return Ok(ScalarValue::Float64(None));
229+
} else {
230+
return Ok(ScalarValue::Float64(Some(f64::NAN)));
231+
}
232+
}
224233
match (covar, stddev1, stddev2) {
225234
(
226235
ScalarValue::Float64(Some(c)),
227236
ScalarValue::Float64(Some(s1)),
228237
ScalarValue::Float64(Some(s2)),
229238
) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))),
230-
_ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
231-
_ => {
232-
if self.covar.get_count() == 1.0 {
233-
return Ok(ScalarValue::Float64(Some(f64::NAN)));
234-
}
235-
Ok(ScalarValue::Float64(None))
236-
}
239+
_ => Ok(ScalarValue::Float64(None)),
237240
}
238241
}
239242

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -584,13 +584,6 @@ object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with Come
584584
}
585585

586586
object CometCorr extends CometAggregateExpressionSerde[Corr] {
587-
588-
override def getSupportLevel(expr: Corr): SupportLevel =
589-
Incompatible(
590-
Some(
591-
"Returns null instead of NaN in some edge cases" +
592-
" (https://github.com/apache/datafusion-comet/issues/2646)"))
593-
594587
override def convert(
595588
aggExpr: AggregateExpression,
596589
corr: Corr,

spark/src/test/resources/sql-tests/expressions/aggregate/corr.sql

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18-
-- Config: spark.comet.expression.Corr.allowIncompatible=true
1918

2019
statement
2120
CREATE TABLE test_corr(x double, y double, grp string) USING parquet
@@ -28,3 +27,13 @@ SELECT corr(x, y) FROM test_corr
2827

2928
query tolerance=1e-6
3029
SELECT grp, corr(x, y) FROM test_corr GROUP BY grp ORDER BY grp
30+
31+
-- Test permutations of NULL and NaN
32+
statement
33+
CREATE TABLE test_corr_nan(x double, y double, grp string) USING parquet
34+
35+
statement
36+
INSERT INTO test_corr_nan VALUES (cast('NaN' as double), cast('NaN' as double), 'both_nan'), (cast('NaN' as double), 1.0, 'nan_val'), (1.0, cast('NaN' as double), 'val_nan'), (NULL, cast('NaN' as double), 'null_nan'), (cast('NaN' as double), NULL, 'nan_null'), (NULL, NULL, 'both_null'), (NULL, 1.0, 'null_val'), (1.0, NULL, 'val_null'), (cast('NaN' as double), cast('NaN' as double), 'mixed'), (1.0, 2.0, 'mixed'), (3.0, 4.0, 'mixed'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan')
37+
38+
query tolerance=1e-6
39+
SELECT grp, corr(x, y) FROM test_corr_nan GROUP BY grp ORDER BY grp

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import scala.util.Random
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
2626
import org.apache.spark.sql.catalyst.expressions.Cast
27-
import org.apache.spark.sql.catalyst.expressions.aggregate.Corr
2827
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
2928
import org.apache.spark.sql.comet.CometHashAggregateExec
3029
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1306,9 +1305,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13061305
}
13071306

13081307
test("covariance & correlation") {
1309-
withSQLConf(
1310-
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
1311-
CometConf.getExprAllowIncompatConfigKey(classOf[Corr]) -> "true") {
1308+
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
13121309
Seq("jvm", "native").foreach { cometShuffleMode =>
13131310
withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) {
13141311
Seq(true, false).foreach { dictionary =>
@@ -1379,6 +1376,31 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13791376
}
13801377
}
13811378

1379+
test("corr - nan/null") {
1380+
Seq(true, false).foreach { nullOnDivideByZero =>
1381+
withSQLConf("spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) {
1382+
withTable("t") {
1383+
sql("""create table t using parquet as
1384+
select cast(null as float) f1, CAST('NaN' AS float) f2, cast(null as double) d1, CAST('NaN' AS double) d2
1385+
from range(1)
1386+
""")
1387+
1388+
checkSparkAnswerAndOperator("""
1389+
|select
1390+
| corr(f1, f2) c1,
1391+
| corr(f1, f1) c2,
1392+
| corr(f2, f1) c3,
1393+
| corr(f2, f2) c4,
1394+
| corr(d1, d2) c5,
1395+
| corr(d1, d1) c6,
1396+
| corr(d2, d1) c7,
1397+
| corr(d2, d2) c8
1398+
| FROM t""".stripMargin)
1399+
}
1400+
}
1401+
}
1402+
}
1403+
13821404
test("var_pop and var_samp") {
13831405
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
13841406
Seq("native", "jvm").foreach { cometShuffleMode =>

0 commit comments

Comments
 (0)