1515import static org .opensearch .sql .ast .tree .Sort .SortOrder .ASC ;
1616import static org .opensearch .sql .ast .tree .Sort .SortOrder .DESC ;
1717import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_FOR_DEDUP ;
18- import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME ;
1918import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME_MAIN ;
2019import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME_SUBSEARCH ;
20+ import static org .opensearch .sql .calcite .utils .PlanUtils .ROW_NUMBER_COLUMN_NAME_TOP_RARE ;
2121import static org .opensearch .sql .calcite .utils .PlanUtils .getRelation ;
2222import static org .opensearch .sql .calcite .utils .PlanUtils .getRexCall ;
2323import static org .opensearch .sql .calcite .utils .PlanUtils .transformPlanToAttachChild ;
@@ -910,14 +910,12 @@ private boolean isCountField(RexCall call) {
910910 * @param groupExprList group by expression list
911911 * @param aggExprList aggregate expression list
912912 * @param context CalcitePlanContext
913- * @param hintBucketNonNull adda bucket nullable hint on LogicalAggregate if set
914913 * @return Pair of (group-by list, field list, aggregate list)
915914 */
916915 private Pair <List <RexNode >, List <AggCall >> aggregateWithTrimming (
917916 List <UnresolvedExpression > groupExprList ,
918917 List <UnresolvedExpression > aggExprList ,
919- CalcitePlanContext context ,
920- boolean hintBucketNonNull ) {
918+ CalcitePlanContext context ) {
921919 Pair <List <RexNode >, List <AggCall >> resolved =
922920 resolveAttributesForAggregation (groupExprList , aggExprList , context );
923921 List <RexNode > resolvedGroupByList = resolved .getLeft ();
@@ -1021,7 +1019,6 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10211019 List <String > intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation (reResolved .getLeft ());
10221020 context .relBuilder .aggregate (
10231021 context .relBuilder .groupKey (reResolved .getLeft ()), reResolved .getRight ());
1024- if (hintBucketNonNull ) hintBucketNonNullOnAggregate (context .relBuilder );
10251022 // During aggregation, Calcite projects both input dependencies and output group-by fields.
10261023 // When names conflict, Calcite adds numeric suffixes (e.g., "value0").
10271024 // Apply explicit renaming to restore the intended aliases.
@@ -1030,24 +1027,6 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10301027 return Pair .of (reResolved .getLeft (), reResolved .getRight ());
10311028 }
10321029
1033- private void hintBucketNonNullOnAggregate (RelBuilder relBuilder ) {
1034- final RelHint statHits =
1035- RelHint .builder ("stats_args" ).hintOption (Argument .BUCKET_NULLABLE , "false" ).build ();
1036- assert relBuilder .peek () instanceof LogicalAggregate
1037- : "Stats hits should be added to LogicalAggregate" ;
1038- relBuilder .hints (statHits );
1039- relBuilder
1040- .getCluster ()
1041- .setHintStrategies (
1042- HintStrategyTable .builder ()
1043- .hintStrategy (
1044- "stats_args" ,
1045- (hint , rel ) -> {
1046- return rel instanceof LogicalAggregate ;
1047- })
1048- .build ());
1049- }
1050-
10511030 /**
10521031 * Imitates {@code Registrar.registerExpression} of {@link RelBuilder} to derive the output order
10531032 * of group-by keys after aggregation.
@@ -1162,7 +1141,10 @@ private void visitAggregation(Aggregation node, CalcitePlanContext context, bool
11621141 }
11631142
11641143 Pair <List <RexNode >, List <AggCall >> aggregationAttributes =
1165- aggregateWithTrimming (groupExprList , aggExprList , context , toAddHintsOnAggregate );
1144+ aggregateWithTrimming (groupExprList , aggExprList , context );
1145+ if (toAddHintsOnAggregate ) {
1146+ addIgnoreNullBucketHintToAggregate (context );
1147+ }
11661148
11671149 // schema reordering
11681150 List <RexNode > outputFields = context .relBuilder .fields ();
@@ -1883,9 +1865,8 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) {
18831865 @ Override
18841866 public RelNode visitRareTopN (RareTopN node , CalcitePlanContext context ) {
18851867 visitChildren (node , context );
1886-
1887- ArgumentMap arguments = ArgumentMap .of (node .getArguments ());
1888- String countFieldName = (String ) arguments .get ("countField" ).getValue ();
1868+ ArgumentMap argumentMap = ArgumentMap .of (node .getArguments ());
1869+ String countFieldName = (String ) argumentMap .get (RareTopN .Option .countField .name ()).getValue ();
18891870 if (context .relBuilder .peek ().getRowType ().getFieldNames ().contains (countFieldName )) {
18901871 throw new IllegalArgumentException (
18911872 "Field `"
@@ -1900,7 +1881,26 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
19001881 groupExprList .addAll (fieldList );
19011882 List <UnresolvedExpression > aggExprList =
19021883 List .of (AstDSL .alias (countFieldName , AstDSL .aggregate ("count" , null )));
1903- aggregateWithTrimming (groupExprList , aggExprList , context , false );
1884+
1885+ // if usenull=false, add a isNotNull before Aggregate and the hint to this Aggregate
1886+ Boolean bucketNullable = (Boolean ) argumentMap .get (RareTopN .Option .useNull .name ()).getValue ();
1887+ boolean toAddHintsOnAggregate = false ;
1888+ if (!bucketNullable && !groupExprList .isEmpty ()) {
1889+ toAddHintsOnAggregate = true ;
1890+ // add isNotNull filter before aggregation to filter out null bucket
1891+ List <RexNode > groupByList =
1892+ groupExprList .stream ().map (expr -> rexVisitor .analyze (expr , context )).toList ();
1893+ context .relBuilder .filter (
1894+ PlanUtils .getSelectColumns (groupByList ).stream ()
1895+ .map (context .relBuilder ::field )
1896+ .map (context .relBuilder ::isNotNull )
1897+ .toList ());
1898+ }
1899+ aggregateWithTrimming (groupExprList , aggExprList , context );
1900+
1901+ if (toAddHintsOnAggregate ) {
1902+ addIgnoreNullBucketHintToAggregate (context );
1903+ }
19041904
19051905 // 2. add a window column
19061906 List <RexNode > partitionKeys = rexVisitor .analyze (node .getGroupExprList (), context );
@@ -1920,26 +1920,46 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
19201920 List .of (countField ),
19211921 WindowFrame .toCurrentRow ());
19221922 context .relBuilder .projectPlus (
1923- context .relBuilder .alias (rowNumberWindowOver , ROW_NUMBER_COLUMN_NAME ));
1923+ context .relBuilder .alias (rowNumberWindowOver , ROW_NUMBER_COLUMN_NAME_TOP_RARE ));
19241924
19251925 // 3. filter row_number() <= k in each partition
1926- Integer N = ( Integer ) arguments . get ( "noOfResults" ). getValue ();
1926+ int k = node . getNoOfResults ();
19271927 context .relBuilder .filter (
19281928 context .relBuilder .lessThanOrEqual (
1929- context .relBuilder .field (ROW_NUMBER_COLUMN_NAME ), context .relBuilder .literal (N )));
1929+ context .relBuilder .field (ROW_NUMBER_COLUMN_NAME_TOP_RARE ),
1930+ context .relBuilder .literal (k )));
19301931
19311932 // 4. project final output. the default output is group by list + field list
1932- Boolean showCount = (Boolean ) arguments .get (" showCount" ).getValue ();
1933+ Boolean showCount = (Boolean ) argumentMap .get (RareTopN . Option . showCount . name () ).getValue ();
19331934 if (showCount ) {
1934- context .relBuilder .projectExcept (context .relBuilder .field (ROW_NUMBER_COLUMN_NAME ));
1935+ context .relBuilder .projectExcept (context .relBuilder .field (ROW_NUMBER_COLUMN_NAME_TOP_RARE ));
19351936 } else {
19361937 context .relBuilder .projectExcept (
1937- context .relBuilder .field (ROW_NUMBER_COLUMN_NAME ),
1938+ context .relBuilder .field (ROW_NUMBER_COLUMN_NAME_TOP_RARE ),
19381939 context .relBuilder .field (countFieldName ));
19391940 }
19401941 return context .relBuilder .peek ();
19411942 }
19421943
1944+ private static void addIgnoreNullBucketHintToAggregate (CalcitePlanContext context ) {
1945+ final RelHint statHits =
1946+ RelHint .builder ("stats_args" ).hintOption (Argument .BUCKET_NULLABLE , "false" ).build ();
1947+ assert context .relBuilder .peek () instanceof LogicalAggregate
1948+ : "Stats hits should be added to LogicalAggregate" ;
1949+ context .relBuilder .hints (statHits );
1950+ context
1951+ .relBuilder
1952+ .getCluster ()
1953+ .setHintStrategies (
1954+ HintStrategyTable .builder ()
1955+ .hintStrategy (
1956+ "stats_args" ,
1957+ (hint , rel ) -> {
1958+ return rel instanceof LogicalAggregate ;
1959+ })
1960+ .build ());
1961+ }
1962+
19431963 @ Override
19441964 public RelNode visitTableFunction (TableFunction node , CalcitePlanContext context ) {
19451965 throw new CalciteUnsupportedException ("Table function is unsupported in Calcite" );
@@ -2242,7 +2262,7 @@ public RelNode visitTimechart(
22422262 try {
22432263 // Step 1: Initial aggregation - IMPORTANT: order is [spanExpr, byField]
22442264 groupExprList = Arrays .asList (spanExpr , byField );
2245- aggregateWithTrimming (groupExprList , List .of (node .getAggregateFunction ()), context , false );
2265+ aggregateWithTrimming (groupExprList , List .of (node .getAggregateFunction ()), context );
22462266
22472267 // First rename the timestamp field (2nd to last) to @timestamp
22482268 List <String > fieldNames = context .relBuilder .peek ().getRowType ().getFieldNames ();
0 commit comments