Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ public RexNode visitCast(Cast node, CalcitePlanContext context) {
@Override
public RexNode visitCase(Case node, CalcitePlanContext context) {
List<RexNode> caseOperands = new ArrayList<>();
List<RelDataType> resultTypes = new ArrayList<>();
for (When when : node.getWhenClauses()) {
RexNode condition = analyze(when.getCondition(), context);
if (!SqlTypeUtil.isBoolean(condition.getType())) {
Expand All @@ -843,11 +844,22 @@ public RexNode visitCase(Case node, CalcitePlanContext context) {
"Condition expected a boolean type, but got %s", condition.getType()));
}
caseOperands.add(condition);
caseOperands.add(analyze(when.getResult(), context));
RexNode result = analyze(when.getResult(), context);
caseOperands.add(result);
resultTypes.add(result.getType());
}
RexNode elseExpr =
node.getElseClause().map(e -> analyze(e, context)).orElse(context.relBuilder.literal(null));
caseOperands.add(elseExpr);
resultTypes.add(elseExpr.getType());

// Pre-validate the THEN/ELSE result types so an unsupertyped mix surfaces as a clean
// 400 here instead of an opaque NPE deep in Calcite's makeCall return-type inference.
RelDataType commonType = context.rexBuilder.getTypeFactory().leastRestrictive(resultTypes);
if (commonType == null) {
throw new ExpressionEvaluationException(
StringUtils.format("case branches must have a common type, but got %s", resultTypes));
}
return context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,21 @@ public void testNestedCaseAggWithAutoDateHistogram() throws IOException {
schema("flags", "bigint"));
verifyNumOfRows(actual2, 32);
}

/** Case with no common branch supertype must return a clean 4xx, not a 500. */
@Test
public void testCaseWithIncompatibleBranchTypesRejectsCleanly() {
org.opensearch.client.ResponseException e =
org.junit.Assert.assertThrows(
org.opensearch.client.ResponseException.class,
() ->
executeQuery(
String.format(
"source=%s | eval x = case(age > 30, 'old', age > 20, 1 else 0.0) | fields"
+ " x",
TEST_INDEX_BANK)));
org.junit.Assert.assertTrue(
"expected 400 status, got: " + e.getMessage(),
e.getMessage().contains("status line [HTTP/1.1 400"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

package org.opensearch.sql.ppl.calcite;

import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;
import org.opensearch.sql.exception.ExpressionEvaluationException;

public class CalcitePPLCaseFunctionTest extends CalcitePPLAbstractTest {

Expand Down Expand Up @@ -101,4 +105,16 @@ public void testCaseWhenInSubquery() {
+ "FROM `scott`.`EMP`)";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

/** Case branches with no common supertype must be rejected cleanly, not NPE. */
@Test
public void testCaseWithIncompatibleBranchTypesRejectsCleanly() {
String ppl =
"source=EMP | eval x = case(DEPTNO > 20, 'big'," + " DEPTNO > 10, 1 else 0.0) | fields x";
ExpressionEvaluationException e =
assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl));
assertTrue(
"expected message to list incompatible types, got: " + e.getMessage(),
e.getMessage().contains("case branches must have a common type"));
}
}
Loading