Skip to content

Commit 563a054

Browse files
committed
feat(calcite): support window functions, RANK/DENSE_RANK, and register ISNULL
Add AggregateFunction handling in visitWindowFunction to support aggregate-based window expressions with DISTINCT and ORDER BY keys. Add translateOrderKeys utility for window ORDER BY translation. Add RANK and DENSE_RANK cases in PlanUtils.makeOver. Register row_number, rank, dense_rank in WINDOW_FUNC_MAPPING. Add isPureWindowFunction helper to skip aggregate validation for pure ranking functions. Pass distinct flag through makeOver call chain. Register ISNULL as alias for IS_NULL in PPLFuncImpTable. Add integration tests for window functions with ORDER BY, RANK, COUNT DISTINCT OVER, and ISNULL. Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent eb0e527 commit 563a054

5 files changed

Lines changed: 177 additions & 10 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,64 @@ SELECT department, count(*)
260260
.assertErrorMessage("Encountered");
261261
}
262262

263+
@Test
264+
public void testSqlWindowFunctionWithOrderBy() {
265+
givenQuery(
266+
"""
267+
SELECT name, SUM(age) OVER (PARTITION BY department ORDER BY id) AS running_sum
268+
FROM catalog.employees\
269+
""")
270+
.assertPlan(
271+
"""
272+
LogicalProject(name=[$1], running_sum=[SUM($2) OVER (PARTITION BY $3 ORDER BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)])
273+
LogicalTableScan(table=[[catalog, employees]])
274+
""");
275+
}
276+
277+
@Test
278+
public void testSqlWindowRankFunction() {
279+
givenQuery(
280+
"""
281+
SELECT name, RANK() OVER (ORDER BY age DESC) AS rnk
282+
FROM catalog.employees\
283+
""")
284+
.assertPlan(
285+
"""
286+
LogicalProject(name=[$1], rnk=[RANK() OVER (ORDER BY $2 DESC)])
287+
LogicalTableScan(table=[[catalog, employees]])
288+
""");
289+
}
290+
291+
@Test
292+
public void testSqlWindowDistinctAggregate() {
293+
givenQuery(
294+
"""
295+
SELECT name, COUNT(DISTINCT department) OVER (PARTITION BY department) AS dist_cnt
296+
FROM catalog.employees\
297+
""")
298+
.assertPlan(
299+
"""
300+
LogicalProject(name=[$1], dist_cnt=[COUNT(DISTINCT $3) OVER (PARTITION BY $3)])
301+
LogicalTableScan(table=[[catalog, employees]])
302+
""");
303+
}
304+
305+
@Test
306+
public void testSqlIsNullFunction() {
307+
// ISNULL(field) — exercises the ISNULL alias registration in PPLFuncImpTable.
308+
// Calcite constant-folds to false since test schema columns are NOT NULL.
309+
givenQuery(
310+
"""
311+
SELECT ISNULL(department) AS is_null
312+
FROM catalog.employees\
313+
""")
314+
.assertPlan(
315+
"""
316+
LogicalProject(is_null=[false])
317+
LogicalTableScan(table=[[catalog, employees]])
318+
""");
319+
}
320+
263321
@Test
264322
public void testSqlLimitOffset() {
265323
givenQuery(

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.apache.calcite.util.TimestampString;
4343
import org.apache.logging.log4j.util.Strings;
4444
import org.opensearch.sql.ast.AbstractNodeVisitor;
45+
import org.opensearch.sql.ast.expression.AggregateFunction;
4546
import org.opensearch.sql.ast.expression.Alias;
4647
import org.opensearch.sql.ast.expression.And;
4748
import org.opensearch.sql.ast.expression.Between;
@@ -563,47 +564,76 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
563564

564565
@Override
565566
public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext context) {
566-
Function windowFunction = (Function) node.getFunction();
567-
List<RexNode> arguments =
568-
windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
567+
String funcName;
568+
List<RexNode> arguments;
569+
final boolean isDistinct;
570+
if (node.getFunction() instanceof AggregateFunction aggFunc) {
571+
funcName = aggFunc.getFuncName();
572+
isDistinct = Boolean.TRUE.equals(aggFunc.getDistinct());
573+
List<UnresolvedExpression> argExprs = new java.util.ArrayList<>();
574+
if (aggFunc.getField() != null) {
575+
argExprs.add(aggFunc.getField());
576+
}
577+
argExprs.addAll(aggFunc.getArgList());
578+
arguments = argExprs.stream().map(arg -> analyze(arg, context)).toList();
579+
} else {
580+
Function windowFunction = (Function) node.getFunction();
581+
funcName = windowFunction.getFuncName();
582+
isDistinct = false;
583+
arguments = windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
584+
}
569585
List<RexNode> partitions =
570586
node.getPartitionByList().stream()
571587
.map(arg -> analyze(arg, context))
572588
.map(this::extractRexNodeFromAlias)
573589
.toList();
574-
return BuiltinFunctionName.ofWindowFunction(windowFunction.getFuncName())
590+
List<RexNode> orderKeys =
591+
PlanUtils.translateOrderKeys(node.getSortList(), expr -> analyze(expr, context), context);
592+
return BuiltinFunctionName.ofWindowFunction(funcName)
575593
.map(
576594
functionName -> {
577595
RexNode field = arguments.isEmpty() ? null : arguments.getFirst();
578596
List<RexNode> args =
579597
(arguments.isEmpty() || arguments.size() == 1)
580598
? Collections.emptyList()
581599
: arguments.subList(1, arguments.size());
600+
// Pure window functions (ROW_NUMBER, RANK, DENSE_RANK) are not registered
601+
// in aggFunctionRegistry, so skip validation for them.
602+
if (BuiltinFunctionName.isPureWindowFunction(functionName)) {
603+
return PlanUtils.makeOver(
604+
context,
605+
functionName,
606+
field,
607+
args,
608+
partitions,
609+
orderKeys,
610+
node.getWindowFrame());
611+
}
582612
List<RexNode> nodes =
583613
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
584614
functionName, field, args, context.rexBuilder);
585615
return nodes != null
586616
? PlanUtils.makeOver(
587617
context,
588618
functionName,
619+
isDistinct,
589620
nodes.getFirst(),
590621
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
591622
partitions,
592-
List.of(),
623+
orderKeys,
593624
node.getWindowFrame())
594625
: PlanUtils.makeOver(
595626
context,
596627
functionName,
628+
isDistinct,
597629
field,
598630
args,
599631
partitions,
600-
List.of(),
632+
orderKeys,
601633
node.getWindowFrame());
602634
})
603635
.orElseThrow(
604-
() ->
605-
new UnsupportedOperationException(
606-
"Unexpected window function: " + windowFunction.getFuncName()));
636+
() -> new UnsupportedOperationException("Unexpected window function: " + funcName));
607637
}
608638

609639
/** extract the expression of Alias from a node */

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import org.opensearch.sql.ast.Node;
7070
import org.opensearch.sql.ast.expression.IntervalUnit;
7171
import org.opensearch.sql.ast.expression.SpanUnit;
72+
import org.opensearch.sql.ast.expression.UnresolvedExpression;
7273
import org.opensearch.sql.ast.expression.WindowBound;
7374
import org.opensearch.sql.ast.expression.WindowFrame;
7475
import org.opensearch.sql.ast.tree.Relation;
@@ -163,9 +164,54 @@ static IntervalUnit spanUnitToIntervalUnit(SpanUnit unit) {
163164
}
164165
}
165166

167+
/**
168+
* Translates a list of (SortOption, UnresolvedExpression) pairs into Calcite RexNodes suitable
169+
* for use as window function ORDER BY keys, applying DESC and NULL FIRST/LAST directives via
170+
* RelBuilder.
171+
*/
172+
static List<RexNode> translateOrderKeys(
173+
List<
174+
org.apache.commons.lang3.tuple.Pair<
175+
org.opensearch.sql.ast.tree.Sort.SortOption, UnresolvedExpression>>
176+
sortList,
177+
java.util.function.Function<UnresolvedExpression, RexNode> analyzer,
178+
CalcitePlanContext context) {
179+
return sortList.stream()
180+
.map(
181+
pair -> {
182+
RexNode sortField = analyzer.apply(pair.getRight());
183+
if (pair.getLeft().getSortOrder()
184+
== org.opensearch.sql.ast.tree.Sort.SortOrder.DESC) {
185+
sortField = context.relBuilder.desc(sortField);
186+
}
187+
if (pair.getLeft().getNullOrder()
188+
== org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST) {
189+
sortField = context.relBuilder.nullsLast(sortField);
190+
} else if (pair.getLeft().getNullOrder()
191+
== org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST) {
192+
sortField = context.relBuilder.nullsFirst(sortField);
193+
}
194+
return sortField;
195+
})
196+
.toList();
197+
}
198+
199+
static RexNode makeOver(
200+
CalcitePlanContext context,
201+
BuiltinFunctionName functionName,
202+
RexNode field,
203+
List<RexNode> argList,
204+
List<RexNode> partitions,
205+
List<RexNode> orderKeys,
206+
@Nullable WindowFrame windowFrame) {
207+
return makeOver(
208+
context, functionName, false, field, argList, partitions, orderKeys, windowFrame);
209+
}
210+
166211
static RexNode makeOver(
167212
CalcitePlanContext context,
168213
BuiltinFunctionName functionName,
214+
boolean distinct,
169215
RexNode field,
170216
List<RexNode> argList,
171217
List<RexNode> partitions,
@@ -216,6 +262,22 @@ static RexNode makeOver(
216262
true,
217263
lowerBound,
218264
upperBound);
265+
case RANK:
266+
return withOver(
267+
context.relBuilder.aggregateCall(SqlStdOperatorTable.RANK),
268+
partitions,
269+
orderKeys,
270+
true,
271+
lowerBound,
272+
upperBound);
273+
case DENSE_RANK:
274+
return withOver(
275+
context.relBuilder.aggregateCall(SqlStdOperatorTable.DENSE_RANK),
276+
partitions,
277+
orderKeys,
278+
true,
279+
lowerBound,
280+
upperBound);
219281
case NTH_VALUE:
220282
return withOver(
221283
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
@@ -226,7 +288,7 @@ static RexNode makeOver(
226288
upperBound);
227289
default:
228290
return withOver(
229-
makeAggCall(context, functionName, false, field, argList),
291+
makeAggCall(context, functionName, distinct, field, argList),
230292
partitions,
231293
orderKeys,
232294
rows,

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ public enum BuiltinFunctionName {
425425
.put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
426426
.put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
427427
.put("pattern", BuiltinFunctionName.INTERNAL_PATTERN)
428+
.put("row_number", BuiltinFunctionName.ROW_NUMBER)
429+
.put("rank", BuiltinFunctionName.RANK)
430+
.put("dense_rank", BuiltinFunctionName.DENSE_RANK)
428431
.build();
429432

430433
public static Optional<BuiltinFunctionName> of(String str) {
@@ -441,6 +444,17 @@ public static Optional<BuiltinFunctionName> ofWindowFunction(String functionName
441444
WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
442445
}
443446

447+
/**
448+
* Pure window functions (no aggregate semantics, take no field argument). They are not registered
449+
* in the aggregate function registry, so callers must skip aggregate validation.
450+
*/
451+
private static final Set<BuiltinFunctionName> PURE_WINDOW_FUNCTIONS =
452+
Set.of(ROW_NUMBER, RANK, DENSE_RANK);
453+
454+
public static boolean isPureWindowFunction(BuiltinFunctionName functionName) {
455+
return PURE_WINDOW_FUNCTIONS.contains(functionName);
456+
}
457+
444458
public static final Set<BuiltinFunctionName> COMPARATORS =
445459
Set.of(
446460
BuiltinFunctionName.EQUAL,

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_5;
9393
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_REGEXP_REPLACE_PG_4;
9494
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_TRANSLATE3;
95+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ISNULL;
9596
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_BLANK;
9697
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_EMPTY;
9798
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL;
@@ -1192,6 +1193,8 @@ void populate() {
11921193
IS_PRESENT, SqlStdOperatorTable.IS_NOT_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
11931194
registerOperator(
11941195
IS_NULL, SqlStdOperatorTable.IS_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
1196+
registerOperator(
1197+
ISNULL, SqlStdOperatorTable.IS_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
11951198

11961199
// Register implementation.
11971200
// Note, make the implementation an individual class if too complex.

0 commit comments

Comments
 (0)