Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/query/expression/src/constant_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,13 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
);
}

// `grouping` is a placeholder before the aggregate rewriter rewrites it to
// `grouping<...>(_grouping_id)`. Folding it here can reach the dummy
// implementation and panic on invalid queries.
if function.signature.name == "grouping" {
return (func_expr, func_domain);
Comment thread
sundy-li marked this conversation as resolved.
}

if all_args_is_scalar {
let block = DataBlock::empty_with_rows(1);
let evaluator = Evaluator::new(&block, self.func_ctx, self.fn_registry);
Expand Down
117 changes: 117 additions & 0 deletions src/query/service/tests/it/sql/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
use databend_common_catalog::table_context::TableContext;
use databend_common_expression::ColumnIndex;
use databend_common_expression::Expr;
use databend_common_sql::Planner;
use databend_common_sql::parse_exprs;
use databend_common_sql::plans::Plan;
use databend_query::physical_plans::PhysicalPlanBuilder;
use databend_query::test_kits::TestFixture;

#[tokio::test(flavor = "multi_thread")]
Expand Down Expand Up @@ -85,6 +88,120 @@ async fn test_inlist_with_null_builds_shallow_or_tree() -> anyhow::Result<()> {
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_invalid_grouping_returns_semantic_error() -> anyhow::Result<()> {
let fixture = TestFixture::setup().await?;
let ctx = fixture.new_query_ctx().await?;
let mut planner = Planner::new(ctx.clone());
fixture
.execute_command("CREATE TABLE students(course STRING, type STRING)")
.await?;

for (sql, expected) in [
(
"SELECT GROUPING()",
"grouping requires at least one argument",
),
(
"SELECT GROUPING() FROM students",
"grouping requires at least one argument",
),
(
"SELECT count() FROM students WHERE GROUPING() = 0 GROUP BY course",
"grouping requires at least one argument",
),
(
"SELECT count() OVER () FROM students GROUP BY course QUALIFY GROUPING() = 0",
"grouping requires at least one argument",
),
(
"SELECT count() \
FROM students s1 \
JOIN students s2 ON GROUPING() = 0 \
GROUP BY s1.course",
"grouping requires at least one argument",
),
(
"SELECT 1 FROM students GROUP BY GROUPING SETS ((GROUPING()))",
"grouping requires at least one argument",
),
] {
let err = planner
.plan_sql(sql)
.await
.expect_err("invalid grouping() should return a semantic error");
assert!(
err.message().contains(expected),
"unexpected error for `{sql}`: {err}",
);
}

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_grouping_qualify_rewrites_before_semantic_checks() -> anyhow::Result<()> {
let fixture = TestFixture::setup().await?;
let ctx = fixture.new_query_ctx().await?;
let mut planner = Planner::new(ctx.clone());
fixture
.execute_command("CREATE TABLE students(course STRING, type STRING)")
.await?;

for sql in [
"SELECT count() OVER () \
FROM students \
GROUP BY GROUPING SETS ((course), ()) \
QUALIFY GROUPING(course) = 0",
"SELECT GROUPING(course) AS g, count() OVER () \
FROM students \
GROUP BY GROUPING SETS ((course), ()) \
QUALIFY g = 0",
] {
planner
.plan_sql(sql)
.await
.unwrap_or_else(|err| panic!("expected valid grouping QUALIFY for `{sql}`: {err}"));
}

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_grouping_sets_to_union_keeps_grouping_id_for_qualify_windows() -> anyhow::Result<()> {
let fixture = TestFixture::setup().await?;
let ctx = fixture.new_query_ctx().await?;
ctx.get_settings()
.set_setting("grouping_sets_to_union".to_string(), "1".to_string())?;

fixture
.execute_command("CREATE TABLE students(course STRING, type STRING)")
.await?;

let sql = "SELECT course, GROUPING(course) AS g, count() OVER () AS w \
FROM students \
GROUP BY GROUPING SETS ((course), ()) \
QUALIFY GROUPING(course) = 0";

let mut planner = Planner::new(ctx.clone());
let (plan, _) = planner.plan_sql(sql).await?;

let Plan::Query {
s_expr,
metadata,
bind_context,
..
} = plan
else {
panic!("expected query plan");
};

let mut builder = PhysicalPlanBuilder::new(metadata, ctx, false);
builder.build(&s_expr, bind_context.column_set()).await?;

Ok(())
}

fn max_or_depth<I: ColumnIndex>(expr: &Expr<I>) -> usize {
match expr {
Expr::Cast(cast) => max_or_depth(&cast.expr),
Expand Down
40 changes: 30 additions & 10 deletions src/query/sql/src/planner/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use itertools::Itertools;
use super::ExprContext;
use super::Finder;
use super::prune_by_children;
use super::reject_grouping_functions;
use crate::BindContext;
use crate::MetadataRef;
use crate::Symbol;
Expand Down Expand Up @@ -660,6 +661,18 @@ impl AggregateInfo {
}

fn replace_grouping(&self, function: &FunctionCall) -> Result<FunctionCall> {
// `grouping<...>(_grouping_id)` is the internal rewritten form. Alias-expanded
// QUALIFY expressions can bind to it directly and must not be rewritten again.
if !function.params.is_empty() {
return Ok(function.clone());
}

if function.arguments.is_empty() {
return Err(ErrorCode::BadArguments(
"grouping requires at least one argument",
));
}

if self.grouping_sets.is_none() {
return Err(ErrorCode::SemanticError(
"grouping can only be called in GROUP BY GROUPING SETS clauses",
Expand Down Expand Up @@ -1120,16 +1133,14 @@ impl Binder {
column: grouping_id_column.clone(),
};

if !self.ctx.get_settings().get_grouping_sets_to_union()? {
agg_info.group_items_map.insert(
bound_grouping_id_col.clone().into(),
agg_info.group_items.len(),
);
agg_info.group_items.push(ScalarItem {
index: grouping_id_column.index,
scalar: bound_grouping_id_col.into(),
});
}
agg_info.group_items_map.insert(
bound_grouping_id_col.clone().into(),
agg_info.group_items.len(),
);
agg_info.group_items.push(ScalarItem {
index: grouping_id_column.index,
scalar: bound_grouping_id_col.into(),
});

let grouping_sets_info = GroupingSetsInfo {
grouping_id_column,
Expand Down Expand Up @@ -1264,6 +1275,15 @@ impl Binder {
scalar.is_aggregate() || matches!(scalar, ScalarExpr::WindowFunction(_))
};

reject_grouping_functions(
bind_context
.aggregate_info
.group_items
.iter()
.map(|item| &item.scalar),
"GROUP BY items",
)?;

for item in bind_context.aggregate_info.group_items.iter() {
let mut finder = Finder::new(&f);
finder.visit(&item.scalar)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::MetadataRef;
use crate::binder::Finder;
use crate::binder::JoinPredicate;
use crate::binder::Visibility;
use crate::binder::reject_grouping_functions;
use crate::binder::wrap_nullable;
use crate::normalize_identifier;
use crate::optimizer::OptimizerContext;
Expand Down Expand Up @@ -759,6 +760,8 @@ impl<'a> JoinConditionResolver<'a> {
}

fn check_join_allowed_scalar_expr(&mut self, scalars: &Vec<ScalarExpr>) -> Result<()> {
reject_grouping_functions(scalars.iter(), "Join condition")?;

let f = |scalar: &ScalarExpr| {
matches!(
scalar,
Expand Down
6 changes: 6 additions & 0 deletions src/query/sql/src/planner/binder/qualify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::BindContext;
use crate::Binder;
use crate::binder::ExprContext;
use crate::binder::ScalarBinder;
use crate::binder::aggregate::AggregateRewriter;
use crate::binder::split_conjunctions;
use crate::binder::window::WindowRewriter;
use crate::binder::window::find_replaced_window_function;
Expand Down Expand Up @@ -52,6 +53,11 @@ impl Binder {
aliases,
);
let (mut scalar, _) = scalar_binder.bind(qualify)?;
AggregateRewriter::rewrite_existing_expr(
&bind_context.aggregate_info,
&mut scalar,
"Qualify clause must not contain aggregate functions",
)?;
let mut rewriter = WindowRewriter::new(bind_context, self.metadata.clone());
rewriter.visit(&mut scalar)?;
Ok(scalar)
Expand Down
43 changes: 43 additions & 0 deletions src/query/sql/src/planner/binder/scalar_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::borrow::Cow;
use std::collections::HashSet;

use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::Scalar;
use databend_common_expression::types::DataType;
Expand All @@ -28,6 +29,8 @@ use crate::plans::ScalarExpr;
use crate::plans::Visitor;
use crate::plans::walk_expr;

pub const GROUPING_FUNC_NAME: &str = "grouping";

// Visitor that find Expressions that match a particular predicate
pub struct Finder<'a, F>
where F: Fn(&ScalarExpr) -> bool
Expand Down Expand Up @@ -75,6 +78,46 @@ where F: Fn(&ScalarExpr) -> bool
}
}

pub fn is_grouping_function(scalar: &ScalarExpr) -> bool {
matches!(
scalar,
ScalarExpr::FunctionCall(func) if func.func_name.eq_ignore_ascii_case(GROUPING_FUNC_NAME)
)
}

pub fn is_raw_grouping_function(scalar: &ScalarExpr) -> bool {
matches!(
scalar,
ScalarExpr::FunctionCall(func)
if func.func_name.eq_ignore_ascii_case(GROUPING_FUNC_NAME) && func.params.is_empty()
)
}

pub fn grouping_clause_error(function: &FunctionCall, clause_name: &str) -> ErrorCode {
let err = if function.params.is_empty() && function.arguments.is_empty() {
ErrorCode::BadArguments("grouping requires at least one argument")
} else {
ErrorCode::SemanticError(format!("{clause_name} can't contain grouping functions"))
};

err.set_span(function.span)
}

pub fn reject_grouping_functions<'a>(
scalars: impl IntoIterator<Item = &'a ScalarExpr>,
clause_name: &str,
) -> Result<()> {
for scalar in scalars {
let mut grouping_finder = Finder::new(&is_grouping_function);
grouping_finder.visit(scalar)?;
if let Some(ScalarExpr::FunctionCall(func)) = grouping_finder.scalars().first() {
return Err(grouping_clause_error(func, clause_name));
}
}

Ok(())
}

pub fn split_conjunctions(scalar: &ScalarExpr) -> Vec<ScalarExpr> {
match scalar {
ScalarExpr::FunctionCall(func) if func.func_name == "and" => [
Expand Down
3 changes: 3 additions & 0 deletions src/query/sql/src/planner/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use databend_common_expression::types::DataType;
use databend_common_functions::BUILTIN_FUNCTIONS;

use super::Finder;
use super::reject_grouping_functions;
use super::sort::OrderItem;
use crate::ColumnEntry;
use crate::ColumnSet;
Expand Down Expand Up @@ -100,6 +101,8 @@ impl Binder {
.set_span(scalar.span()));
}

reject_grouping_functions(std::iter::once(&scalar), "Where clause")?;

let filter_plan = Filter {
predicates: split_conjunctions(&scalar),
};
Expand Down
Loading
Loading