Skip to content

Commit 097efae

Browse files
authored
fix(substrait): dedupe names of aggregate measures, not just groupings (#22453)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change When the substrait consumer hits an `Aggregate` with two identical measures (e.g. `sum(a)` present twice), planning fails with `Schema contains duplicate unqualified field name`. Substrait carries column names at the plan root rather than on the measures themselves, so the measures arrive at `Aggregate` schema construction without aliases -- and two identical exprs produce two identical field names. PR #20539 fixed the `NameTracker` to dedupe duplicate names in the consumer, but it was only applied to grouping expressions, not to the measures. The planner sees: ``` field 1: (qualifier: None, name: "sum(data.a)") field 2: (qualifier: None, name: "sum(data.a)") ``` which is rejected when constructing the Aggregate's output schema. ## What changes are included in this PR? Run aggregate measures through the same `NameTracker` like the grouping expressions in `from_aggregate_rel` ## Are these changes tested? Yes -- added a roundtrip test `aggregate_identical_measures`. Without the fix it produces `Error: SchemaError(DuplicateUnqualifiedField { name: "sum(data.a)" }, Some(""))` ## Are there any user-facing changes? No.
1 parent 9986525 commit 097efae

3 files changed

Lines changed: 133 additions & 3 deletions

File tree

datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,17 @@ pub async fn from_aggregate_rel(
109109
aggr_exprs.push(std::sync::Arc::unwrap_or_clone(agg_func?));
110110
}
111111

112-
// Ensure that all expressions have a unique name
112+
// Ensure that all expressions have a unique name. Both grouping and
113+
// aggregate expressions become fields in the aggregate's output schema,
114+
// so they share a single namespace.
113115
let mut name_tracker = NameTracker::new();
114116
let group_exprs = group_exprs
115-
.iter()
116-
.map(|e| name_tracker.get_uniquely_named_expr(e.clone()))
117+
.into_iter()
118+
.map(|e| name_tracker.get_uniquely_named_expr(e))
119+
.collect::<Result<Vec<Expr>, _>>()?;
120+
let aggr_exprs = aggr_exprs
121+
.into_iter()
122+
.map(|e| name_tracker.get_uniquely_named_expr(e))
117123
.collect::<Result<Vec<Expr>, _>>()?;
118124

119125
input.aggregate(group_exprs, aggr_exprs)?.build()

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,27 @@ async fn aggregate_identical_grouping_expressions() -> Result<()> {
11161116
Ok(())
11171117
}
11181118

1119+
#[tokio::test]
1120+
async fn aggregate_identical_measures() -> Result<()> {
1121+
// Two identical aggregate measures share the same schema_name; without
1122+
// NameTracker dedup over measures, building the Aggregate's output
1123+
// DFSchema fails with "Schema contains duplicate unqualified field name".
1124+
let proto_plan = read_json(
1125+
"tests/testdata/test_plans/aggregate_identical_measures.substrait.json",
1126+
);
1127+
1128+
let plan = generate_plan_from_substrait(proto_plan).await?;
1129+
assert_snapshot!(
1130+
plan,
1131+
@r"
1132+
Projection: __common_expr_1 AS sum_a_1, __common_expr_1 AS sum(data.a)__temp__0 AS sum_a_2
1133+
Aggregate: groupBy=[[]], aggr=[[sum(data.a) AS __common_expr_1]]
1134+
TableScan: data projection=[a]
1135+
"
1136+
);
1137+
Ok(())
1138+
}
1139+
11191140
#[tokio::test]
11201141
async fn simple_intersect_consume() -> Result<()> {
11211142
let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json");
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
{
2+
"extensionUris": [{
3+
"extensionUriAnchor": 1,
4+
"uri": "/functions_arithmetic.yaml"
5+
}],
6+
"extensions": [{
7+
"extensionFunction": {
8+
"extensionUriReference": 1,
9+
"functionAnchor": 0,
10+
"name": "sum:i64"
11+
}
12+
}],
13+
"relations": [{
14+
"root": {
15+
"input": {
16+
"aggregate": {
17+
"common": {
18+
"direct": {}
19+
},
20+
"input": {
21+
"read": {
22+
"common": {
23+
"direct": {}
24+
},
25+
"baseSchema": {
26+
"names": ["a"],
27+
"struct": {
28+
"types": [{
29+
"i64": {
30+
"nullability": "NULLABILITY_NULLABLE"
31+
}
32+
}],
33+
"nullability": "NULLABILITY_REQUIRED"
34+
}
35+
},
36+
"namedTable": {
37+
"names": ["data"]
38+
}
39+
}
40+
},
41+
"groupings": [{
42+
"groupingExpressions": []
43+
}],
44+
"measures": [
45+
{
46+
"measure": {
47+
"functionReference": 0,
48+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
49+
"outputType": {
50+
"i64": {
51+
"nullability": "NULLABILITY_NULLABLE"
52+
}
53+
},
54+
"invocation": "AGGREGATION_INVOCATION_ALL",
55+
"arguments": [{
56+
"value": {
57+
"selection": {
58+
"directReference": {
59+
"structField": {
60+
"field": 0
61+
}
62+
},
63+
"rootReference": {}
64+
}
65+
}
66+
}]
67+
}
68+
},
69+
{
70+
"measure": {
71+
"functionReference": 0,
72+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
73+
"outputType": {
74+
"i64": {
75+
"nullability": "NULLABILITY_NULLABLE"
76+
}
77+
},
78+
"invocation": "AGGREGATION_INVOCATION_ALL",
79+
"arguments": [{
80+
"value": {
81+
"selection": {
82+
"directReference": {
83+
"structField": {
84+
"field": 0
85+
}
86+
},
87+
"rootReference": {}
88+
}
89+
}
90+
}]
91+
}
92+
}
93+
]
94+
}
95+
},
96+
"names": ["sum_a_1", "sum_a_2"]
97+
}
98+
}],
99+
"version": {
100+
"minorNumber": 54,
101+
"producer": "manual"
102+
}
103+
}

0 commit comments

Comments
 (0)