Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6851,3 +6851,50 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {

Ok(())
}

/// Regression test for https://github.com/apache/datafusion/issues/21411
/// grouping() should work when wrapped in an alias via the DataFrame API.
///
/// This bug only manifests through the DataFrame API because `.alias()` wraps
/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression
/// level. The SQL planner handles aliasing separately (via projection), so the
/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper
/// around the aggregate function in SQL queries — making SQL-based tests
/// insufficient to cover this case.
#[tokio::test]
async fn test_grouping_with_alias() -> Result<()> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed this test fails on main.

use datafusion_functions_aggregate::expr_fn::grouping;

let df = create_test_table("test")
.await?
.aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])?
.sort(vec![Sort::new(col("a"), true, false)])?;

let results = df.collect().await?;

let expected = [
"+-----------+---+",
"| a | g |",
"+-----------+---+",
"| 123AbcDef | 0 |",
"| CBAdef | 0 |",
"| abc123 | 0 |",
"| abcDEF | 0 |",
"+-----------+---+",
];
assert_batches_eq!(expected, &results);

// Also verify that nested aliases (e.g. .alias("x").alias("g")) work correctly
let df = create_test_table("test")
.await?
.aggregate(
vec![col("a")],
vec![grouping(col("a")).alias("x").alias("g")],
)?
.sort(vec![Sort::new(col("a"), true, false)])?;

let results = df.collect().await?;
assert_batches_eq!(expected, &results);

Ok(())
}
51 changes: 43 additions & 8 deletions datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ fn replace_grouping_exprs(
.into_iter()
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
{
let grouping_id_type = is_grouping_set
.then(|| {
schema
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
.map(|f| f.data_type().clone())
})
.transpose()?;
match expr {
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
let grouping_id_type = is_grouping_set
.then(|| {
schema
.field_with_name(None, Aggregate::INTERNAL_GROUPING_ID)
.map(|f| f.data_type().clone())
})
.transpose()?;
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
Expand All @@ -117,6 +117,24 @@ fn replace_grouping_exprs(
column.name,
)));
}
Expr::Alias(Alias {
ref relation,
ref name,
..
}) if is_grouping_function(&expr) => {
let function = unwrap_alias_to_grouping_function(&expr)?;
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
grouping_id_type,
)?;
// Preserve the outermost user-provided alias
projection_exprs.push(Expr::Alias(Alias::new(
grouping_expr,
relation.clone(),
name.clone(),
)));
}
_ => {
projection_exprs.push(Expr::Column(column));
new_agg_expr.push(expr);
Expand Down Expand Up @@ -155,10 +173,27 @@ fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Ok(transformed_plan)
}

/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`.
/// Returns an error if the innermost expression is not an `AggregateFunction`,
/// which should not happen if `is_grouping_function` returned true.
fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> {
match expr {
Expr::AggregateFunction(function) => Ok(function),
Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr),
_ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"),
}
}

fn is_grouping_function(expr: &Expr) -> bool {
// TODO: Do something better than name here should grouping be a built
// in expression?
matches!(expr, Expr::AggregateFunction(AggregateFunction { func, .. }) if func.name() == "grouping")
match expr {
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
func.name() == "grouping"
}
Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
_ => false,
}
}

fn contains_grouping_function(exprs: &[Expr]) -> bool {
Expand Down
Loading