diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java index a6d57ea01f6..ffd1dc3c229 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java @@ -63,6 +63,7 @@ import org.apache.calcite.plan.RelOptSchema; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptTable.ViewExpander; +import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; @@ -74,6 +75,7 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.RelShuttle; import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.hint.HintStrategyTable; import org.apache.calcite.rel.logical.LogicalTableScan; import org.apache.calcite.rel.rules.FilterMergeRule; import org.apache.calcite.rel.type.RelDataType; @@ -367,6 +369,36 @@ protected SqlToRelConverter getSqlToRelConverter( return new OpenSearchSqlToRelConverter( this, validator, catalogReader, this.cluster, convertletTable, config); } + + @Override + protected RelRoot trimUnusedFields(RelRoot root) { + final SqlToRelConverter.Config config = + SqlToRelConverter.config() + .withTrimUnusedFields(shouldTrim(root.rel)) + .withExpand(THREAD_EXPAND.get()) + .withInSubQueryThreshold(requireNonNull(THREAD_INSUBQUERY_THRESHOLD.get())); + // PPL analyzes into a pre-built RelNode before prepareStatement(rel). Reuse the incoming + // RelNode's cluster here so prepare-time trimming does not create replacement nodes under a + // different planner than the rest of the tree. + final SqlToRelConverter converter = + new OpenSearchSqlToRelConverter( + this, + getSqlValidator(), + catalogReader, + root.rel.getCluster(), + convertletTable, + config); + final boolean ordered = !root.collation.getFieldCollations().isEmpty(); + final boolean dml = SqlKind.DML.contains(root.kind); + return root.withRel(converter.trimUnusedFields(dml || ordered, root.rel)); + } + + private static boolean shouldTrim(RelNode rootRel) { + // For now, don't trim if there are more than 3 joins. The projects + // near the leaves created by trim migrate past joins and seem to + // prevent join-reordering. + return THREAD_TRIM.get() || RelOptUtil.countJoins(rootRel) < 2; + } } public static class OpenSearchSqlToRelConverter extends SqlToRelConverter { @@ -379,22 +411,51 @@ public OpenSearchSqlToRelConverter( RelOptCluster cluster, SqlRexConvertletTable convertletTable, Config config) { - super(viewExpander, validator, catalogReader, cluster, convertletTable, config); + this( + viewExpander, + validator, + catalogReader, + cluster, + convertletTable, + preserveHintStrategies(cluster, config), + true); + } + + private OpenSearchSqlToRelConverter( + ViewExpander viewExpander, + @Nullable SqlValidator validator, + CatalogReader catalogReader, + RelOptCluster cluster, + SqlRexConvertletTable convertletTable, + Config effectiveConfig, + boolean ignored) { + super(viewExpander, validator, catalogReader, cluster, convertletTable, effectiveConfig); this.relBuilder = - config + effectiveConfig .getRelBuilderFactory() .create( cluster, validator != null ? validator.getCatalogReader().unwrap(RelOptSchema.class) : null) - .transform(config.getRelBuilderConfigTransform()); + .transform(effectiveConfig.getRelBuilderConfigTransform()); } @Override protected RelFieldTrimmer newFieldTrimmer() { return new OpenSearchRelFieldTrimmer(validator, this.relBuilder); } + + // SqlToRelConverter always installs the hint strategy table from its config onto the cluster. + // When prepare-time trimming reuses an incoming RelNode cluster, preserve any PPL-specific + // aggregate hint strategies that were already registered during analysis. + private static Config preserveHintStrategies(RelOptCluster cluster, Config config) { + if (config.getHintStrategyTable() == HintStrategyTable.EMPTY + && cluster.getHintStrategies() != HintStrategyTable.EMPTY) { + return config.withHintStrategyTable(cluster.getHintStrategies()); + } + return config; + } } public static class OpenSearchRelRunners { diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendPipeCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendPipeCommandIT.java index d25d3ca80db..6ae37a027ba 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendPipeCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAppendPipeCommandIT.java @@ -87,4 +87,74 @@ public void testAppendpipeWithConflictTypeColumn() throws IOException { TEST_INDEX_ACCOUNT))); assertTrue(exception.getMessage().contains("due to incompatible types")); } + + /** Regression test: double appendpipe with different aggregations (issue #5173). */ + @Test + public void testDoubleAppendPipe() throws IOException { + JSONObject actual = + executeQuery( + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age by gender" + + " | appendpipe [ stats avg(sum_age) as avg_sum_age ]" + + " | appendpipe [ stats max(sum_age) as max_sum_age ]", + TEST_INDEX_ACCOUNT)); + verifySchemaInOrder( + actual, + schema("sum_age", "bigint"), + schema("gender", "string"), + schema("avg_sum_age", "double"), + schema("max_sum_age", "bigint")); + // 2 original rows + 1 avg row + 1 max row + verifyDataRows( + actual, + rows(14947, "F", null, null), + rows(15224, "M", null, null), + rows(null, null, 15085.5, null), + rows(null, null, null, 15224)); + } + + /** Regression test: triple appendpipe with different aggregations (issue #5173). */ + @Test + public void testTripleAppendPipe() throws IOException { + JSONObject actual = + executeQuery( + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age by gender" + + " | appendpipe [ stats avg(sum_age) as avg_sum_age ]" + + " | appendpipe [ stats max(sum_age) as max_sum_age ]" + + " | appendpipe [ stats min(sum_age) as min_sum_age ]", + TEST_INDEX_ACCOUNT)); + verifySchemaInOrder( + actual, + schema("sum_age", "bigint"), + schema("gender", "string"), + schema("avg_sum_age", "double"), + schema("max_sum_age", "bigint"), + schema("min_sum_age", "bigint")); + // 2 original rows + 1 avg + 1 max + 1 min + verifyDataRows( + actual, + rows(14947, "F", null, null, null), + rows(15224, "M", null, null, null), + rows(null, null, 15085.5, null, null), + rows(null, null, null, 15224, null), + rows(null, null, null, null, 14947)); + } + + /** Regression test: double appendpipe with non-aggregation (filter) subpipeline. */ + @Test + public void testDoubleAppendPipeWithFilter() throws IOException { + JSONObject actual = + executeQuery( + String.format( + Locale.ROOT, + "source=%s | stats sum(age) as sum_age by gender" + + " | appendpipe [ where gender = 'F' ]" + + " | appendpipe [ where gender = 'M' ]", + TEST_INDEX_ACCOUNT)); + // 2 original + 1 (F filter from original) + 1 (M filter from cumulative 3 rows) + verifyDataRows(actual, rows(14947, "F"), rows(15224, "M"), rows(14947, "F"), rows(15224, "M")); + } } diff --git a/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5173.yml b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5173.yml new file mode 100644 index 00000000000..3db25e24f56 --- /dev/null +++ b/integ-test/src/yamlRestTest/resources/rest-api-spec/test/issues/5173.yml @@ -0,0 +1,99 @@ +setup: + - do: + query.settings: + body: + transient: + plugins.calcite.enabled: true + + - do: + indices.create: + index: issue5173 + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + gender: + type: keyword + age: + type: integer + + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "issue5173", "_id": "1"}}' + - '{"gender": "F", "age": 10}' + - '{"index": {"_index": "issue5173", "_id": "2"}}' + - '{"gender": "F", "age": 20}' + - '{"index": {"_index": "issue5173", "_id": "3"}}' + - '{"gender": "M", "age": 30}' + - '{"index": {"_index": "issue5173", "_id": "4"}}' + - '{"gender": "M", "age": 40}' + +--- +teardown: + - do: + indices.delete: + index: issue5173 + ignore_unavailable: true + - do: + query.settings: + body: + transient: + plugins.calcite.enabled: false + +--- +"Issue 5173: double appendpipe with different aggregations should succeed": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: "source=issue5173 | stats sum(age) as sum_age by gender | appendpipe [ stats avg(sum_age) as avg_sum_age ] | appendpipe [ stats max(sum_age) as max_sum_age ]" + + - match: { total: 4 } + - match: + schema: + - { name: sum_age, type: bigint } + - { name: gender, type: string } + - { name: avg_sum_age, type: double } + - { name: max_sum_age, type: bigint } + - match: + datarows: + - [ 30, "F", null, null ] + - [ 70, "M", null, null ] + - [ null, null, 50.0, null ] + - [ null, null, null, 70 ] + +--- +"Issue 5173: triple appendpipe with different aggregations should succeed": + - skip: + features: + - headers + - do: + headers: + Content-Type: 'application/json' + ppl: + body: + query: "source=issue5173 | stats sum(age) as sum_age by gender | appendpipe [ stats avg(sum_age) as avg_sum_age ] | appendpipe [ stats max(sum_age) as max_sum_age ] | appendpipe [ stats min(sum_age) as min_sum_age ]" + + - match: { total: 5 } + - match: + schema: + - { name: sum_age, type: bigint } + - { name: gender, type: string } + - { name: avg_sum_age, type: double } + - { name: max_sum_age, type: bigint } + - { name: min_sum_age, type: bigint } + - match: + datarows: + - [ 30, "F", null, null, null ] + - [ 70, "M", null, null, null ] + - [ null, null, 50.0, null, null ] + - [ null, null, null, 70, null ] + - [ null, null, null, null, 30 ] diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendPipeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendPipeTest.java index faf944da4a0..56ed409b4d7 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendPipeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendPipeTest.java @@ -59,4 +59,152 @@ public void testAppendPipeWithMergedColumns() { + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + /** + * Regression test: double appendpipe with different aggregations. Result count (16 = 14 + 1 avg + + * 1 max) is verified in integration tests only because RelRunners.run() creates a new planner + * that conflicts with shared RelNode subtrees — a test framework limitation that does not affect + * the production path. + */ + @Test + public void testDoubleAppendPipe() { + String ppl = + "source=EMP | appendpipe [stats avg(SAL) as avg_sal] | appendpipe [stats max(SAL) as" + + " max_sal]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], avg_sal=[$8], max_sal=[null:DECIMAL(7, 2)])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[null:DECIMAL(11, 6)], max_sal=[$0])\n" + + " LogicalAggregate(group=[{}], max_sal=[MAX($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + /** + * Regression test: triple appendpipe with different aggregations. Result count (17 = 14 + 1 avg + + * 1 max + 1 min) is verified in integration tests only — see testDoubleAppendPipe for rationale. + */ + @Test + public void testTripleAppendPipe() { + String ppl = + "source=EMP | appendpipe [stats avg(SAL) as avg_sal] | appendpipe [stats max(SAL) as" + + " max_sal] | appendpipe [stats min(SAL) as min_sal]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], avg_sal=[$8], max_sal=[$9]," + + " min_sal=[null:DECIMAL(7, 2)])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], avg_sal=[$8]," + + " max_sal=[null:DECIMAL(7, 2)])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[null:DECIMAL(11, 6)], max_sal=[$0])\n" + + " LogicalAggregate(group=[{}], max_sal=[MAX($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[null:DECIMAL(11, 6)], max_sal=[null:DECIMAL(7, 2)], min_sal=[$0])\n" + + " LogicalAggregate(group=[{}], min_sal=[MIN($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], avg_sal=[$8]," + + " max_sal=[null:DECIMAL(7, 2)])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[null:DECIMAL(11, 6)], max_sal=[$0])\n" + + " LogicalAggregate(group=[{}], max_sal=[MAX($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3]," + + " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7]," + + " avg_sal=[null:DECIMAL(11, 6)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT]," + + " avg_sal=[$0])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($0)])\n" + + " LogicalProject(SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + /** Regression test: double appendpipe with non-aggregation (filter) subpipeline. */ + @Test + public void testDoubleAppendPipeWithFilter() { + String ppl = "source=EMP | appendpipe [where DEPTNO = 20] | appendpipe [where DEPTNO = 30]"; + RelNode root = getRelNode(ppl); + verifyResultCount(root, 25); // 14 original + 5 (dept 20) + 6 (dept 30) + } }