Skip to content

Commit a37d949

Browse files
committed
Fix unix test for flatten command
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 70521a5 commit a37d949

3 files changed

Lines changed: 98 additions & 15 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,41 @@ private void tryToRemoveNestedFields(CalcitePlanContext context) {
207207
.map(field -> (RexNode) context.relBuilder.field(field))
208208
.toList();
209209
if (!duplicatedNestedFields.isEmpty()) {
210-
context.relBuilder.projectExcept(duplicatedNestedFields);
210+
// This is a workaround to avoid the bug in Calcite:
211+
// In {@link RelBuilder#project_(Iterable, Iterable, Iterable, boolean, Iterable)},
212+
// the check `RexUtil.isIdentity(nodeList, inputRowType)` will pass when the input
213+
// and the output nodeList refer to the same fields, even if the field name list
214+
// is different. As a result, renaming operation will not be applied. This makes
215+
// the logical plan for the flatten command incorrect, where the operation is
216+
// equivalent to renaming the flattened sub-fields. E.g. emp.name -> name.
217+
forceProjectExcept(context.relBuilder, duplicatedNestedFields);
211218
}
212219
}
213220

221+
/**
222+
* Project except with force.
223+
*
224+
* <p>This method is copied from {@link RelBuilder#projectExcept(Iterable)} and modified with the
225+
* force flag in project set to true. It is subject to future changes in Calcite.
226+
*
227+
* @param relBuilder RelBuilder
228+
* @param expressions Expressions to exclude from the project
229+
*/
230+
private static void forceProjectExcept(RelBuilder relBuilder, Iterable<RexNode> expressions) {
231+
List<RexNode> allExpressions = new ArrayList<>(relBuilder.fields());
232+
Set<RexNode> excludeExpressions = new HashSet<>();
233+
for (RexNode excludeExp : expressions) {
234+
if (!excludeExpressions.add(excludeExp)) {
235+
throw new IllegalArgumentException(
236+
"Input list contains duplicates. Expression " + excludeExp + " exists multiple times.");
237+
}
238+
if (!allExpressions.remove(excludeExp)) {
239+
throw new IllegalArgumentException("Expression " + excludeExp.toString() + " not found.");
240+
}
241+
}
242+
relBuilder.project(allExpressions, ImmutableList.of(), true);
243+
}
244+
214245
/**
215246
* Try to remove metadata fields in two cases:
216247
*

docs/user/ppl/cmd/flatten.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ flatten
1010

1111
Description
1212
===========
13-
From 3.1.0
1413

1514
Use ``flatten`` command to flatten a nested struct / object field into separate
1615
fields in a document.
@@ -23,6 +22,10 @@ Note that ``flatten`` does not work on arrays. Please use ``expand`` command
2322
to expand an array field into multiple rows instead. If the field is an nested
2423
array of structs, only the first element of the array will be flattened.
2524

25+
Version
26+
=======
27+
Since 3.1.0
28+
2629
Syntax
2730
======
2831

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFlattenTest.java

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@
99

1010
import com.google.common.collect.ImmutableList;
1111
import java.util.List;
12+
import lombok.RequiredArgsConstructor;
13+
import org.apache.calcite.DataContext;
1214
import org.apache.calcite.config.CalciteConnectionConfig;
15+
import org.apache.calcite.linq4j.Enumerable;
16+
import org.apache.calcite.linq4j.Linq4j;
1317
import org.apache.calcite.plan.RelTraitDef;
1418
import org.apache.calcite.rel.RelCollations;
1519
import org.apache.calcite.rel.RelNode;
1620
import org.apache.calcite.rel.type.RelDataType;
1721
import org.apache.calcite.rel.type.RelDataTypeFactory;
1822
import org.apache.calcite.rel.type.RelProtoDataType;
23+
import org.apache.calcite.schema.ScannableTable;
1924
import org.apache.calcite.schema.Schema;
2025
import org.apache.calcite.schema.SchemaPlus;
2126
import org.apache.calcite.schema.Statistic;
2227
import org.apache.calcite.schema.Statistics;
23-
import org.apache.calcite.schema.Table;
2428
import org.apache.calcite.sql.SqlCall;
2529
import org.apache.calcite.sql.SqlNode;
2630
import org.apache.calcite.sql.parser.SqlParser;
@@ -32,9 +36,7 @@
3236
import org.junit.Assert;
3337
import org.junit.Test;
3438

35-
/**
36-
* Unit tests for {@code flatten} command in PPL.
37-
*/
39+
/** Unit tests for {@code flatten} command in PPL. */
3840
public class CalcitePPLFlattenTest extends CalcitePPLAbstractTest {
3941
public CalcitePPLFlattenTest() {
4042
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
@@ -45,7 +47,12 @@ protected Frameworks.ConfigBuilder config(CalciteAssert.SchemaSpec... schemaSpec
4547
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
4648
final SchemaPlus schema = CalciteAssert.addSchema(rootSchema, schemaSpecs);
4749
// Add an empty table with name DEPT for test purpose
48-
schema.add("DEPT", new TableWithStruct());
50+
ImmutableList<Object[]> rows =
51+
ImmutableList.of(
52+
new Object[] {10, ImmutableList.of(7369, "ALLEN"), "SMITH", 7369},
53+
new Object[] {20, ImmutableList.of(7499, "ALLEN"), "ALLEN", 7499},
54+
new Object[] {30, ImmutableList.of(7521, "WARD"), "WARD", 7521});
55+
schema.add("DEPT", new TableWithStruct(rows));
4956
return Frameworks.newConfigBuilder()
5057
.parserConfig(SqlParser.Config.DEFAULT)
5158
.defaultSchema(schema)
@@ -58,20 +65,54 @@ public void testFlatten() {
5865
String ppl = "source=DEPT | flatten EMP";
5966
RelNode root = getRelNode(ppl);
6067
// Regarded as an identity scan. See RelBuilder#L2801
61-
String expectedLogical = "LogicalTableScan(table=[[scott, DEPT]])\n";
68+
String expectedLogical =
69+
"LogicalProject(DEPTNO=[$0], EMP=[$1], EMPNAME=[$2], EMPNO=[$3])\n"
70+
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
6271
verifyLogical(root, expectedLogical);
63-
String expectedSparkSql = "SELECT *\nFROM `scott`.`DEPT`";
72+
String expectedSparkSql =
73+
"SELECT `DEPTNO`, `EMP`, `EMP.EMPNAME` `EMPNAME`, `EMP.EMPNO` `EMPNO`\nFROM `scott`.`DEPT`";
6474
verifyPPLToSparkSQL(root, expectedSparkSql);
75+
String expectedResult =
76+
"DEPTNO=10; EMP={7369, ALLEN}; EMPNAME=SMITH; EMPNO=7369\n"
77+
+ "DEPTNO=20; EMP={7499, ALLEN}; EMPNAME=ALLEN; EMPNO=7499\n"
78+
+ "DEPTNO=30; EMP={7521, WARD}; EMPNAME=WARD; EMPNO=7521\n";
79+
verifyResult(root, expectedResult);
6580
}
6681

6782
@Test
6883
public void testFlattenWithAliases() {
6984
String ppl = "source=DEPT | flatten EMP as name, number";
7085
RelNode root = getRelNode(ppl);
71-
String expectedLogical = "LogicalTableScan(table=[[scott, DEPT]])\n";
86+
String expectedLogical =
87+
"LogicalProject(DEPTNO=[$0], EMP=[$1], name=[$2], number=[$3])\n"
88+
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
7289
verifyLogical(root, expectedLogical);
73-
String expectedSparkSql = "SELECT *\nFROM `scott`.`DEPT`";
90+
String expectedSparkSql =
91+
"SELECT `DEPTNO`, `EMP`, `EMP.EMPNAME` `name`, `EMP.EMPNO` `number`\nFROM `scott`.`DEPT`";
7492
verifyPPLToSparkSQL(root, expectedSparkSql);
93+
String expectedResult =
94+
"DEPTNO=10; EMP={7369, ALLEN}; name=SMITH; number=7369\n"
95+
+ "DEPTNO=20; EMP={7499, ALLEN}; name=ALLEN; number=7499\n"
96+
+ "DEPTNO=30; EMP={7521, WARD}; name=WARD; number=7521\n";
97+
verifyResult(root, expectedResult);
98+
}
99+
100+
/**
101+
* This validates that the created table is scannable and the nested fields are removed from the
102+
* result.
103+
*/
104+
@Test
105+
public void testProject() {
106+
String ppl = "source=DEPT";
107+
RelNode root = getRelNode(ppl);
108+
String expectedLogical =
109+
"LogicalProject(DEPTNO=[$0], EMP=[$1])\n LogicalTableScan(table=[[scott, DEPT]])\n";
110+
verifyLogical(root, expectedLogical);
111+
String expectedResult =
112+
"DEPTNO=10; EMP={7369, ALLEN}\n"
113+
+ "DEPTNO=20; EMP={7499, ALLEN}\n"
114+
+ "DEPTNO=30; EMP={7521, WARD}\n";
115+
verifyResult(root, expectedResult);
75116
}
76117

77118
@Test
@@ -80,12 +121,15 @@ public void testFlattenWithMismatchedNumberOfAliasesShouldThrow() {
80121
Throwable t = Assert.assertThrows(IllegalArgumentException.class, () -> getRelNode(ppl));
81122
verifyErrorMessageContains(
82123
t,
83-
"The number of aliases has to match the number of flattened fields. Expected 2 (EMP.EMPNO,"
84-
+ " EMP.EMPNAME), got 1 (name)");
124+
"The number of aliases has to match the number of flattened fields. Expected 2"
125+
+ " (EMP.EMPNAME, EMP.EMPNO), got 1 (name)");
85126
}
86127

87128
// There is no existing table with arrays. We create one for test purpose.
88-
public static class TableWithStruct implements Table {
129+
@RequiredArgsConstructor
130+
public static class TableWithStruct implements ScannableTable {
131+
private final ImmutableList<Object[]> rows;
132+
89133
protected final RelProtoDataType protoRowType =
90134
factory ->
91135
factory
@@ -102,10 +146,15 @@ public static class TableWithStruct implements Table {
102146
// E.g. struct emp will always hava emp.empno and emp.empname in its
103147
// logical projection. We add these two fields to simulate this behavior
104148
// in opensearch.
105-
.add("EMP.EMPNO", SqlTypeName.INTEGER)
106149
.add("EMP.EMPNAME", SqlTypeName.VARCHAR)
150+
.add("EMP.EMPNO", SqlTypeName.INTEGER)
107151
.build();
108152

153+
@Override
154+
public Enumerable<@Nullable Object[]> scan(DataContext root) {
155+
return Linq4j.asEnumerable(rows);
156+
}
157+
109158
@Override
110159
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
111160
return protoRowType.apply(typeFactory);

0 commit comments

Comments
 (0)