Skip to content

Commit 5392d62

Browse files
authored
[Error Enhancement] Fix NPE on case() with incompatible branch types (#5575)
Signed-off-by: Jialiang Liang <jiallian@amazon.com>
1 parent 4c4166b commit 5392d62

3 files changed

Lines changed: 46 additions & 1 deletion

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ public RexNode visitCast(Cast node, CalcitePlanContext context) {
835835
@Override
836836
public RexNode visitCase(Case node, CalcitePlanContext context) {
837837
List<RexNode> caseOperands = new ArrayList<>();
838+
List<RelDataType> resultTypes = new ArrayList<>();
838839
for (When when : node.getWhenClauses()) {
839840
RexNode condition = analyze(when.getCondition(), context);
840841
if (!SqlTypeUtil.isBoolean(condition.getType())) {
@@ -843,11 +844,22 @@ public RexNode visitCase(Case node, CalcitePlanContext context) {
843844
"Condition expected a boolean type, but got %s", condition.getType()));
844845
}
845846
caseOperands.add(condition);
846-
caseOperands.add(analyze(when.getResult(), context));
847+
RexNode result = analyze(when.getResult(), context);
848+
caseOperands.add(result);
849+
resultTypes.add(result.getType());
847850
}
848851
RexNode elseExpr =
849852
node.getElseClause().map(e -> analyze(e, context)).orElse(context.relBuilder.literal(null));
850853
caseOperands.add(elseExpr);
854+
resultTypes.add(elseExpr.getType());
855+
856+
// Pre-validate the THEN/ELSE result types so an unsupertyped mix surfaces as a clean
857+
// 400 here instead of an opaque NPE deep in Calcite's makeCall return-type inference.
858+
RelDataType commonType = context.rexBuilder.getTypeFactory().leastRestrictive(resultTypes);
859+
if (commonType == null) {
860+
throw new ExpressionEvaluationException(
861+
StringUtils.format("case branches must have a common type, but got %s", resultTypes));
862+
}
851863
return context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
852864
}
853865

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,21 @@ public void testNestedCaseAggWithAutoDateHistogram() throws IOException {
537537
schema("flags", "bigint"));
538538
verifyNumOfRows(actual2, 32);
539539
}
540+
541+
/** Case with no common branch supertype must return a clean 4xx, not a 500. */
542+
@Test
543+
public void testCaseWithIncompatibleBranchTypesRejectsCleanly() {
544+
org.opensearch.client.ResponseException e =
545+
org.junit.Assert.assertThrows(
546+
org.opensearch.client.ResponseException.class,
547+
() ->
548+
executeQuery(
549+
String.format(
550+
"source=%s | eval x = case(age > 30, 'old', age > 20, 1 else 0.0) | fields"
551+
+ " x",
552+
TEST_INDEX_BANK)));
553+
org.junit.Assert.assertTrue(
554+
"expected 400 status, got: " + e.getMessage(),
555+
e.getMessage().contains("status line [HTTP/1.1 400"));
556+
}
540557
}

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCaseFunctionTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55

66
package org.opensearch.sql.ppl.calcite;
77

8+
import static org.junit.Assert.assertThrows;
9+
import static org.junit.Assert.assertTrue;
10+
811
import org.apache.calcite.rel.RelNode;
912
import org.apache.calcite.test.CalciteAssert;
1013
import org.junit.Test;
14+
import org.opensearch.sql.exception.ExpressionEvaluationException;
1115

1216
public class CalcitePPLCaseFunctionTest extends CalcitePPLAbstractTest {
1317

@@ -101,4 +105,16 @@ public void testCaseWhenInSubquery() {
101105
+ "FROM `scott`.`EMP`)";
102106
verifyPPLToSparkSQL(root, expectedSparkSql);
103107
}
108+
109+
/** Case branches with no common supertype must be rejected cleanly, not NPE. */
110+
@Test
111+
public void testCaseWithIncompatibleBranchTypesRejectsCleanly() {
112+
String ppl =
113+
"source=EMP | eval x = case(DEPTNO > 20, 'big'," + " DEPTNO > 10, 1 else 0.0) | fields x";
114+
ExpressionEvaluationException e =
115+
assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl));
116+
assertTrue(
117+
"expected message to list incompatible types, got: " + e.getMessage(),
118+
e.getMessage().contains("case branches must have a common type"));
119+
}
104120
}

0 commit comments

Comments
 (0)