Skip to content

Commit cab69a1

Browse files
nathanb9Nathan Bezualemnathanb9
authored
Fix correlated subquery empty defaults for regr_count and approx_distinct (#22319)
## Which issue does this PR close? - Closes #22317. ## Rationale for this change Correlated scalar subqueries with ungrouped aggregates are decorrelated into joins. For unmatched outer rows, the rewritten join naturally produces NULLs on the right side, so DataFusion has compensation logic for aggregates that should return a non-NULL value on empty input. That compensation previously special-cased `count` by name. As a result, other aggregates with non-NULL empty-input results, such as `regr_count` and `approx_distinct`, incorrectly returned NULL after decorrelation. ## What changes are included in this PR? This PR updates decorrelation to use each aggregate UDF's `default_value()` instead of hard-coding `count`. It also adds empty-input defaults for: - `regr_count`: `UInt64(0)` - `approx_distinct`: `UInt64(0)` Regression coverage is added for correlated scalar subqueries using these aggregates in projection expressions and filters. ## Are these changes tested? Yes. ```bash cargo fmt --all cargo test -p datafusion-sqllogictest --test sqllogictests -- subquery.slt ``` ## Are there any user-facing changes? Yes. Queries using `regr_count` or `approx_distinct` in correlated scalar subqueries now return `0` for unmatched outer rows instead of `NULL`, matching the aggregate behavior on empty input. --------- Co-authored-by: Nathan Bezualem <nbez@amazon.com> Co-authored-by: nathanb9 <nathanb9@amazon.com>
1 parent c48e993 commit cab69a1

6 files changed

Lines changed: 92 additions & 16 deletions

File tree

datafusion/core/tests/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ async fn window_using_aggregates() -> Result<()> {
12041204
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
12051205
| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |
12061206
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
1207-
| | | | | | | | 1 | -85 |
1207+
| | | 0 | | | | | 1 | -85 |
12081208
| -85 | -101 | 14 | -12.0 | -12.0 | 83 | -101 | 4 | -54 |
12091209
| -85 | -101 | 17 | -25.0 | -25.0 | 83 | -101 | 5 | -31 |
12101210
| -85 | -12 | 10 | -32.75 | -34.0 | 83 | -85 | 3 | 13 |

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,14 @@ impl AggregateUDFImpl for ApproxDistinct {
381381
Ok(DataType::UInt64)
382382
}
383383

384+
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
385+
Ok(ScalarValue::UInt64(Some(0)))
386+
}
387+
388+
fn is_nullable(&self) -> bool {
389+
false
390+
}
391+
384392
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
385393
let data_type = args.input_fields[0].data_type();
386394
match data_type {

datafusion/functions-aggregate/src/regr.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,18 @@ impl AggregateUDFImpl for Regr {
457457
}
458458
}
459459

460+
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
461+
if self.regr_type == RegrType::Count {
462+
Ok(ScalarValue::UInt64(Some(0)))
463+
} else {
464+
Ok(ScalarValue::Float64(None))
465+
}
466+
}
467+
468+
fn is_nullable(&self) -> bool {
469+
self.regr_type != RegrType::Count
470+
}
471+
460472
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
461473
Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?))
462474
}

datafusion/optimizer/src/decorrelate.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use datafusion_expr::utils::{
3535
collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
3636
};
3737
use datafusion_expr::{
38-
BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder,
39-
Operator, expr, lit,
38+
BinaryExpr, Cast, EmptyRelation, Expr, ExprSchemable, FetchType, LogicalPlan,
39+
LogicalPlanBuilder, Operator, expr, lit,
4040
};
4141

4242
/// This struct rewrite the sub query plan by pull up the correlated
@@ -512,18 +512,12 @@ fn agg_exprs_evaluation_result_on_empty_batch(
512512
let result_expr = e
513513
.clone()
514514
.transform_up(|expr| {
515-
let new_expr = match expr {
516-
Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
517-
if func.name() == "count" {
518-
Transformed::yes(Expr::Literal(
519-
ScalarValue::Int64(Some(0)),
520-
None,
521-
))
522-
} else {
523-
Transformed::yes(Expr::Literal(ScalarValue::Null, None))
524-
}
525-
}
526-
_ => Transformed::no(expr),
515+
let new_expr = if let Expr::AggregateFunction(agg) = &expr {
516+
let return_type = expr.get_type(schema.as_ref())?;
517+
let default_value = agg.func.default_value(&return_type)?;
518+
Transformed::yes(Expr::Literal(default_value, None))
519+
} else {
520+
Transformed::no(expr)
527521
};
528522
Ok(new_expr)
529523
})

datafusion/optimizer/src/scalar_subquery_to_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ mod tests {
819819
assert_optimized_plan_equal!(
820820
plan,
821821
@r#"
822-
Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
822+
Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(Float64(NULL) AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
823823
Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
824824
TableScan: customer [c_custkey:Int64, c_name:Utf8]
825825
SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,68 @@ SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from
888888
33 3
889889
44 0
890890

891+
#correlated_scalar_subquery_non_count_agg_empty_defaults
892+
query III rowsort
893+
SELECT
894+
t1_id,
895+
(
896+
SELECT regr_count(1.0, 1.0)
897+
FROM t2
898+
WHERE t2.t2_int = t1.t1_int
899+
) AS r,
900+
(
901+
SELECT approx_distinct(t2.t2_id)
902+
FROM t2
903+
WHERE t2.t2_int = t1.t1_int
904+
) AS d
905+
FROM t1
906+
----
907+
11 1 1
908+
22 0 0
909+
33 3 3
910+
44 0 0
911+
912+
query II rowsort
913+
SELECT
914+
t1_id,
915+
(
916+
SELECT regr_count(1.0, 1.0) + approx_distinct(t2.t2_id)
917+
FROM t2
918+
WHERE t2.t2_int = t1.t1_int
919+
) AS combined
920+
FROM t1
921+
----
922+
11 2
923+
22 0
924+
33 6
925+
44 0
926+
927+
query I rowsort
928+
SELECT t1_id
929+
FROM t1
930+
WHERE
931+
(
932+
SELECT approx_distinct(t2.t2_id)
933+
FROM t2
934+
WHERE t2.t2_int = t1.t1_int
935+
) = 0
936+
----
937+
22
938+
44
939+
940+
query I rowsort
941+
SELECT t1_id
942+
FROM t1
943+
WHERE
944+
(
945+
SELECT regr_count(1.0, 1.0)
946+
FROM t2
947+
WHERE t2.t2_int = t1.t1_int
948+
) = 0
949+
----
950+
22
951+
44
952+
891953
#correlated_scalar_subquery_count_agg_with_alias
892954
query TT
893955
explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1

0 commit comments

Comments
 (0)