diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 44c6d87da1..3c37a11ba5 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -835,6 +835,7 @@ public RexNode visitCast(Cast node, CalcitePlanContext context) { @Override public RexNode visitCase(Case node, CalcitePlanContext context) { List caseOperands = new ArrayList<>(); + List resultTypes = new ArrayList<>(); for (When when : node.getWhenClauses()) { RexNode condition = analyze(when.getCondition(), context); if (!SqlTypeUtil.isBoolean(condition.getType())) { @@ -843,11 +844,22 @@ public RexNode visitCase(Case node, CalcitePlanContext context) { "Condition expected a boolean type, but got %s", condition.getType())); } caseOperands.add(condition); - caseOperands.add(analyze(when.getResult(), context)); + RexNode result = analyze(when.getResult(), context); + caseOperands.add(result); + resultTypes.add(result.getType()); } RexNode elseExpr = node.getElseClause().map(e -> analyze(e, context)).orElse(context.relBuilder.literal(null)); caseOperands.add(elseExpr); + resultTypes.add(elseExpr.getType()); + + // Pre-validate the THEN/ELSE result types so an unsupertyped mix surfaces as a clean + // 400 here instead of an opaque NPE deep in Calcite's makeCall return-type inference. + RelDataType commonType = context.rexBuilder.getTypeFactory().leastRestrictive(resultTypes); + if (commonType == null) { + throw new ExpressionEvaluationException( + StringUtils.format("case branches must have a common type, but got %s", resultTypes)); + } return context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands); } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java index bc9d4388d5..8bba907e2e 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java @@ -537,4 +537,21 @@ public void testNestedCaseAggWithAutoDateHistogram() throws IOException { schema("flags", "bigint")); verifyNumOfRows(actual2, 32); } + + /** Case with no common branch supertype must return a clean 4xx, not a 500. */ + @Test + public void testCaseWithIncompatibleBranchTypesRejectsCleanly() { + org.opensearch.client.ResponseException e = + org.junit.Assert.assertThrows( + org.opensearch.client.ResponseException.class, + () -> + executeQuery( + String.format( + "source=%s | eval x = case(age > 30, 'old', age > 20, 1 else 0.0) | fields" + + " x", + TEST_INDEX_BANK))); + org.junit.Assert.assertTrue( + "expected 400 status, got: " + e.getMessage(), + e.getMessage().contains("status line [HTTP/1.1 400")); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCaseFunctionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCaseFunctionTest.java index 5cc257b0b0..2788f7a9cf 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCaseFunctionTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCaseFunctionTest.java @@ -5,9 +5,13 @@ package org.opensearch.sql.ppl.calcite; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + import org.apache.calcite.rel.RelNode; import org.apache.calcite.test.CalciteAssert; import org.junit.Test; +import org.opensearch.sql.exception.ExpressionEvaluationException; public class CalcitePPLCaseFunctionTest extends CalcitePPLAbstractTest { @@ -101,4 +105,16 @@ public void testCaseWhenInSubquery() { + "FROM `scott`.`EMP`)"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + /** Case branches with no common supertype must be rejected cleanly, not NPE. */ + @Test + public void testCaseWithIncompatibleBranchTypesRejectsCleanly() { + String ppl = + "source=EMP | eval x = case(DEPTNO > 20, 'big'," + " DEPTNO > 10, 1 else 0.0) | fields x"; + ExpressionEvaluationException e = + assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); + assertTrue( + "expected message to list incompatible types, got: " + e.getMessage(), + e.getMessage().contains("case branches must have a common type")); + } }