Skip to content

Commit 17a19d3

Browse files
committed
add UT
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 7f9c6ec commit 17a19d3

2 files changed

Lines changed: 181 additions & 3 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.Map;
2121
import java.util.stream.Collectors;
2222
import java.util.stream.IntStream;
23-
import java.util.stream.Stream;
2423
import lombok.RequiredArgsConstructor;
2524
import org.apache.calcite.rel.RelNode;
2625
import org.apache.calcite.rel.type.RelDataType;
@@ -378,7 +377,7 @@ public RexNode visitLet(Let node, CalcitePlanContext context) {
378377
* will map type for each lambda argument by the order of previous argument. Also, the function
379378
* will add these variables to the context so they can pass visitQualifiedName
380379
*/
381-
private CalcitePlanContext prepareLambdaContext(
380+
public CalcitePlanContext prepareLambdaContext(
382381
CalcitePlanContext context,
383382
LambdaFunction node,
384383
List<RexNode> previousArgument,
@@ -424,7 +423,8 @@ private CalcitePlanContext prepareLambdaContext(
424423
private List<RelDataType> modifyLambdaTypeByFunction(
425424
String functionName, List<RelDataType> originalType) {
426425
switch (functionName.toUpperCase(Locale.ROOT)) {
427-
case "REDUCE": // For reduce case, the first type is acc should be any since it is the output of accumulator lambda function
426+
case "REDUCE": // For reduce case, the first type is acc should be any since it is the output
427+
// of accumulator lambda function
428428
if (originalType.size() == 2) {
429429
return List.of(originalType.get(1), originalType.get(0));
430430
} else {
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite;
7+
8+
import static org.junit.jupiter.api.Assertions.*;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.when;
11+
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
12+
13+
import java.sql.Connection;
14+
import java.util.List;
15+
import org.apache.calcite.rel.type.RelDataType;
16+
import org.apache.calcite.rex.RexNode;
17+
import org.apache.calcite.sql.type.ArraySqlType;
18+
import org.apache.calcite.sql.type.SqlTypeName;
19+
import org.apache.calcite.tools.FrameworkConfig;
20+
import org.apache.calcite.tools.RelBuilder;
21+
import org.junit.jupiter.api.AfterEach;
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.MockedStatic;
27+
import org.mockito.Mockito;
28+
import org.mockito.junit.jupiter.MockitoExtension;
29+
import org.opensearch.sql.ast.expression.LambdaFunction;
30+
import org.opensearch.sql.ast.expression.QualifiedName;
31+
import org.opensearch.sql.calcite.utils.CalciteToolsHelper;
32+
import org.opensearch.sql.executor.QueryType;
33+
34+
@ExtendWith(MockitoExtension.class)
35+
public class CalciteRexNodeVisitorTest {
36+
@Mock LambdaFunction lambdaFunction;
37+
@Mock RexNode arrayArg;
38+
@Mock RexNode extraArg;
39+
@Mock RexNode accArg;
40+
;
41+
@Mock ArraySqlType arraySqlType;
42+
@Mock RelDataType componentType;
43+
@Mock RelDataType extraType;
44+
@Mock RelDataType accType;
45+
@Mock QualifiedName functionArg1;
46+
@Mock QualifiedName functionArg2;
47+
48+
static CalciteRexNodeVisitor visitor;
49+
static CalciteRelNodeVisitor relNodeVisitor;
50+
51+
@Mock static FrameworkConfig frameworkConfig;
52+
@Mock static Connection connection;
53+
@Mock static RelBuilder relBuilder;
54+
@Mock static ExtendedRexBuilder rexBuilder;
55+
static CalcitePlanContext context;
56+
MockedStatic<CalciteToolsHelper> mockedStatic;
57+
58+
@BeforeEach
59+
public void setUpContext() {
60+
relNodeVisitor = new CalciteRelNodeVisitor();
61+
visitor = new CalciteRexNodeVisitor(relNodeVisitor);
62+
when(relBuilder.getRexBuilder()).thenReturn(rexBuilder);
63+
when(rexBuilder.getTypeFactory()).thenReturn(TYPE_FACTORY);
64+
mockedStatic = Mockito.mockStatic(CalciteToolsHelper.class);
65+
mockedStatic.when(() -> CalciteToolsHelper.connect(any(), any())).thenReturn(connection);
66+
67+
mockedStatic.when(() -> CalciteToolsHelper.create(any(), any(), any())).thenReturn(relBuilder);
68+
69+
context = CalcitePlanContext.create(frameworkConfig, 100, QueryType.PPL);
70+
}
71+
72+
@AfterEach
73+
public void tearDown() {
74+
mockedStatic.close();
75+
}
76+
77+
@Test
78+
public void testPrepareLambdaForBasicLambda() {
79+
when(componentType.getSqlTypeName()).thenReturn(SqlTypeName.DOUBLE);
80+
when(arrayArg.getType()).thenReturn(arraySqlType);
81+
when(arraySqlType.getComponentType()).thenReturn(componentType);
82+
83+
List<RexNode> previousArguments = List.of(arrayArg);
84+
when(functionArg1.toString()).thenReturn("arg1");
85+
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1));
86+
87+
CalcitePlanContext lambdaContext =
88+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "forall");
89+
90+
assertNotNull(lambdaContext);
91+
assertNotNull(lambdaContext.getRexLambdaRefMap());
92+
assertEquals(1, lambdaContext.getRexLambdaRefMap().size());
93+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("arg1"));
94+
assertEquals(
95+
lambdaContext.getRexLambdaRefMap().get("arg1").getType().getSqlTypeName(),
96+
SqlTypeName.DOUBLE);
97+
}
98+
99+
@Test
100+
public void testPrepareLambdaForTransform() {
101+
when(componentType.getSqlTypeName()).thenReturn(SqlTypeName.DOUBLE);
102+
when(arrayArg.getType()).thenReturn(arraySqlType);
103+
when(arraySqlType.getComponentType()).thenReturn(componentType);
104+
105+
List<RexNode> previousArguments = List.of(arrayArg);
106+
when(functionArg1.toString()).thenReturn("arg1");
107+
when(functionArg2.toString()).thenReturn("i");
108+
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1, functionArg2));
109+
110+
CalcitePlanContext lambdaContext =
111+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "transform");
112+
113+
assertNotNull(lambdaContext);
114+
assertNotNull(lambdaContext.getRexLambdaRefMap());
115+
assertEquals(2, lambdaContext.getRexLambdaRefMap().size());
116+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("arg1"));
117+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("i"));
118+
assertEquals(
119+
lambdaContext.getRexLambdaRefMap().get("arg1").getType().getSqlTypeName(),
120+
SqlTypeName.DOUBLE);
121+
assertEquals(
122+
lambdaContext.getRexLambdaRefMap().get("i").getType().getSqlTypeName(),
123+
SqlTypeName.INTEGER);
124+
}
125+
126+
@Test
127+
public void testPrepareLambdaForReduce() {
128+
when(componentType.getSqlTypeName()).thenReturn(SqlTypeName.DOUBLE);
129+
when(arrayArg.getType()).thenReturn(arraySqlType);
130+
when(arraySqlType.getComponentType()).thenReturn(componentType);
131+
when(extraArg.getType()).thenReturn(extraType);
132+
when(extraType.getSqlTypeName()).thenReturn(SqlTypeName.VARCHAR);
133+
134+
List<RexNode> previousArguments = List.of(arrayArg, extraArg);
135+
when(functionArg1.toString()).thenReturn("acc");
136+
when(functionArg2.toString()).thenReturn("arg1");
137+
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1, functionArg2));
138+
139+
CalcitePlanContext lambdaContext =
140+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce");
141+
142+
assertNotNull(lambdaContext);
143+
assertNotNull(lambdaContext.getRexLambdaRefMap());
144+
assertEquals(2, lambdaContext.getRexLambdaRefMap().size());
145+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("arg1"));
146+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("acc"));
147+
assertEquals(
148+
lambdaContext.getRexLambdaRefMap().get("arg1").getType().getSqlTypeName(),
149+
SqlTypeName.DOUBLE);
150+
assertEquals(
151+
lambdaContext.getRexLambdaRefMap().get("acc").getType().getSqlTypeName(),
152+
SqlTypeName.VARCHAR);
153+
}
154+
155+
@Test
156+
public void testPrepareLambdaForReduceFinalizerFunction() {
157+
when(arrayArg.getType()).thenReturn(arraySqlType);
158+
when(arraySqlType.getComponentType()).thenReturn(componentType);
159+
when(extraArg.getType()).thenReturn(extraType);
160+
when(accArg.getType()).thenReturn(accType);
161+
when(accType.getSqlTypeName()).thenReturn(SqlTypeName.FLOAT);
162+
163+
List<RexNode> previousArguments = List.of(arrayArg, extraArg, accArg);
164+
when(functionArg1.toString()).thenReturn("acc");
165+
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1));
166+
167+
CalcitePlanContext lambdaContext =
168+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce");
169+
170+
assertNotNull(lambdaContext);
171+
assertNotNull(lambdaContext.getRexLambdaRefMap());
172+
assertEquals(1, lambdaContext.getRexLambdaRefMap().size());
173+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("acc"));
174+
assertEquals(
175+
lambdaContext.getRexLambdaRefMap().get("acc").getType().getSqlTypeName(),
176+
SqlTypeName.FLOAT);
177+
}
178+
}

0 commit comments

Comments
 (0)