Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public static Optional<SqlNode> convert(final ExpressionSegment segment) {
return Optional.of(ParameterMarkerExpressionConverter.convert((ParameterMarkerExpressionSegment) segment));
}
if (segment instanceof FunctionSegment) {
return FunctionConverter.convert((FunctionSegment) segment);
return Optional.of(FunctionConverter.convert((FunctionSegment) segment));
}
if (segment instanceof AggregationProjectionSegment) {
return AggregationProjectionConverter.convert((AggregationProjectionSegment) segment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

/**
* Function converter.
Expand All @@ -55,29 +54,29 @@ public final class FunctionConverter {
* @param segment function segment
* @return SQL node
*/
public static Optional<SqlNode> convert(final FunctionSegment segment) {
public static SqlNode convert(final FunctionSegment segment) {
SqlIdentifier functionName = new SqlIdentifier(getQualifiedFunctionNames(segment), SqlParserPos.ZERO);
// TODO optimize sql parse logic for select current_user.
// TODO optimize SQL parse logic for select current_user
if ("CURRENT_USER".equalsIgnoreCase(functionName.getSimple())) {
return Optional.of(functionName);
return functionName;
}
if ("TRIM".equalsIgnoreCase(functionName.getSimple())) {
return Optional.of(TrimFunctionConverter.convert(segment));
return TrimFunctionConverter.convert(segment);
}
if ("OVER".equalsIgnoreCase(functionName.getSimple())) {
return Optional.of(WindowFunctionConverter.convert(segment));
return WindowFunctionConverter.convert(segment);
}
List<SqlOperator> functions = new LinkedList<>();
SqlStdOperatorTable.instance().lookupOperatorOverloads(functionName, null, SqlSyntax.FUNCTION, functions, SqlNameMatchers.withCaseSensitive(false));
if (!functions.isEmpty() && segment.getWindow().isPresent()) {
SqlBasicCall functionCall = new SqlBasicCall(functions.iterator().next(), getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO);
SqlWindow sqlWindow = WindowConverter.convertWindowItem(segment.getWindow().get());
return Optional.of(new SqlBasicCall(SqlStdOperatorTable.OVER, new SqlNode[]{functionCall, sqlWindow}, SqlParserPos.ZERO));
return new SqlBasicCall(SqlStdOperatorTable.OVER, new SqlNode[]{functionCall, sqlWindow}, SqlParserPos.ZERO);
}
return Optional.of(functions.isEmpty()
? new SqlBasicCall(new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION), getFunctionParameters(segment.getParameters()),
SqlParserPos.ZERO)
: new SqlBasicCall(functions.iterator().next(), getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO));
SqlOperator operator = functions.isEmpty()
? new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION)
: functions.iterator().next();
return new SqlBasicCall(operator, getFunctionParameters(segment.getParameters()), SqlParserPos.ZERO);
}

private static List<String> getQualifiedFunctionNames(final FunctionSegment segment) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void assertConvertDelegatesToAllSupportedConverters() {
when(ParameterMarkerExpressionConverter.convert(parameterSegment)).thenReturn(expectedParameterNode);
SqlNode expectedFunctionNode = mock(SqlNode.class);
FunctionSegment functionSegment = new FunctionSegment(0, 0, "func", "func_text");
when(FunctionConverter.convert(functionSegment)).thenReturn(Optional.of(expectedFunctionNode));
when(FunctionConverter.convert(functionSegment)).thenReturn(expectedFunctionNode);
SqlNode expectedAggregationNode = mock(SqlNode.class);
AggregationProjectionSegment aggregationSegment = new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "count(expr)");
when(AggregationProjectionConverter.convert(aggregationSegment)).thenReturn(Optional.of(expectedAggregationNode));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -58,8 +57,7 @@ class FunctionConverterTest {
@Test
void assertConvertReturnsCurrentUserIdentifier() {
FunctionSegment segment = new FunctionSegment(0, 0, "CURRENT_USER", "CURRENT_USER");
SqlIdentifier actual = (SqlIdentifier) FunctionConverter.convert(segment).orElse(null);
assertNotNull(actual);
SqlIdentifier actual = (SqlIdentifier) FunctionConverter.convert(segment);
assertThat(actual.getSimple(), is("CURRENT_USER"));
}

Expand All @@ -68,17 +66,15 @@ void assertConvertDelegatesToTrimFunctionConverter() {
FunctionSegment segment = new FunctionSegment(0, 0, "TRIM", "TRIM");
SqlBasicCall expected = mock(SqlBasicCall.class);
when(TrimFunctionConverter.convert(segment)).thenReturn(expected);
SqlNode actual = FunctionConverter.convert(segment).orElse(null);
assertThat(actual, is(expected));
assertThat(FunctionConverter.convert(segment), is(expected));
}

@Test
void assertConvertDelegatesToWindowFunctionConverter() {
FunctionSegment segment = new FunctionSegment(0, 0, "OVER", "OVER");
SqlBasicCall expected = mock(SqlBasicCall.class);
when(WindowFunctionConverter.convert(segment)).thenReturn(expected);
SqlNode actual = FunctionConverter.convert(segment).orElse(null);
assertThat(actual, is(expected));
assertThat(FunctionConverter.convert(segment), is(expected));
}

@Test
Expand All @@ -92,8 +88,7 @@ void assertConvertResolvedFunctionWithWindow() {
SqlWindow windowNode = mock(SqlWindow.class);
when(ExpressionConverter.convert(param)).thenReturn(Optional.of(paramNode));
when(WindowConverter.convertWindowItem(windowItemSegment)).thenReturn(windowNode);
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null);
assertNotNull(actual);
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment);
assertThat(actual.getOperator(), is(SqlStdOperatorTable.OVER));
SqlBasicCall functionCall = (SqlBasicCall) actual.getOperandList().get(0);
assertThat(functionCall.getOperator().getName(), is("COUNT"));
Expand All @@ -113,8 +108,7 @@ void assertConvertResolvedFunctionWithoutWindowFlattensParameters() {
SqlNode secondNode = mock(SqlNode.class);
when(ExpressionConverter.convert(firstParam)).thenReturn(Optional.of(new SqlNodeList(Arrays.asList(nodeInList, listSecondNode), SqlParserPos.ZERO)));
when(ExpressionConverter.convert(secondParam)).thenReturn(Optional.of(secondNode));
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null);
assertNotNull(actual);
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment);
assertThat(actual.getOperator().getName(), is("SUM"));
assertThat(actual.getOperandList(), is(Arrays.asList(nodeInList, listSecondNode, secondNode)));
}
Expand All @@ -129,8 +123,7 @@ void assertConvertUnresolvedFunctionWithOwner() {
SqlNode paramNode = mock(SqlNode.class);
when(OwnerConverter.convert(owner)).thenReturn(new ArrayList<>());
when(ExpressionConverter.convert(param)).thenReturn(Optional.of(paramNode));
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment).orElse(null);
assertNotNull(actual);
SqlBasicCall actual = (SqlBasicCall) FunctionConverter.convert(segment);
assertThat(actual.getOperator(), instanceOf(SqlUnresolvedFunction.class));
SqlIdentifier functionName = actual.getOperator().getNameAsId();
assertThat(functionName.names, is(Collections.singletonList("custom_func")));
Expand Down