Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,154 @@ SELECT department, count(*)
""")
.assertErrorMessage("Encountered");
}

@Test
public void testSqlWindowFunctionWithOrderBy() {
givenQuery(
"""
SELECT name, SUM(age) OVER (PARTITION BY department ORDER BY id) AS running_sum
FROM catalog.employees\
""")
.assertPlan(
"""
LogicalProject(name=[$1], running_sum=[SUM($2) OVER (PARTITION BY $3 ORDER BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlWindowRowNumber() {
givenQuery(
"""
SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn
FROM catalog.employees\
""")
.assertPlan(
"""
LogicalProject(name=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $0)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlWindowDistinctAggregate() {
givenQuery(
"""
SELECT name, COUNT(DISTINCT department) OVER (PARTITION BY department) AS dist_cnt
FROM catalog.employees\
""")
.assertPlan(
"""
LogicalProject(name=[$1], dist_cnt=[COUNT(DISTINCT $3) OVER (PARTITION BY $3)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlIsNullFunction() {
// ISNULL(field) — exercises the ISNULL alias registration in PPLFuncImpTable.
// Calcite constant-folds to false since test schema columns are NOT NULL.
givenQuery(
"""
SELECT ISNULL(department) AS is_null
FROM catalog.employees\
""")
.assertPlan(
"""
LogicalProject(is_null=[false])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlLimitOffset() {
givenQuery(
"""
SELECT name
FROM catalog.employees
LIMIT 10 OFFSET 5\
""")
.assertPlan(
"""
LogicalProject(name=[$1])
LogicalSort(offset=[5], fetch=[10])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlAggregateWithAlias() {
givenQuery(
"""
SELECT department, COUNT(*) AS cnt
FROM catalog.employees
GROUP BY department\
""")
.assertPlan(
"""
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(department=[$3])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlGroupByWithoutBucketNullable() {
givenQuery(
"""
SELECT age, COUNT(*) AS cnt
FROM catalog.employees
GROUP BY age\
""")
.assertPlan(
"""
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(age=[$2])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlSelectWithAlias() {
givenQuery(
"""
SELECT age AS employee_age, name AS employee_name
FROM catalog.employees\
""")
.assertPlan(
"""
LogicalProject(employee_age=[$2], employee_name=[$1])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlDerivedTableInFromClause() {
// SELECT ... FROM (SELECT ...) AS t — exercises visitRelationSubquery override.
givenQuery(
"""
SELECT t.id
FROM (SELECT id, name FROM catalog.employees WHERE age > 30) AS t\
""")
.assertPlan(
"""
LogicalProject(t.id=[$0])
LogicalFilter(condition=[>($2, 30)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testSqlSelectWithoutFromClause() {
// SELECT 1 — exercises visitValues dual-table case (single empty row).
givenQuery(
"""
SELECT 1\
""")
.assertPlan(
"""
LogicalSort(sort0=[$0], dir0=[ASC])
LogicalValues(tuples=[[]])
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.ast.tree.Lookup.OutputStrategy;
import org.opensearch.sql.ast.tree.ML;
Expand All @@ -146,6 +147,7 @@
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Regex;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.RelationSubquery;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Replace;
import org.opensearch.sql.ast.tree.ReplacePair;
Expand Down Expand Up @@ -542,7 +544,18 @@ private List<RexNode> expandProjectFields(
.forEach(field -> expandedFields.add(context.relBuilder.field(field)));
}
case Alias alias -> {
expandedFields.add(rexVisitor.analyze(alias, context));
// SQL aggregate aliases (e.g., COUNT(*) AS cnt): reference the already-computed field
// and rebind under the user's alias, since re-analyzing the alias returns null.
if (alias.getDelegated() instanceof AggregateFunction
&& alias.getName() != null
&& currentFields.contains(alias.getName())) {
String displayName =
Strings.isNullOrEmpty(alias.getAlias()) ? alias.getName() : alias.getAlias();
expandedFields.add(
context.relBuilder.alias(context.relBuilder.field(alias.getName()), displayName));
} else {
expandedFields.add(rexVisitor.analyze(alias, context));
}
}
default ->
throw new IllegalStateException(
Expand Down Expand Up @@ -766,6 +779,13 @@ public RelNode visitHead(Head node, CalcitePlanContext context) {
return context.relBuilder.peek();
}

@Override
public RelNode visitLimit(Limit node, CalcitePlanContext context) {
visitChildren(node, context);
context.relBuilder.limit(node.getOffset(), node.getLimit());
return context.relBuilder.peek();
}

/**
* Insert a reversed sort node after finding the original sort in the tree. This rebuilds the tree
* with the reversed sort inserted right after the original sort.
Expand Down Expand Up @@ -1624,7 +1644,9 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
@Override
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
Argument.ArgumentMap statsArgs = Argument.ArgumentMap.of(node.getArgExprList());
Boolean bucketNullable = (Boolean) statsArgs.get(Argument.BUCKET_NULLABLE).getValue();
// SQL aggregations don't carry the PPL-only BUCKET_NULLABLE argument; default to true.
boolean bucketNullable =
(Boolean) statsArgs.getOrDefault(Argument.BUCKET_NULLABLE, Literal.TRUE).getValue();
int nGroup = node.getGroupExprList().size() + (Objects.nonNull(node.getSpan()) ? 1 : 0);
BitSet nonNullGroupMask = new BitSet(nGroup);
if (!bucketNullable) {
Expand Down Expand Up @@ -1931,6 +1953,14 @@ public RelNode visitSubqueryAlias(SubqueryAlias node, CalcitePlanContext context
return context.relBuilder.peek();
}

@Override
public RelNode visitRelationSubquery(RelationSubquery node, CalcitePlanContext context) {
// Handle SQL derived tables in FROM clause: SELECT ... FROM (SELECT ...) AS t.
visitChildren(node, context);
context.relBuilder.as(node.getAliasAsTableName());
return context.relBuilder.peek();
}

@Override
public RelNode visitLookup(Lookup node, CalcitePlanContext context) {
// 1. resolve source side
Expand Down Expand Up @@ -4125,12 +4155,13 @@ public RelNode visitMvExpand(MvExpand mvExpand, CalcitePlanContext context) {

@Override
public RelNode visitValues(Values values, CalcitePlanContext context) {
if (values.getValues() == null || values.getValues().isEmpty()) {
// Accept SQL SELECT without FROM (dual table), encoded as Values([[]]) — one row, zero columns.
List<List<Literal>> rows = values.getValues();
if (rows == null || rows.isEmpty() || (rows.size() == 1 && rows.get(0).isEmpty())) {
context.relBuilder.values(context.relBuilder.getTypeFactory().builder().build());
return context.relBuilder.peek();
} else {
throw new CalciteUnsupportedException("Explicit values node is unsupported in Calcite");
}
throw new CalciteUnsupportedException("Inline VALUES with literal rows is unsupported");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Between;
Expand Down Expand Up @@ -72,6 +75,8 @@
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.Sort.SortOrder;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.plan.rel.LogicalSystemLimit;
import org.opensearch.sql.calcite.plan.rel.LogicalSystemLimit.SystemLimitType;
Expand Down Expand Up @@ -563,47 +568,96 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {

@Override
public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext context) {
Function windowFunction = (Function) node.getFunction();
List<RexNode> arguments =
windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
// SQL emits AggregateFunction for aggregate-as-window (e.g., SUM(x) OVER); PPL emits Function.
final String funcName;
final List<RexNode> arguments;
final boolean isDistinct;
if (node.getFunction() instanceof AggregateFunction aggFunc) {
funcName = aggFunc.getFuncName();
isDistinct = Boolean.TRUE.equals(aggFunc.getDistinct());
List<UnresolvedExpression> argExprs = new ArrayList<>();
if (aggFunc.getField() != null) {
argExprs.add(aggFunc.getField());
}
argExprs.addAll(aggFunc.getArgList());
arguments = argExprs.stream().map(arg -> analyze(arg, context)).toList();
} else {
Function windowFunction = (Function) node.getFunction();
funcName = windowFunction.getFuncName();
isDistinct = false;
arguments = windowFunction.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList();
}
List<RexNode> partitions =
node.getPartitionByList().stream()
.map(arg -> analyze(arg, context))
.map(this::extractRexNodeFromAlias)
.toList();
return BuiltinFunctionName.ofWindowFunction(windowFunction.getFuncName())
List<RexNode> orderKeys = translateOrderKeys(node.getSortList(), context);
return BuiltinFunctionName.ofWindowFunction(funcName)
.map(
functionName -> {
RexNode field = arguments.isEmpty() ? null : arguments.getFirst();
List<RexNode> args =
(arguments.isEmpty() || arguments.size() == 1)
? Collections.emptyList()
: arguments.subList(1, arguments.size());
// ROW_NUMBER takes no field/args and isn't in aggFunctionRegistry,
// so skip aggregate signature validation.
if (functionName == BuiltinFunctionName.ROW_NUMBER) {
return PlanUtils.makeOver(
context,
functionName,
field,
args,
partitions,
orderKeys,
node.getWindowFrame());
}
List<RexNode> nodes =
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
functionName, field, args, context.rexBuilder);
return nodes != null
? PlanUtils.makeOver(
context,
functionName,
isDistinct,
nodes.getFirst(),
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
partitions,
List.of(),
orderKeys,
node.getWindowFrame())
: PlanUtils.makeOver(
context,
functionName,
isDistinct,
field,
args,
partitions,
List.of(),
orderKeys,
node.getWindowFrame());
})
.orElseThrow(
() ->
new UnsupportedOperationException(
"Unexpected window function: " + windowFunction.getFuncName()));
() -> new UnsupportedOperationException("Unexpected window function: " + funcName));
}

private List<RexNode> translateOrderKeys(
List<Pair<SortOption, UnresolvedExpression>> sortList, CalcitePlanContext context) {
RelBuilder b = context.relBuilder;
return sortList.stream()
.map(
p -> {
SortOption opt = p.getLeft();
RexNode field = analyze(p.getRight(), context);
if (opt.getSortOrder() == SortOrder.DESC) {
field = b.desc(field);
}
return switch (opt.getNullOrder()) {
case NULL_LAST -> b.nullsLast(field);
case NULL_FIRST -> b.nullsFirst(field);
default -> field;
};
})
.toList();
}

/** extract the expression of Alias from a node */
Expand Down
Loading
Loading