Skip to content

Commit 2ddff58

Browse files
committed
fix tests
Signed-off-by: Jialiang Liang <jiallian@amazon.com>
1 parent 259097d commit 2ddff58

5 files changed

Lines changed: 99 additions & 30 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ public RelNode analyze(UnresolvedPlan unresolved, CalcitePlanContext context) {
183183
context.enableFilterAccumulation();
184184
try {
185185
unresolved.accept(this, context);
186-
context.flushFilterConditions(); // Flush accumulated conditions before returning
187-
return context.relBuilder.peek(); // Get the result after flushing
186+
context.flushFilterConditions();
187+
return context.relBuilder.peek();
188188
} finally {
189189
context.disableFilterAccumulation();
190190
}
@@ -193,6 +193,17 @@ public RelNode analyze(UnresolvedPlan unresolved, CalcitePlanContext context) {
193193
}
194194
}
195195

196+
/**
197+
* Flushes accumulated filter conditions before schema-changing operations. This prevents
198+
* RexInputRef index mismatches that occur when filters reference field indices from the old
199+
* schema.
200+
*/
201+
private void flushFiltersBeforeSchemaChange(CalcitePlanContext context) {
202+
if (context.isFilterAccumulationEnabled() && context.hasPendingFilterConditions()) {
203+
context.flushFilterConditions();
204+
}
205+
}
206+
196207
@Override
197208
public RelNode visitRelation(Relation node, CalcitePlanContext context) {
198209
DataSourceSchemaIdentifierNameResolver nameResolver =
@@ -404,10 +415,7 @@ private boolean containsSubqueryExpression(Node expr) {
404415
public RelNode visitProject(Project node, CalcitePlanContext context) {
405416
visitChildren(node, context);
406417

407-
// Flush accumulated filter conditions before schema-changing operations
408-
if (context.isFilterAccumulationEnabled() && context.hasPendingFilterConditions()) {
409-
context.flushFilterConditions();
410-
}
418+
flushFiltersBeforeSchemaChange(context);
411419

412420
if (isSingleAllFieldsProject(node)) {
413421
return handleAllFieldsProject(node, context);
@@ -883,6 +891,9 @@ public RelNode visitPatterns(Patterns node, CalcitePlanContext context) {
883891
@Override
884892
public RelNode visitEval(Eval node, CalcitePlanContext context) {
885893
visitChildren(node, context);
894+
895+
flushFiltersBeforeSchemaChange(context);
896+
886897
node.getExpressionList()
887898
.forEach(
888899
expr -> {
@@ -1152,6 +1163,9 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
11521163
/** Visits an aggregation for stats command */
11531164
@Override
11541165
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
1166+
// Flush accumulated filter conditions before schema-changing aggregation operations
1167+
flushFiltersBeforeSchemaChange(context);
1168+
11551169
Argument.ArgumentMap statsArgs = Argument.ArgumentMap.of(node.getArgExprList());
11561170
Boolean bucketNullable =
11571171
(Boolean) statsArgs.getOrDefault(Argument.BUCKET_NULLABLE, Literal.TRUE).getValue();
@@ -2252,10 +2266,26 @@ private RelNode mergeTableAndResolveColumnConflict(
22522266
@Override
22532267
public RelNode visitMultisearch(Multisearch node, CalcitePlanContext context) {
22542268
List<RelNode> subsearchNodes = new ArrayList<>();
2269+
// Save the current filter accumulation state - we'll process each subsearch independently
2270+
boolean wasFilterAccumulationEnabled = context.isFilterAccumulationEnabled();
2271+
22552272
for (UnresolvedPlan subsearch : node.getSubsearches()) {
22562273
UnresolvedPlan prunedSubSearch = subsearch.accept(new EmptySourcePropagateVisitor(), null);
2257-
prunedSubSearch.accept(this, context);
2274+
2275+
// Temporarily disable filter accumulation so each subsearch gets its own independent
2276+
// lifecycle via analyze(). This prevents filter state from bleeding across branches.
2277+
if (wasFilterAccumulationEnabled) {
2278+
context.disableFilterAccumulation();
2279+
}
2280+
2281+
// Use analyze() to let each subsearch determine its own filter accumulation needs
2282+
analyze(prunedSubSearch, context);
22582283
subsearchNodes.add(context.relBuilder.build());
2284+
2285+
// Restore filter accumulation state for the next iteration
2286+
if (wasFilterAccumulationEnabled) {
2287+
context.enableFilterAccumulation();
2288+
}
22592289
}
22602290

22612291
// Use shared schema merging logic that handles type conflicts via field renaming
@@ -3271,8 +3301,12 @@ private RexNode createOptimizedTransliteration(
32713301
* RelNodes. This is used to detect queries with multiple regex/filter operations that could cause
32723302
* deep Filter RelNode chains and memory exhaustion.
32733303
*
3304+
* <p>Stops counting at schema-changing operations (like Aggregation, Project with computed
3305+
* expressions) to avoid enabling filter accumulation across schema boundaries, which would cause
3306+
* RexInputRef index mismatches.
3307+
*
32743308
* @param plan the UnresolvedPlan to analyze
3275-
* @return the count of filtering operations found
3309+
* @return the count of filtering operations found before the first schema-changing operation
32763310
*/
32773311
private int countFilteringOperations(UnresolvedPlan plan) {
32783312
if (plan == null) {
@@ -3282,8 +3316,25 @@ private int countFilteringOperations(UnresolvedPlan plan) {
32823316
int count = 0;
32833317

32843318
// Count this node if it's a filtering operation
3285-
if (plan instanceof Regex || plan instanceof Filter) {
3319+
// BUT: Don't count Filter nodes that contain function calls, as they can cause
3320+
// type mismatches when accumulated and flushed later
3321+
if (plan instanceof Regex) {
32863322
count = 1;
3323+
} else if (plan instanceof Filter) {
3324+
Filter filterNode = (Filter) plan;
3325+
if (!containsFunctionCall(filterNode.getCondition())) {
3326+
count = 1;
3327+
}
3328+
}
3329+
3330+
// Stop counting at schema-changing operations to prevent accumulation across schema boundaries
3331+
// Schema-changing operations include: Aggregation, Eval, Project (with computed expressions),
3332+
// Window, StreamWindow, etc.
3333+
if (plan instanceof Aggregation
3334+
|| plan instanceof Eval
3335+
|| plan instanceof Window
3336+
|| plan instanceof StreamWindow) {
3337+
return count; // Don't recurse into children beyond schema changes
32873338
}
32883339

32893340
// Recursively count filtering operations in children
@@ -3297,4 +3348,29 @@ private int countFilteringOperations(UnresolvedPlan plan) {
32973348

32983349
return count;
32993350
}
3351+
3352+
/**
3353+
* Checks if an expression contains any function calls. Filter expressions with function calls can
3354+
* cause type mismatches when accumulated and flushed later, so we exclude them from filter
3355+
* accumulation.
3356+
*/
3357+
private boolean containsFunctionCall(UnresolvedExpression expr) {
3358+
if (expr == null) {
3359+
return false;
3360+
}
3361+
3362+
if (expr instanceof org.opensearch.sql.ast.expression.Function) {
3363+
return true;
3364+
}
3365+
3366+
// Check children recursively
3367+
for (Node child : expr.getChild()) {
3368+
if (child instanceof UnresolvedExpression
3369+
&& containsFunctionCall((UnresolvedExpression) child)) {
3370+
return true;
3371+
}
3372+
}
3373+
3374+
return false;
3375+
}
33003376
}

integ-test/src/test/resources/expectedOutput/calcite/explain_filter_push.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ calcite:
22
logical: |
33
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
44
LogicalProject(age=[$8])
5-
LogicalFilter(condition=[>($3, 10000)])
6-
LogicalFilter(condition=[<($8, 40)])
7-
LogicalFilter(condition=[>($8, 30)])
8-
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
5+
LogicalFilter(condition=[AND(SEARCH($8, Sarg[(30..40)]), >($3, 10000))])
6+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
97
physical: |
108
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[balance, age], FILTER->AND(SEARCH($1, Sarg[(30..40)]), >($0, 10000)), PROJECT->[age], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":10000,"timeout":"1m","query":{"bool":{"must":[{"range":{"age":{"from":30.0,"to":40.0,"include_lower":false,"include_upper":false,"boost":1.0}}},{"range":{"balance":{"from":10000,"to":null,"include_lower":false,"include_upper":true,"boost":1.0}}}],"adjust_pure_negative":true,"boost":1.0}},"_source":{"includes":["age"],"excludes":[]}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
calcite:
22
logical: |
33
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4-
LogicalFilter(condition=[<($0, DATE('2018-11-09 00:00:00.000000000':VARCHAR))])
5-
LogicalFilter(condition=[>($0, DATE('2016-12-08 00:00:00.123456789':VARCHAR))])
6-
LogicalProject(yyyy-MM-dd=[$83])
7-
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_date_formats]])
4+
LogicalFilter(condition=[AND(>($0, DATE('2016-12-08 00:00:00.123456789':VARCHAR)), <($0, DATE('2018-11-09 00:00:00.000000000':VARCHAR)))])
5+
LogicalProject(yyyy-MM-dd=[$83])
6+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_date_formats]])
87
physical: |
98
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_date_formats]], PushDownContext=[[PROJECT->[yyyy-MM-dd], FILTER->SEARCH($0, Sarg[('2016-12-08':VARCHAR..'2018-11-09':VARCHAR)]:VARCHAR), LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":10000,"timeout":"1m","query":{"range":{"yyyy-MM-dd":{"from":"2016-12-08","to":"2018-11-09","include_lower":false,"include_upper":false,"boost":1.0}}},"_source":{"includes":["yyyy-MM-dd"],"excludes":[]}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])

integ-test/src/test/resources/expectedOutput/calcite/explain_filter_push_compare_timestamp_string.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ calcite:
22
logical: |
33
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
44
LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12])
5-
LogicalFilter(condition=[<($3, TIMESTAMP('2018-11-09 00:00:00.000000000':VARCHAR))])
6-
LogicalFilter(condition=[>($3, TIMESTAMP('2016-12-08 00:00:00.000000000':VARCHAR))])
7-
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])
5+
LogicalFilter(condition=[AND(>($3, TIMESTAMP('2016-12-08 00:00:00.000000000':VARCHAR)), <($3, TIMESTAMP('2018-11-09 00:00:00.000000000':VARCHAR)))])
6+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])
87
physical: |
98
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[PROJECT->[account_number, firstname, address, birthdate, gender, city, lastname, balance, employer, state, age, email, male], FILTER->SEARCH($3, Sarg[('2016-12-08 00:00:00':VARCHAR..'2018-11-09 00:00:00':VARCHAR)]:VARCHAR), LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":10000,"timeout":"1m","query":{"range":{"birthdate":{"from":"2016-12-08T00:00:00.000Z","to":"2018-11-09T00:00:00.000Z","include_lower":false,"include_upper":false,"format":"date_time","boost":1.0}}},"_source":{"includes":["account_number","firstname","address","birthdate","gender","city","lastname","balance","employer","state","age","email","male"],"excludes":[]}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLMultisearchTest.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,29 +183,26 @@ public void testMultisearchWithStats() {
183183
+ " LogicalAggregate(group=[{0}], count=[COUNT()])\n"
184184
+ " LogicalProject(type=[$8])\n"
185185
+ " LogicalUnion(all=[true])\n"
186-
+ " LogicalFilter(condition=[=($7, 10)])\n"
187-
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
186+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
188187
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], type=['accounting':VARCHAR])\n"
188+
+ " LogicalFilter(condition=[=($7, 10)])\n"
189189
+ " LogicalTableScan(table=[[scott, EMP]])\n"
190-
+ " LogicalFilter(condition=[=($7, 20)])\n"
191-
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
190+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
192191
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], type=['research':VARCHAR])\n"
192+
+ " LogicalFilter(condition=[=($7, 20)])\n"
193193
+ " LogicalTableScan(table=[[scott, EMP]])\n";
194194
verifyLogical(root, expectedLogical);
195195

196-
// SparkSQL reflects Filter above Project due to flush logic
197196
String expectedSparkSql =
198197
"SELECT COUNT(*) `count`, `type`\n"
199-
+ "FROM (SELECT *\n"
200198
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
201199
+ " 'accounting' `type`\n"
202-
+ "FROM `scott`.`EMP`) `t`\n"
200+
+ "FROM `scott`.`EMP`\n"
203201
+ "WHERE `DEPTNO` = 10\n"
204202
+ "UNION ALL\n"
205-
+ "SELECT *\n"
206-
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
203+
+ "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
207204
+ " 'research' `type`\n"
208-
+ "FROM `scott`.`EMP`) `t1`\n"
205+
+ "FROM `scott`.`EMP`\n"
209206
+ "WHERE `DEPTNO` = 20) `t3`\n"
210207
+ "GROUP BY `type`";
211208
verifyPPLToSparkSQL(root, expectedSparkSql);

0 commit comments

Comments
 (0)