diff --git a/src/query/expression/src/constant_folder.rs b/src/query/expression/src/constant_folder.rs index 8855356e9cf9b..5be030a4890be 100644 --- a/src/query/expression/src/constant_folder.rs +++ b/src/query/expression/src/constant_folder.rs @@ -517,6 +517,13 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> { ); } + // `grouping` is a placeholder before the aggregate rewriter rewrites it to + // `grouping<...>(_grouping_id)`. Folding it here can reach the dummy + // implementation and panic on invalid queries. + if function.signature.name == "grouping" { + return (func_expr, func_domain); + } + if all_args_is_scalar { let block = DataBlock::empty_with_rows(1); let evaluator = Evaluator::new(&block, self.func_ctx, self.fn_registry); diff --git a/src/query/service/tests/it/sql/planner/semantic/type_check.rs b/src/query/service/tests/it/sql/planner/semantic/type_check.rs index 8f16ac8e610fd..f169387190dc0 100644 --- a/src/query/service/tests/it/sql/planner/semantic/type_check.rs +++ b/src/query/service/tests/it/sql/planner/semantic/type_check.rs @@ -15,7 +15,10 @@ use databend_common_catalog::table_context::TableContext; use databend_common_expression::ColumnIndex; use databend_common_expression::Expr; +use databend_common_sql::Planner; use databend_common_sql::parse_exprs; +use databend_common_sql::plans::Plan; +use databend_query::physical_plans::PhysicalPlanBuilder; use databend_query::test_kits::TestFixture; #[tokio::test(flavor = "multi_thread")] @@ -85,6 +88,120 @@ async fn test_inlist_with_null_builds_shallow_or_tree() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread")] +async fn test_invalid_grouping_returns_semantic_error() -> anyhow::Result<()> { + let fixture = TestFixture::setup().await?; + let ctx = fixture.new_query_ctx().await?; + let mut planner = Planner::new(ctx.clone()); + fixture + .execute_command("CREATE TABLE students(course STRING, type STRING)") + .await?; + + for (sql, expected) in [ + ( + "SELECT GROUPING()", + "grouping requires at least one argument", + ), + ( + "SELECT GROUPING() FROM students", + "grouping requires at least one argument", + ), + ( + "SELECT count() FROM students WHERE GROUPING() = 0 GROUP BY course", + "grouping requires at least one argument", + ), + ( + "SELECT count() OVER () FROM students GROUP BY course QUALIFY GROUPING() = 0", + "grouping requires at least one argument", + ), + ( + "SELECT count() \ + FROM students s1 \ + JOIN students s2 ON GROUPING() = 0 \ + GROUP BY s1.course", + "grouping requires at least one argument", + ), + ( + "SELECT 1 FROM students GROUP BY GROUPING SETS ((GROUPING()))", + "grouping requires at least one argument", + ), + ] { + let err = planner + .plan_sql(sql) + .await + .expect_err("invalid grouping() should return a semantic error"); + assert!( + err.message().contains(expected), + "unexpected error for `{sql}`: {err}", + ); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_grouping_qualify_rewrites_before_semantic_checks() -> anyhow::Result<()> { + let fixture = TestFixture::setup().await?; + let ctx = fixture.new_query_ctx().await?; + let mut planner = Planner::new(ctx.clone()); + fixture + .execute_command("CREATE TABLE students(course STRING, type STRING)") + .await?; + + for sql in [ + "SELECT count() OVER () \ + FROM students \ + GROUP BY GROUPING SETS ((course), ()) \ + QUALIFY GROUPING(course) = 0", + "SELECT GROUPING(course) AS g, count() OVER () \ + FROM students \ + GROUP BY GROUPING SETS ((course), ()) \ + QUALIFY g = 0", + ] { + planner + .plan_sql(sql) + .await + .unwrap_or_else(|err| panic!("expected valid grouping QUALIFY for `{sql}`: {err}")); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_grouping_sets_to_union_keeps_grouping_id_for_qualify_windows() -> anyhow::Result<()> { + let fixture = TestFixture::setup().await?; + let ctx = fixture.new_query_ctx().await?; + ctx.get_settings() + .set_setting("grouping_sets_to_union".to_string(), "1".to_string())?; + + fixture + .execute_command("CREATE TABLE students(course STRING, type STRING)") + .await?; + + let sql = "SELECT course, GROUPING(course) AS g, count() OVER () AS w \ + FROM students \ + GROUP BY GROUPING SETS ((course), ()) \ + QUALIFY GROUPING(course) = 0"; + + let mut planner = Planner::new(ctx.clone()); + let (plan, _) = planner.plan_sql(sql).await?; + + let Plan::Query { + s_expr, + metadata, + bind_context, + .. + } = plan + else { + panic!("expected query plan"); + }; + + let mut builder = PhysicalPlanBuilder::new(metadata, ctx, false); + builder.build(&s_expr, bind_context.column_set()).await?; + + Ok(()) +} + fn max_or_depth(expr: &Expr) -> usize { match expr { Expr::Cast(cast) => max_or_depth(&cast.expr), diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 2041174a4a139..6ead44f43ba18 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -34,6 +34,7 @@ use itertools::Itertools; use super::ExprContext; use super::Finder; use super::prune_by_children; +use super::reject_grouping_functions; use crate::BindContext; use crate::MetadataRef; use crate::Symbol; @@ -660,6 +661,18 @@ impl AggregateInfo { } fn replace_grouping(&self, function: &FunctionCall) -> Result { + // `grouping<...>(_grouping_id)` is the internal rewritten form. Alias-expanded + // QUALIFY expressions can bind to it directly and must not be rewritten again. + if !function.params.is_empty() { + return Ok(function.clone()); + } + + if function.arguments.is_empty() { + return Err(ErrorCode::BadArguments( + "grouping requires at least one argument", + )); + } + if self.grouping_sets.is_none() { return Err(ErrorCode::SemanticError( "grouping can only be called in GROUP BY GROUPING SETS clauses", @@ -1120,16 +1133,14 @@ impl Binder { column: grouping_id_column.clone(), }; - if !self.ctx.get_settings().get_grouping_sets_to_union()? { - agg_info.group_items_map.insert( - bound_grouping_id_col.clone().into(), - agg_info.group_items.len(), - ); - agg_info.group_items.push(ScalarItem { - index: grouping_id_column.index, - scalar: bound_grouping_id_col.into(), - }); - } + agg_info.group_items_map.insert( + bound_grouping_id_col.clone().into(), + agg_info.group_items.len(), + ); + agg_info.group_items.push(ScalarItem { + index: grouping_id_column.index, + scalar: bound_grouping_id_col.into(), + }); let grouping_sets_info = GroupingSetsInfo { grouping_id_column, @@ -1264,6 +1275,15 @@ impl Binder { scalar.is_aggregate() || matches!(scalar, ScalarExpr::WindowFunction(_)) }; + reject_grouping_functions( + bind_context + .aggregate_info + .group_items + .iter() + .map(|item| &item.scalar), + "GROUP BY items", + )?; + for item in bind_context.aggregate_info.group_items.iter() { let mut finder = Finder::new(&f); finder.visit(&item.scalar)?; diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs index 6028ab36cf897..2a8a172429e0e 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs @@ -32,6 +32,7 @@ use crate::MetadataRef; use crate::binder::Finder; use crate::binder::JoinPredicate; use crate::binder::Visibility; +use crate::binder::reject_grouping_functions; use crate::binder::wrap_nullable; use crate::normalize_identifier; use crate::optimizer::OptimizerContext; @@ -759,6 +760,8 @@ impl<'a> JoinConditionResolver<'a> { } fn check_join_allowed_scalar_expr(&mut self, scalars: &Vec) -> Result<()> { + reject_grouping_functions(scalars.iter(), "Join condition")?; + let f = |scalar: &ScalarExpr| { matches!( scalar, diff --git a/src/query/sql/src/planner/binder/qualify.rs b/src/query/sql/src/planner/binder/qualify.rs index 6553e668ccaf1..f30b75b24ad2b 100644 --- a/src/query/sql/src/planner/binder/qualify.rs +++ b/src/query/sql/src/planner/binder/qualify.rs @@ -22,6 +22,7 @@ use crate::BindContext; use crate::Binder; use crate::binder::ExprContext; use crate::binder::ScalarBinder; +use crate::binder::aggregate::AggregateRewriter; use crate::binder::split_conjunctions; use crate::binder::window::WindowRewriter; use crate::binder::window::find_replaced_window_function; @@ -52,6 +53,11 @@ impl Binder { aliases, ); let (mut scalar, _) = scalar_binder.bind(qualify)?; + AggregateRewriter::rewrite_existing_expr( + &bind_context.aggregate_info, + &mut scalar, + "Qualify clause must not contain aggregate functions", + )?; let mut rewriter = WindowRewriter::new(bind_context, self.metadata.clone()); rewriter.visit(&mut scalar)?; Ok(scalar) diff --git a/src/query/sql/src/planner/binder/scalar_common.rs b/src/query/sql/src/planner/binder/scalar_common.rs index 05a37a62f2327..36bb0e295cd28 100644 --- a/src/query/sql/src/planner/binder/scalar_common.rs +++ b/src/query/sql/src/planner/binder/scalar_common.rs @@ -15,6 +15,7 @@ use std::borrow::Cow; use std::collections::HashSet; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::Scalar; use databend_common_expression::types::DataType; @@ -28,6 +29,8 @@ use crate::plans::ScalarExpr; use crate::plans::Visitor; use crate::plans::walk_expr; +pub const GROUPING_FUNC_NAME: &str = "grouping"; + // Visitor that find Expressions that match a particular predicate pub struct Finder<'a, F> where F: Fn(&ScalarExpr) -> bool @@ -75,6 +78,46 @@ where F: Fn(&ScalarExpr) -> bool } } +pub fn is_grouping_function(scalar: &ScalarExpr) -> bool { + matches!( + scalar, + ScalarExpr::FunctionCall(func) if func.func_name.eq_ignore_ascii_case(GROUPING_FUNC_NAME) + ) +} + +pub fn is_raw_grouping_function(scalar: &ScalarExpr) -> bool { + matches!( + scalar, + ScalarExpr::FunctionCall(func) + if func.func_name.eq_ignore_ascii_case(GROUPING_FUNC_NAME) && func.params.is_empty() + ) +} + +pub fn grouping_clause_error(function: &FunctionCall, clause_name: &str) -> ErrorCode { + let err = if function.params.is_empty() && function.arguments.is_empty() { + ErrorCode::BadArguments("grouping requires at least one argument") + } else { + ErrorCode::SemanticError(format!("{clause_name} can't contain grouping functions")) + }; + + err.set_span(function.span) +} + +pub fn reject_grouping_functions<'a>( + scalars: impl IntoIterator, + clause_name: &str, +) -> Result<()> { + for scalar in scalars { + let mut grouping_finder = Finder::new(&is_grouping_function); + grouping_finder.visit(scalar)?; + if let Some(ScalarExpr::FunctionCall(func)) = grouping_finder.scalars().first() { + return Err(grouping_clause_error(func, clause_name)); + } + } + + Ok(()) +} + pub fn split_conjunctions(scalar: &ScalarExpr) -> Vec { match scalar { ScalarExpr::FunctionCall(func) if func.func_name == "and" => [ diff --git a/src/query/sql/src/planner/binder/select.rs b/src/query/sql/src/planner/binder/select.rs index 7f123f3abacc9..635e0e8284f5f 100644 --- a/src/query/sql/src/planner/binder/select.rs +++ b/src/query/sql/src/planner/binder/select.rs @@ -32,6 +32,7 @@ use databend_common_expression::types::DataType; use databend_common_functions::BUILTIN_FUNCTIONS; use super::Finder; +use super::reject_grouping_functions; use super::sort::OrderItem; use crate::ColumnEntry; use crate::ColumnSet; @@ -100,6 +101,8 @@ impl Binder { .set_span(scalar.span())); } + reject_grouping_functions(std::iter::once(&scalar), "Where clause")?; + let filter_plan = Filter { predicates: split_conjunctions(&scalar), }; diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_grouping_sets_to_union.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_grouping_sets_to_union.rs index f920a605d068e..25416898ee0e1 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_grouping_sets_to_union.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_grouping_sets_to_union.rs @@ -38,6 +38,7 @@ use crate::plans::EvalScalar; use crate::plans::MaterializedCTE; use crate::plans::MaterializedCTERef; use crate::plans::RelOp; +use crate::plans::ScalarItem; use crate::plans::Sequence; use crate::plans::UnionAll; use crate::plans::VisitorMut; @@ -85,6 +86,13 @@ impl RuleGroupingSetsToUnion { } } +fn is_grouping_id_item(item: &ScalarItem) -> bool { + matches!( + &item.scalar, + ScalarExpr::BoundColumnRef(col) if col.column.column_name == "_grouping_id" + ) +} + // Must go before `RuleSplitAggregate` impl Rule for RuleGroupingSetsToUnion { fn id(&self) -> RuleID { @@ -92,7 +100,7 @@ impl Rule for RuleGroupingSetsToUnion { } fn apply(&self, s_expr: &SExpr, state: &mut TransformResult) -> Result<()> { - let eval_scalar: EvalScalar = s_expr.plan().clone().try_into()?; + let mut eval_scalar: EvalScalar = s_expr.plan().clone().try_into()?; let agg: Aggregate = s_expr.child(0)?.plan().clone().try_into()?; if agg.mode != AggregateMode::Initial { return Ok(()); @@ -113,6 +121,20 @@ impl Rule for RuleGroupingSetsToUnion { if let Some(grouping_sets) = &agg.grouping_sets { if !grouping_sets.sets.is_empty() { + if !eval_scalar + .items + .iter() + .any(|item| item.index == grouping_sets.grouping_id_index) + { + eval_scalar.items.push(ScalarItem { + index: grouping_sets.grouping_id_index, + scalar: ScalarExpr::ConstantExpr(ConstantExpr { + value: Scalar::Number(NumberScalar::UInt32(0)), + span: None, + }), + }); + } + let mut children = Vec::with_capacity(grouping_sets.sets.len()); let mut hasher = DefaultHasher::new(); @@ -143,6 +165,7 @@ impl Rule for RuleGroupingSetsToUnion { let group_bys = agg .group_items .iter() + .filter(|item| !is_grouping_id_item(item)) .map(|i| { agg_input_columns .iter() @@ -174,6 +197,7 @@ impl Rule for RuleGroupingSetsToUnion { let null_group_ids: Vec = agg .group_items .iter() + .filter(|item| !is_grouping_id_item(item)) .map(|i| i.index) .filter(|index| !set.contains(index)) .clone() @@ -193,6 +217,20 @@ impl Rule for RuleGroupingSetsToUnion { visitor.visit(&mut scalar.scalar)?; } + if !eval_scalar + .items + .iter() + .any(|item| item.index == grouping_sets.grouping_id_index) + { + eval_scalar.items.push(ScalarItem { + index: grouping_sets.grouping_id_index, + scalar: ScalarExpr::ConstantExpr(ConstantExpr { + value: Scalar::Number(NumberScalar::UInt32(grouping_id)), + span: None, + }), + }); + } + let agg_plan = SExpr::create_unary(agg, cte_consumer.clone()); let eval_plan = SExpr::create_unary(eval_scalar, agg_plan); children.push(eval_plan); diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_hierarchical_grouping_sets.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_hierarchical_grouping_sets.rs index cc4f4e53c4dfa..b3299f90da1e4 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_hierarchical_grouping_sets.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_hierarchical_grouping_sets.rs @@ -402,7 +402,8 @@ impl RuleHierarchicalGroupingSetsToUnion { )?; // Step 4: Assemble the complete plan - let union_result = self.create_union_all(&union_branches, eval_scalar)?; + let union_result = + self.create_union_all(&union_branches, eval_scalar, grouping_id_index)?; // Step 5: Chain all CTEs in correct dependency order // Sequence semantics: left executes first, right executes after @@ -777,6 +778,12 @@ impl RuleHierarchicalGroupingSetsToUnion { let null_group_ids: Vec = agg .group_items .iter() + .filter(|item| { + !matches!( + &item.scalar, + ScalarExpr::BoundColumnRef(col) if col.column.column_name == "_grouping_id" + ) + }) .map(|i| i.index) .filter(|index| !group_columns.contains(index)) .collect(); @@ -793,21 +800,45 @@ impl RuleHierarchicalGroupingSetsToUnion { visitor.visit(&mut scalar_item.scalar)?; } + if !eval_scalar + .items + .iter() + .any(|item| item.index == grouping_id_index) + { + eval_scalar.items.push(ScalarItem { + index: grouping_id_index, + scalar: ScalarExpr::ConstantExpr(ConstantExpr { + value: Scalar::Number(NumberScalar::UInt32(grouping_id)), + span: None, + }), + }); + } + Ok(()) } /// Create UNION ALL combining all final branches - fn create_union_all(&self, branches: &[SExpr], eval_scalar: &EvalScalar) -> Result { + fn create_union_all( + &self, + branches: &[SExpr], + eval_scalar: &EvalScalar, + grouping_id_index: Symbol, + ) -> Result { if branches.is_empty() { return Err(databend_common_exception::ErrorCode::Internal( "No branches for union".to_string(), )); } + let mut output_indexes: Vec = eval_scalar.items.iter().map(|x| x.index).collect(); + if !output_indexes.contains(&grouping_id_index) { + output_indexes.push(grouping_id_index); + } + let mut result = branches[0].clone(); for branch in branches.iter().skip(1) { let left_outputs: Vec<(Symbol, Option)> = - eval_scalar.items.iter().map(|x| (x.index, None)).collect(); + output_indexes.iter().map(|x| (*x, None)).collect(); let right_outputs = left_outputs.clone(); let union_plan = UnionAll { @@ -815,7 +846,7 @@ impl RuleHierarchicalGroupingSetsToUnion { right_outputs, cte_scan_names: vec![], logical_recursive_cte_id: None, - output_indexes: eval_scalar.items.iter().map(|x| x.index).collect(), + output_indexes: output_indexes.clone(), }; result = SExpr::create_binary(Arc::new(union_plan.into()), result, branch.clone()); } @@ -824,6 +855,15 @@ impl RuleHierarchicalGroupingSetsToUnion { } fn calculate_grouping_id(&self, group_columns: &[Symbol], all_groups: &[ScalarItem]) -> u32 { + let all_groups: Vec<&ScalarItem> = all_groups + .iter() + .filter(|item| { + !matches!( + &item.scalar, + ScalarExpr::BoundColumnRef(col) if col.column.column_name == "_grouping_id" + ) + }) + .collect(); let mask = (1 << all_groups.len()) - 1; let mut id = 0; diff --git a/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test b/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test index 4d25e44286081..3de26b70aae59 100644 --- a/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test +++ b/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test @@ -66,6 +66,48 @@ create or replace table t (a string, b string, c int); statement ok insert into t values ('a','A',1),('a','A',2),('a','B',1),('a','B',3),('b','A',1),('b','A',4),('b','B',1),('b','B',5); +statement error +SELECT GROUPING() + +statement error +SELECT GROUPING() FROM t + +statement error +SELECT GROUPING(NULL) FROM t + +statement error +SELECT GROUPING(a) FROM t GROUP BY () + +statement error +SELECT count() FROM t WHERE GROUPING() = 0 GROUP BY a + +statement error +SELECT count() OVER () FROM t GROUP BY a QUALIFY GROUPING() = 0 + +statement error +SELECT count() FROM t t1 JOIN t t2 ON GROUPING() = 0 GROUP BY t1.a + +statement error +SELECT 1 FROM t GROUP BY GROUPING SETS ((GROUPING())) + +query TII rowsort +SELECT a, GROUPING(a) AS g, count() OVER () AS w +FROM t +GROUP BY GROUPING SETS ((a), ()) +QUALIFY GROUPING(a) = 0 +---- +a 0 3 +b 0 3 + +query TII rowsort +SELECT a, GROUPING(a) AS g, count() OVER () AS w +FROM t +GROUP BY GROUPING SETS ((a), ()) +QUALIFY g = 0 +---- +a 0 3 +b 0 3 + query TTI select a, b, sum(c) as sc from t group by grouping sets ((a,b),(),(b),(a)) order by sc; ----