From 2eca86daca37b8f75eab3d4a4c9b6b0095fdb01c Mon Sep 17 00:00:00 2001 From: zhangliang Date: Fri, 12 Dec 2025 19:23:52 +0800 Subject: [PATCH] Refactor FunctionConverter --- .../expression/ExpressionConverter.java | 2 +- .../expression/impl/FunctionConverter.java | 21 +++++++++---------- .../expression/ExpressionConverterTest.java | 2 +- .../impl/FunctionConverterTest.java | 19 ++++++----------- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverter.java b/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverter.java index 0d07bfe97708d..a2329f481b778 100644 --- a/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverter.java +++ b/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverter.java @@ -121,7 +121,7 @@ public static Optional convert(final ExpressionSegment segment) { return Optional.of(ParameterMarkerExpressionConverter.convert((ParameterMarkerExpressionSegment) segment)); } if (segment instanceof FunctionSegment) { - return FunctionConverter.convert((FunctionSegment) segment); + return Optional.of(FunctionConverter.convert((FunctionSegment) segment)); } if (segment instanceof AggregationProjectionSegment) { return AggregationProjectionConverter.convert((AggregationProjectionSegment) segment); diff --git a/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverter.java b/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverter.java index f7e4478641b92..1a69b1faad197 100644 --- a/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverter.java +++ b/kernel/sql-federation/compiler/src/main/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverter.java @@ -41,7 +41,6 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; -import java.util.Optional; /** * Function converter. @@ -55,29 +54,29 @@ public final class FunctionConverter { * @param segment function segment * @return SQL node */ - public static Optional convert(final FunctionSegment segment) { + public static SqlNode convert(final FunctionSegment segment) { SqlIdentifier functionName = new SqlIdentifier(getQualifiedFunctionNames(segment), SqlParserPos.ZERO); - // TODO optimize sql parse logic for select current_user. + // TODO optimize SQL parse logic for select current_user if ("CURRENT_USER".equalsIgnoreCase(functionName.getSimple())) { - return Optional.of(functionName); + return functionName; } if ("TRIM".equalsIgnoreCase(functionName.getSimple())) { - return Optional.of(TrimFunctionConverter.convert(segment)); + return TrimFunctionConverter.convert(segment); } if ("OVER".equalsIgnoreCase(functionName.getSimple())) { - return Optional.of(WindowFunctionConverter.convert(segment)); + return WindowFunctionConverter.convert(segment); } List functions = new LinkedList<>(); SqlStdOperatorTable.instance().lookupOperatorOverloads(functionName, null, SqlSyntax.FUNCTION, functions, SqlNameMatchers.withCaseSensitive(false)); if (!functions.isEmpty() && segment.getWindow().isPresent()) { SqlBasicCall functionCall = new SqlBasicCall(functions.iterator().next(), getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO); SqlWindow sqlWindow = WindowConverter.convertWindowItem(segment.getWindow().get()); - return Optional.of(new SqlBasicCall(SqlStdOperatorTable.OVER, new SqlNode[]{functionCall, sqlWindow}, SqlParserPos.ZERO)); + return new SqlBasicCall(SqlStdOperatorTable.OVER, new SqlNode[]{functionCall, sqlWindow}, SqlParserPos.ZERO); } - return Optional.of(functions.isEmpty() - ? new SqlBasicCall(new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION), getFunctionParameters(segment.getParameters()), - SqlParserPos.ZERO) - : new SqlBasicCall(functions.iterator().next(), getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO)); + SqlOperator operator = functions.isEmpty() + ? new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION) + : functions.iterator().next(); + return new SqlBasicCall(operator, getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO); } private static List getQualifiedFunctionNames(final FunctionSegment segment) { diff --git a/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverterTest.java b/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverterTest.java index c52a8d468abd6..7fce651ac09bd 100644 --- a/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverterTest.java +++ b/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/ExpressionConverterTest.java @@ -142,7 +142,7 @@ void assertConvertDelegatesToAllSupportedConverters() { when(ParameterMarkerExpressionConverter.convert(parameterSegment)).thenReturn(expectedParameterNode); SqlNode expectedFunctionNode = mock(SqlNode.class); FunctionSegment functionSegment = new FunctionSegment(0, 0, "func", "func_text"); - when(FunctionConverter.convert(functionSegment)).thenReturn(Optional.of(expectedFunctionNode)); + when(FunctionConverter.convert(functionSegment)).thenReturn(expectedFunctionNode); SqlNode expectedAggregationNode = mock(SqlNode.class); AggregationProjectionSegment aggregationSegment = new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "count(expr)"); when(AggregationProjectionConverter.convert(aggregationSegment)).thenReturn(Optional.of(expectedAggregationNode)); diff --git a/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverterTest.java b/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverterTest.java index 39ef00d5efbc9..53b2e76612ebd 100644 --- a/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverterTest.java +++ b/kernel/sql-federation/compiler/src/test/java/org/apache/shardingsphere/sqlfederation/compiler/sql/ast/converter/segment/expression/impl/FunctionConverterTest.java @@ -47,7 +47,6 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -58,8 +57,7 @@ class FunctionConverterTest { @Test void assertConvertReturnsCurrentUserIdentifier() { FunctionSegment segment = new FunctionSegment(0, 0, "CURRENT_USER", "CURRENT_USER"); - SqlIdentifier actual = (SqlIdentifier) FunctionConverter.convert(segment).orElse(null); - assertNotNull(actual); + SqlIdentifier actual = (SqlIdentifier) FunctionConverter.convert(segment); assertThat(actual.getSimple(), is("CURRENT_USER")); } @@ -68,8 +66,7 @@ void assertConvertDelegatesToTrimFunctionConverter() { FunctionSegment segment = new FunctionSegment(0, 0, "TRIM", "TRIM"); SqlBasicCall expected = mock(SqlBasicCall.class); when(TrimFunctionConverter.convert(segment)).thenReturn(expected); - SqlNode actual = FunctionConverter.convert(segment).orElse(null); - assertThat(actual, is(expected)); + assertThat(FunctionConverter.convert(segment), is(expected)); } @Test @@ -77,8 +74,7 @@ void assertConvertDelegatesToWindowFunctionConverter() { FunctionSegment segment = new FunctionSegment(0, 0, "OVER", "OVER"); SqlBasicCall expected = mock(SqlBasicCall.class); when(WindowFunctionConverter.convert(segment)).thenReturn(expected); - SqlNode actual = FunctionConverter.convert(segment).orElse(null); - assertThat(actual, is(expected)); + assertThat(FunctionConverter.convert(segment), is(expected)); } @Test @@ -92,8 +88,7 @@ void assertConvertResolvedFunctionWithWindow() { SqlWindow windowNode = mock(SqlWindow.class); when(ExpressionConverter.convert(param)).thenReturn(Optional.of(paramNode)); when(WindowConverter.convertWindowItem(windowItemSegment)).thenReturn(windowNode); - SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null); - assertNotNull(actual); + SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment); assertThat(actual.getOperator(), is(SqlStdOperatorTable.OVER)); SqlBasicCall functionCall = (SqlBasicCall) actual.getOperandList().get(0); assertThat(functionCall.getOperator().getName(), is("COUNT")); @@ -113,8 +108,7 @@ void assertConvertResolvedFunctionWithoutWindowFlattensParameters() { SqlNode secondNode = mock(SqlNode.class); when(ExpressionConverter.convert(firstParam)).thenReturn(Optional.of(new SqlNodeList(Arrays.asList(nodeInList, listSecondNode), SqlParserPos.ZERO))); when(ExpressionConverter.convert(secondParam)).thenReturn(Optional.of(secondNode)); - SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null); - assertNotNull(actual); + SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment); assertThat(actual.getOperator().getName(), is("SUM")); assertThat(actual.getOperandList(), is(Arrays.asList(nodeInList, listSecondNode, secondNode))); } @@ -129,8 +123,7 @@ void assertConvertUnresolvedFunctionWithOwner() { SqlNode paramNode = mock(SqlNode.class); when(OwnerConverter.convert(owner)).thenReturn(new ArrayList<>()); when(ExpressionConverter.convert(param)).thenReturn(Optional.of(paramNode)); - SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null); - assertNotNull(actual); + SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment); assertThat(actual.getOperator(), instanceOf(SqlUnresolvedFunction.class)); SqlIdentifier functionName = actual.getOperator().getNameAsId(); assertThat(functionName.names, is(Collections.singletonList("custom_func")));