Skip to content

Commit 3211f0b

Browse files
committed
feat(calcite): support window functions, ROW_NUMBER, and register ISNULL
Add AggregateFunction handling in visitWindowFunction to support aggregate-based window expressions with DISTINCT and ORDER BY keys. Add translateOrderKeys utility for window ORDER BY translation. Register row_number in WINDOW_FUNC_MAPPING and skip aggregate signature validation for it (it has no field/args). Pass distinct flag through makeOver call chain. RANK and DENSE_RANK are deferred to a follow-up alongside the open PPL eventstats/streamstats issue (#5168) which involves the same function registration and a separate ORDER BY semantics question. Register ISNULL as alias for IS_NULL in PPLFuncImpTable. Add integration tests for window functions with ORDER BY, ROW_NUMBER, COUNT DISTINCT OVER, and ISNULL. Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 498b25f commit 3211f0b

5 files changed

Lines changed: 139 additions & 10 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,64 @@ SELECT department, count(*)
260260
.assertErrorMessage("Encountered");
261261
}
262262

263+
@Test
264+
public void testSqlWindowFunctionWithOrderBy() {
265+
givenQuery(
266+
"""
267+
SELECT name, SUM(age) OVER (PARTITION BY department ORDER BY id) AS running_sum
268+
FROM catalog.employees\
269+
""")
270+
.assertPlan(
271+
"""
272+
LogicalProject(name=[$1], running_sum=[SUM($2) OVER (PARTITION BY $3 ORDER BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)])
273+
LogicalTableScan(table=[[catalog, employees]])
274+
""");
275+
}
276+
277+
@Test
278+
public void testSqlWindowRowNumber() {
279+
givenQuery(
280+
"""
281+
SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn
282+
FROM catalog.employees\
283+
""")
284+
.assertPlan(
285+
"""
286+
LogicalProject(name=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $0)])
287+
LogicalTableScan(table=[[catalog, employees]])
288+
""");
289+
}
290+
291+
@Test
292+
public void testSqlWindowDistinctAggregate() {
293+
givenQuery(
294+
"""
295+
SELECT name, COUNT(DISTINCT department) OVER (PARTITION BY department) AS dist_cnt
296+
FROM catalog.employees\
297+
""")
298+
.assertPlan(
299+
"""
300+
LogicalProject(name=[$1], dist_cnt=[COUNT(DISTINCT $3) OVER (PARTITION BY $3)])
301+
LogicalTableScan(table=[[catalog, employees]])
302+
""");
303+
}
304+
305+
@Test
306+
public void testSqlIsNullFunction() {
307+
// ISNULL(field) — exercises the ISNULL alias registration in PPLFuncImpTable.
308+
// Calcite constant-folds to false since test schema columns are NOT NULL.
309+
givenQuery(
310+
"""
311+
SELECT ISNULL(department) AS is_null
312+
FROM catalog.employees\
313+
""")
314+
.assertPlan(
315+
"""
316+
LogicalProject(is_null=[false])
317+
LogicalTableScan(table=[[catalog, employees]])
318+
""");
319+
}
320+
263321
@Test
264322
public void testSqlLimitOffset() {
265323
givenQuery(

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@
3737
import org.apache.calcite.sql.type.ArraySqlType;
3838
import org.apache.calcite.sql.type.SqlTypeName;
3939
import org.apache.calcite.sql.type.SqlTypeUtil;
40+
import org.apache.calcite.tools.RelBuilder;
4041
import org.apache.calcite.util.DateString;
4142
import org.apache.calcite.util.TimeString;
4243
import org.apache.calcite.util.TimestampString;
44+
import org.apache.commons.lang3.tuple.Pair;
4345
import org.apache.logging.log4j.util.Strings;
4446
import org.opensearch.sql.ast.AbstractNodeVisitor;
47+
import org.opensearch.sql.ast.expression.AggregateFunction;
4548
import org.opensearch.sql.ast.expression.Alias;
4649
import org.opensearch.sql.ast.expression.And;
4750
import org.opensearch.sql.ast.expression.Between;
@@ -72,6 +75,8 @@
7275
import org.opensearch.sql.ast.expression.subquery.InSubquery;
7376
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
7477
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
78+
import org.opensearch.sql.ast.tree.Sort.SortOption;
79+
import org.opensearch.sql.ast.tree.Sort.SortOrder;
7580
import org.opensearch.sql.ast.tree.UnresolvedPlan;
7681
import org.opensearch.sql.calcite.plan.rel.LogicalSystemLimit;
7782
import org.opensearch.sql.calcite.plan.rel.LogicalSystemLimit.SystemLimitType;
@@ -563,47 +568,96 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
563568

564569
@Override
565570
public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext context) {
566-
Function windowFunction = (Function) node.getFunction();
567-
List<RexNode> arguments =
568-
windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
571+
// SQL emits AggregateFunction for aggregate-as-window (e.g., SUM(x) OVER); PPL emits Function.
572+
final String funcName;
573+
final List<RexNode> arguments;
574+
final boolean isDistinct;
575+
if (node.getFunction() instanceof AggregateFunction aggFunc) {
576+
funcName = aggFunc.getFuncName();
577+
isDistinct = Boolean.TRUE.equals(aggFunc.getDistinct());
578+
List<UnresolvedExpression> argExprs = new ArrayList<>();
579+
if (aggFunc.getField() != null) {
580+
argExprs.add(aggFunc.getField());
581+
}
582+
argExprs.addAll(aggFunc.getArgList());
583+
arguments = argExprs.stream().map(arg -> analyze(arg, context)).toList();
584+
} else {
585+
Function windowFunction = (Function) node.getFunction();
586+
funcName = windowFunction.getFuncName();
587+
isDistinct = false;
588+
arguments = windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
589+
}
569590
List<RexNode> partitions =
570591
node.getPartitionByList().stream()
571592
.map(arg -> analyze(arg, context))
572593
.map(this::extractRexNodeFromAlias)
573594
.toList();
574-
return BuiltinFunctionName.ofWindowFunction(windowFunction.getFuncName())
595+
List<RexNode> orderKeys = translateOrderKeys(node.getSortList(), context);
596+
return BuiltinFunctionName.ofWindowFunction(funcName)
575597
.map(
576598
functionName -> {
577599
RexNode field = arguments.isEmpty() ? null : arguments.getFirst();
578600
List<RexNode> args =
579601
(arguments.isEmpty() || arguments.size() == 1)
580602
? Collections.emptyList()
581603
: arguments.subList(1, arguments.size());
604+
// ROW_NUMBER takes no field/args and isn't in aggFunctionRegistry,
605+
// so skip aggregate signature validation.
606+
if (functionName == BuiltinFunctionName.ROW_NUMBER) {
607+
return PlanUtils.makeOver(
608+
context,
609+
functionName,
610+
field,
611+
args,
612+
partitions,
613+
orderKeys,
614+
node.getWindowFrame());
615+
}
582616
List<RexNode> nodes =
583617
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
584618
functionName, field, args, context.rexBuilder);
585619
return nodes != null
586620
? PlanUtils.makeOver(
587621
context,
588622
functionName,
623+
isDistinct,
589624
nodes.getFirst(),
590625
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
591626
partitions,
592-
List.of(),
627+
orderKeys,
593628
node.getWindowFrame())
594629
: PlanUtils.makeOver(
595630
context,
596631
functionName,
632+
isDistinct,
597633
field,
598634
args,
599635
partitions,
600-
List.of(),
636+
orderKeys,
601637
node.getWindowFrame());
602638
})
603639
.orElseThrow(
604-
() ->
605-
new UnsupportedOperationException(
606-
"Unexpected window function: " + windowFunction.getFuncName()));
640+
() -> new UnsupportedOperationException("Unexpected window function: " + funcName));
641+
}
642+
643+
private List<RexNode> translateOrderKeys(
644+
List<Pair<SortOption, UnresolvedExpression>> sortList, CalcitePlanContext context) {
645+
RelBuilder b = context.relBuilder;
646+
return sortList.stream()
647+
.map(
648+
p -> {
649+
SortOption opt = p.getLeft();
650+
RexNode field = analyze(p.getRight(), context);
651+
if (opt.getSortOrder() == SortOrder.DESC) {
652+
field = b.desc(field);
653+
}
654+
return switch (opt.getNullOrder()) {
655+
case NULL_LAST -> b.nullsLast(field);
656+
case NULL_FIRST -> b.nullsFirst(field);
657+
default -> field;
658+
};
659+
})
660+
.toList();
607661
}
608662

609663
/** extract the expression of Alias from a node */

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ static RexNode makeOver(
171171
List<RexNode> partitions,
172172
List<RexNode> orderKeys,
173173
@Nullable WindowFrame windowFrame) {
174+
return makeOver(
175+
context, functionName, false, field, argList, partitions, orderKeys, windowFrame);
176+
}
177+
178+
static RexNode makeOver(
179+
CalcitePlanContext context,
180+
BuiltinFunctionName functionName,
181+
boolean distinct,
182+
RexNode field,
183+
List<RexNode> argList,
184+
List<RexNode> partitions,
185+
List<RexNode> orderKeys,
186+
@Nullable WindowFrame windowFrame) {
174187
if (windowFrame == null) {
175188
windowFrame = WindowFrame.rowsUnbounded();
176189
}
@@ -226,7 +239,7 @@ static RexNode makeOver(
226239
upperBound);
227240
default:
228241
return withOver(
229-
makeAggCall(context, functionName, false, field, argList),
242+
makeAggCall(context, functionName, distinct, field, argList),
230243
partitions,
231244
orderKeys,
232245
rows,

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ public enum BuiltinFunctionName {
425425
.put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
426426
.put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
427427
.put("pattern", BuiltinFunctionName.INTERNAL_PATTERN)
428+
.put("row_number", BuiltinFunctionName.ROW_NUMBER)
428429
.build();
429430

430431
public static Optional<BuiltinFunctionName> of(String str) {

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_5;
9393
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_PG_4;
9494
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_TRANSLATE3;
95+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ISNULL;
9596
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_BLANK;
9697
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_EMPTY;
9798
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL;
@@ -1192,6 +1193,8 @@ void populate() {
11921193
IS_PRESENT, SqlStdOperatorTable.IS_NOT_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
11931194
registerOperator(
11941195
IS_NULL, SqlStdOperatorTable.IS_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
1196+
registerOperator(
1197+
ISNULL, SqlStdOperatorTable.IS_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
11951198

11961199
// Register implementation.
11971200
// Note, make the implementation an individual class if too complex.

0 commit comments

Comments
 (0)