Skip to content

Commit a797906

Browse files
committed
Unconditionally rewrite COUNT() to COUNT(*) in sql level to allow type inference (1701/2015)
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 73f2f1b commit a797906

1 file changed

Lines changed: 23 additions & 2 deletions

File tree

core/src/main/java/org/opensearch/sql/executor/QueryService.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@
3333
import org.apache.calcite.rel.rules.FilterMergeRule;
3434
import org.apache.calcite.runtime.CalciteContextException;
3535
import org.apache.calcite.schema.SchemaPlus;
36+
import org.apache.calcite.sql.SqlBasicCall;
37+
import org.apache.calcite.sql.SqlCall;
3638
import org.apache.calcite.sql.SqlIdentifier;
3739
import org.apache.calcite.sql.SqlNode;
38-
import org.apache.calcite.sql.dialect.MysqlSqlDialect;
40+
import org.apache.calcite.sql.dialect.SparkSqlDialect;
41+
import org.apache.calcite.sql.fun.SqlCountAggFunction;
42+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3943
import org.apache.calcite.sql.parser.SqlParser;
4044
import org.apache.calcite.sql.util.SqlShuttle;
4145
import org.apache.calcite.sql.validate.SqlValidator;
@@ -82,7 +86,7 @@ public class QueryService {
8286
private DataSourceService dataSourceService;
8387
private Settings settings;
8488
private static final PplRelToSqlNodeConverter rel2sql =
85-
new PplRelToSqlNodeConverter(MysqlSqlDialect.DEFAULT);
89+
new PplRelToSqlNodeConverter(SparkSqlDialect.DEFAULT);
8690

8791
@Getter(lazy = true)
8892
private final CalciteRelNodeVisitor relNodeVisitor = new CalciteRelNodeVisitor(dataSourceService);
@@ -336,6 +340,23 @@ public SqlNode visit(SqlIdentifier id) {
336340
}
337341
return id;
338342
}
343+
344+
@Override
345+
public @org.checkerframework.checker.nullness.qual.Nullable SqlNode visit(
346+
SqlCall call) {
347+
if (call.getOperator() instanceof SqlCountAggFunction
348+
&& call.getOperandList().isEmpty()) {
349+
// Convert COUNT() to COUNT(*) so that SqlCall.isCountStar() resolves to True
350+
// This is useful when deriving the return types in SqlCountAggFunction#deriveType
351+
call =
352+
new SqlBasicCall(
353+
SqlStdOperatorTable.COUNT,
354+
List.of(SqlIdentifier.STAR),
355+
call.getParserPosition(),
356+
call.getFunctionQuantifier());
357+
}
358+
return super.visit(call);
359+
}
339360
});
340361

341362
SqlValidator validator = context.getValidator();

0 commit comments

Comments
 (0)