Skip to content

Commit ebf6089

Browse files
committed
feat(calcite): fix visitor issues for V2 SQL AST integration
Fix four CalciteRelNodeVisitor/RexNodeVisitor gaps surfaced when routing the V2 SQL AST through the unified query path: - Add Alias case in project expansion: SQL `SELECT COUNT(*) AS cnt ... GROUP BY ...` previously crashed with "Unexpected expression type: Alias". Reference already-computed aggregates by schema name instead of re-analyzing. - Add visitLimit override for SQL LIMIT/OFFSET. PPL is unaffected (uses the Head node). - Fix bucketNullable NPE when SQL GROUP BY produces an Aggregation without BUCKET_NULLABLE arg. Use `getOrDefault(..., Literal.TRUE)` to match Analyzer.java. - Rewrite visitWindowFunction to handle AggregateFunction (V2 SQL) + DISTINCT + OVER ORDER BY (previously hardcoded to List.of()). Register ROW_NUMBER/RANK/DENSE_RANK in WINDOW_FUNC_MAPPING and bypass aggregate validation for these pure window functions. Add a makeOver overload with a distinct parameter (old signature delegates with false to preserve PPL behavior). Add seven tests in UnifiedQueryPlannerSqlTest, one per fix. Verified zero PPL regressions across 2300+ existing tests. Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 6421658 commit ebf6089

5 files changed

Lines changed: 201 additions & 11 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlTest.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,4 +259,82 @@ SELECT department, count(*)
259259
""")
260260
.assertErrorMessage("Encountered");
261261
}
262+
263+
@Test
264+
public void testSqlLimitOffset() {
265+
givenQuery(
266+
"""
267+
SELECT name
268+
FROM catalog.employees
269+
LIMIT 10 OFFSET 5\
270+
""")
271+
.assertPlanContains("LogicalSort(offset=[5], fetch=[10])");
272+
}
273+
274+
@Test
275+
public void testSqlAggregateWithAlias() {
276+
givenQuery(
277+
"""
278+
SELECT department, COUNT(*) AS cnt
279+
FROM catalog.employees
280+
GROUP BY department\
281+
""")
282+
.assertPlanContains("LogicalAggregate(group=[{0}]")
283+
.assertPlanContains("COUNT()");
284+
}
285+
286+
@Test
287+
public void testSqlWindowFunctionWithOrderBy() {
288+
givenQuery(
289+
"""
290+
SELECT name, SUM(age) OVER (PARTITION BY department ORDER BY id) AS running_sum
291+
FROM catalog.employees\
292+
""")
293+
.assertPlanContains("OVER")
294+
.assertPlanContains("ORDER BY");
295+
}
296+
297+
@Test
298+
public void testSqlWindowRankFunction() {
299+
givenQuery(
300+
"""
301+
SELECT name, RANK() OVER (ORDER BY age DESC) AS rnk
302+
FROM catalog.employees\
303+
""")
304+
.assertPlanContains("RANK() OVER");
305+
}
306+
307+
@Test
308+
public void testSqlWindowDistinctAggregate() {
309+
givenQuery(
310+
"""
311+
SELECT name, COUNT(DISTINCT department) OVER (PARTITION BY department) AS dist_cnt
312+
FROM catalog.employees\
313+
""")
314+
.assertPlanContains("OVER")
315+
.assertPlanContains("COUNT");
316+
}
317+
318+
@Test
319+
public void testSqlGroupByWithoutBucketNullable() {
320+
givenQuery(
321+
"""
322+
SELECT age, COUNT(*) AS cnt
323+
FROM catalog.employees
324+
GROUP BY age\
325+
""")
326+
.assertPlanContains("LogicalAggregate(group=[{0}]")
327+
.assertPlanContains("COUNT()");
328+
}
329+
330+
@Test
331+
public void testSqlSelectWithAlias() {
332+
givenQuery(
333+
"""
334+
SELECT age AS employee_age, name AS employee_name
335+
FROM catalog.employees\
336+
""")
337+
.assertPlanContains("LogicalProject")
338+
.assertPlanContains("LogicalTableScan");
339+
}
262340
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
import org.opensearch.sql.ast.tree.Head;
133133
import org.opensearch.sql.ast.tree.Join;
134134
import org.opensearch.sql.ast.tree.Kmeans;
135+
import org.opensearch.sql.ast.tree.Limit;
135136
import org.opensearch.sql.ast.tree.Lookup;
136137
import org.opensearch.sql.ast.tree.Lookup.OutputStrategy;
137138
import org.opensearch.sql.ast.tree.ML;
@@ -541,6 +542,25 @@ private List<RexNode> expandProjectFields(
541542
.filter(addedFields::add)
542543
.forEach(field -> expandedFields.add(context.relBuilder.field(field)));
543544
}
545+
case Alias alias -> {
546+
String aliasName =
547+
Strings.isNullOrEmpty(alias.getAlias()) ? alias.getName() : alias.getAlias();
548+
// When an aggregate was already computed (its output name matches a current field),
549+
// reference the existing field instead of re-analyzing (which would return null).
550+
if (alias.getDelegated() instanceof AggregateFunction) {
551+
String aggFieldName = alias.getName();
552+
if (currentFields.contains(aliasName)) {
553+
expandedFields.add(context.relBuilder.field(aliasName));
554+
} else if (aggFieldName != null && currentFields.contains(aggFieldName)) {
555+
expandedFields.add(
556+
context.relBuilder.alias(context.relBuilder.field(aggFieldName), aliasName));
557+
} else {
558+
expandedFields.add(rexVisitor.analyze(alias, context));
559+
}
560+
} else {
561+
expandedFields.add(rexVisitor.analyze(alias, context));
562+
}
563+
}
544564
default ->
545565
throw new IllegalStateException(
546566
"Unexpected expression type in project list: " + expr.getClass().getSimpleName());
@@ -763,6 +783,13 @@ public RelNode visitHead(Head node, CalcitePlanContext context) {
763783
return context.relBuilder.peek();
764784
}
765785

786+
@Override
787+
public RelNode visitLimit(Limit node, CalcitePlanContext context) {
788+
visitChildren(node, context);
789+
context.relBuilder.limit(node.getOffset(), node.getLimit());
790+
return context.relBuilder.peek();
791+
}
792+
766793
/**
767794
* Insert a reversed sort node after finding the original sort in the tree. This rebuilds the tree
768795
* with the reversed sort inserted right after the original sort.
@@ -1621,7 +1648,8 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
16211648
@Override
16221649
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
16231650
Argument.ArgumentMap statsArgs = Argument.ArgumentMap.of(node.getArgExprList());
1624-
Boolean bucketNullable = (Boolean) statsArgs.get(Argument.BUCKET_NULLABLE).getValue();
1651+
boolean bucketNullable =
1652+
(Boolean) statsArgs.getOrDefault(Argument.BUCKET_NULLABLE, Literal.TRUE).getValue();
16251653
int nGroup = node.getGroupExprList().size() + (Objects.nonNull(node.getSpan()) ? 1 : 0);
16261654
BitSet nonNullGroupMask = new BitSet(nGroup);
16271655
if (!bucketNullable) {

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Locale;
2020
import java.util.Map;
21+
import java.util.Set;
2122
import java.util.stream.Collectors;
2223
import java.util.stream.IntStream;
2324
import javax.annotation.Nullable;
@@ -42,6 +43,7 @@
4243
import org.apache.calcite.util.TimestampString;
4344
import org.apache.logging.log4j.util.Strings;
4445
import org.opensearch.sql.ast.AbstractNodeVisitor;
46+
import org.opensearch.sql.ast.expression.AggregateFunction;
4547
import org.opensearch.sql.ast.expression.Alias;
4648
import org.opensearch.sql.ast.expression.And;
4749
import org.opensearch.sql.ast.expression.Between;
@@ -90,6 +92,10 @@
9092
public class CalciteRexNodeVisitor extends AbstractNodeVisitor<RexNode, CalcitePlanContext> {
9193
private final CalciteRelNodeVisitor planVisitor;
9294

95+
private static final Set<BuiltinFunctionName> PURE_WINDOW_FUNCTIONS =
96+
Set.of(
97+
BuiltinFunctionName.ROW_NUMBER, BuiltinFunctionName.RANK, BuiltinFunctionName.DENSE_RANK);
98+
9399
public RexNode analyze(UnresolvedExpression unresolved, CalcitePlanContext context) {
94100
return unresolved.accept(this, context);
95101
}
@@ -563,47 +569,93 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
563569

564570
@Override
565571
public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext context) {
566-
Function windowFunction = (Function) node.getFunction();
567-
List<RexNode> arguments =
568-
windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
572+
String funcName;
573+
List<RexNode> arguments;
574+
final boolean isDistinct;
575+
if (node.getFunction() instanceof AggregateFunction aggFunc) {
576+
funcName = aggFunc.getFuncName();
577+
isDistinct = Boolean.TRUE.equals(aggFunc.getDistinct());
578+
List<UnresolvedExpression> argExprs = new java.util.ArrayList<>();
579+
if (aggFunc.getField() != null) {
580+
argExprs.add(aggFunc.getField());
581+
}
582+
argExprs.addAll(aggFunc.getArgList());
583+
arguments = argExprs.stream().map(arg -> analyze(arg, context)).toList();
584+
} else {
585+
Function windowFunction = (Function) node.getFunction();
586+
funcName = windowFunction.getFuncName();
587+
isDistinct = false;
588+
arguments = windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
589+
}
569590
List<RexNode> partitions =
570591
node.getPartitionByList().stream()
571592
.map(arg -> analyze(arg, context))
572593
.map(this::extractRexNodeFromAlias)
573594
.toList();
574-
return BuiltinFunctionName.ofWindowFunction(windowFunction.getFuncName())
595+
List<RexNode> orderKeys =
596+
node.getSortList().stream()
597+
.map(
598+
pair -> {
599+
RexNode sortField = analyze(pair.getRight(), context);
600+
if (pair.getLeft().getSortOrder()
601+
== org.opensearch.sql.ast.tree.Sort.SortOrder.DESC) {
602+
sortField = context.relBuilder.desc(sortField);
603+
}
604+
if (pair.getLeft().getNullOrder()
605+
== org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST) {
606+
sortField = context.relBuilder.nullsLast(sortField);
607+
} else if (pair.getLeft().getNullOrder()
608+
== org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST) {
609+
sortField = context.relBuilder.nullsFirst(sortField);
610+
}
611+
return sortField;
612+
})
613+
.toList();
614+
return BuiltinFunctionName.ofWindowFunction(funcName)
575615
.map(
576616
functionName -> {
577617
RexNode field = arguments.isEmpty() ? null : arguments.getFirst();
578618
List<RexNode> args =
579619
(arguments.isEmpty() || arguments.size() == 1)
580620
? Collections.emptyList()
581621
: arguments.subList(1, arguments.size());
622+
// Pure window functions (ROW_NUMBER, RANK, DENSE_RANK) are not registered
623+
// in aggFunctionRegistry, so skip validation for them.
624+
if (PURE_WINDOW_FUNCTIONS.contains(functionName)) {
625+
return PlanUtils.makeOver(
626+
context,
627+
functionName,
628+
field,
629+
args,
630+
partitions,
631+
orderKeys,
632+
node.getWindowFrame());
633+
}
582634
List<RexNode> nodes =
583635
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
584636
functionName, field, args, context.rexBuilder);
585637
return nodes != null
586638
? PlanUtils.makeOver(
587639
context,
588640
functionName,
641+
isDistinct,
589642
nodes.getFirst(),
590643
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
591644
partitions,
592-
List.of(),
645+
orderKeys,
593646
node.getWindowFrame())
594647
: PlanUtils.makeOver(
595648
context,
596649
functionName,
650+
isDistinct,
597651
field,
598652
args,
599653
partitions,
600-
List.of(),
654+
orderKeys,
601655
node.getWindowFrame());
602656
})
603657
.orElseThrow(
604-
() ->
605-
new UnsupportedOperationException(
606-
"Unexpected window function: " + windowFunction.getFuncName()));
658+
() -> new UnsupportedOperationException("Unexpected window function: " + funcName));
607659
}
608660

609661
/** extract the expression of Alias from a node */

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ static RexNode makeOver(
171171
List<RexNode> partitions,
172172
List<RexNode> orderKeys,
173173
@Nullable WindowFrame windowFrame) {
174+
return makeOver(
175+
context, functionName, false, field, argList, partitions, orderKeys, windowFrame);
176+
}
177+
178+
static RexNode makeOver(
179+
CalcitePlanContext context,
180+
BuiltinFunctionName functionName,
181+
boolean distinct,
182+
RexNode field,
183+
List<RexNode> argList,
184+
List<RexNode> partitions,
185+
List<RexNode> orderKeys,
186+
@Nullable WindowFrame windowFrame) {
174187
if (windowFrame == null) {
175188
windowFrame = WindowFrame.rowsUnbounded();
176189
}
@@ -216,6 +229,22 @@ static RexNode makeOver(
216229
true,
217230
lowerBound,
218231
upperBound);
232+
case RANK:
233+
return withOver(
234+
context.relBuilder.aggregateCall(SqlStdOperatorTable.RANK),
235+
partitions,
236+
orderKeys,
237+
true,
238+
lowerBound,
239+
upperBound);
240+
case DENSE_RANK:
241+
return withOver(
242+
context.relBuilder.aggregateCall(SqlStdOperatorTable.DENSE_RANK),
243+
partitions,
244+
orderKeys,
245+
true,
246+
lowerBound,
247+
upperBound);
219248
case NTH_VALUE:
220249
return withOver(
221250
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
@@ -226,7 +255,7 @@ static RexNode makeOver(
226255
upperBound);
227256
default:
228257
return withOver(
229-
makeAggCall(context, functionName, false, field, argList),
258+
makeAggCall(context, functionName, distinct, field, argList),
230259
partitions,
231260
orderKeys,
232261
rows,

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ public enum BuiltinFunctionName {
425425
.put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
426426
.put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
427427
.put("pattern", BuiltinFunctionName.INTERNAL_PATTERN)
428+
.put("row_number", BuiltinFunctionName.ROW_NUMBER)
429+
.put("rank", BuiltinFunctionName.RANK)
430+
.put("dense_rank", BuiltinFunctionName.DENSE_RANK)
428431
.build();
429432

430433
public static Optional<BuiltinFunctionName> of(String str) {

0 commit comments

Comments
 (0)