|
5 | 5 |
|
6 | 6 | package org.opensearch.sql.api.spec.search; |
7 | 7 |
|
| 8 | +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR; |
| 9 | +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST; |
| 10 | +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR; |
| 11 | +import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR; |
| 12 | + |
| 13 | +import java.util.ArrayList; |
8 | 14 | import java.util.List; |
9 | 15 | import lombok.AccessLevel; |
10 | 16 | import lombok.NoArgsConstructor; |
| 17 | +import org.apache.calcite.sql.SqlBasicTypeNameSpec; |
11 | 18 | import org.apache.calcite.sql.SqlCall; |
| 19 | +import org.apache.calcite.sql.SqlDataTypeSpec; |
12 | 20 | import org.apache.calcite.sql.SqlIdentifier; |
13 | 21 | import org.apache.calcite.sql.SqlKind; |
14 | 22 | import org.apache.calcite.sql.SqlLiteral; |
15 | 23 | import org.apache.calcite.sql.SqlNode; |
16 | | -import org.apache.calcite.sql.fun.SqlStdOperatorTable; |
17 | 24 | import org.apache.calcite.sql.parser.SqlParserPos; |
| 25 | +import org.apache.calcite.sql.type.SqlTypeName; |
18 | 26 | import org.apache.calcite.sql.util.SqlShuttle; |
19 | 27 | import org.checkerframework.checker.nullness.qual.Nullable; |
20 | 28 | import org.opensearch.sql.api.spec.UnifiedFunctionSpec; |
@@ -42,34 +50,69 @@ public final class NamedArgRewriter extends SqlShuttle { |
42 | 50 |
|
43 | 51 | /** |
44 | 52 | * Rewrites each argument into a MAP entry. For match(name, 'John', operator='AND'): |
45 | | - * <li>Positional arg: name → MAP('field', name) |
46 | 53 | * <li>Named arg: operator='AND' → MAP('operator', 'AND') |
| 54 | + * <li>Positional arg: name → MAP('field', name) |
| 55 | + * <li>ARRAY arg: ARRAY['f1','f2'] → MAP('fields', MAP(CAST('f1' AS VARCHAR), 1, ...)) |
47 | 56 | */ |
48 | 57 | private static SqlCall rewriteToMaps(SqlCall call, List<String> paramNames) { |
49 | 58 | List<SqlNode> operands = call.getOperandList(); |
50 | 59 | SqlNode[] maps = new SqlNode[operands.size()]; |
51 | 60 | for (int i = 0; i < operands.size(); i++) { |
52 | 61 | SqlNode op = operands.get(i); |
53 | | - if (op instanceof SqlCall eq && op.getKind() == SqlKind.EQUALS) { |
54 | | - SqlNode key = eq.operand(0); |
55 | | - String name = |
56 | | - key instanceof SqlIdentifier ident |
57 | | - ? ident.getSimple() |
58 | | - : key.toString(); // avoid backtick-decorated keys for reserved words |
59 | | - maps[i] = toMap(name, eq.operand(1)); |
60 | | - } else { |
| 62 | + if (isNamedArg(op)) { |
| 63 | + maps[i] = namedArgToMap((SqlCall) op); |
| 64 | + } else { // Positional arg |
61 | 65 | if (i >= paramNames.size()) { |
62 | 66 | throw new IllegalArgumentException( |
63 | 67 | String.format("Invalid arguments for function '%s'", call.getOperator().getName())); |
| 68 | + } else if (isArrayArg(op)) { |
| 69 | + maps[i] = map(paramNames.get(i), arrayArgToMap((SqlCall) op)); |
| 70 | + } else { |
| 71 | + maps[i] = map(paramNames.get(i), op); |
64 | 72 | } |
65 | | - maps[i] = toMap(paramNames.get(i), op); |
66 | 73 | } |
67 | 74 | } |
68 | 75 | return call.getOperator().createCall(call.getParserPosition(), maps); |
69 | 76 | } |
70 | 77 |
|
71 | | - private static SqlNode toMap(String key, SqlNode value) { |
72 | | - return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall( |
| 78 | + private static boolean isNamedArg(SqlNode node) { |
| 79 | + return node instanceof SqlCall && node.getKind() == SqlKind.EQUALS; |
| 80 | + } |
| 81 | + |
| 82 | + private static boolean isArrayArg(SqlNode node) { |
| 83 | + return node instanceof SqlCall call && call.getOperator() == ARRAY_VALUE_CONSTRUCTOR; |
| 84 | + } |
| 85 | + |
| 86 | + private static SqlNode namedArgToMap(SqlCall eq) { |
| 87 | + SqlNode key = eq.operand(0); |
| 88 | + String name = |
| 89 | + key instanceof SqlIdentifier ident |
| 90 | + ? ident.getSimple() |
| 91 | + : key.toString(); // avoid backtick-decorated keys for reserved words |
| 92 | + return map(name, eq.operand(1)); |
| 93 | + } |
| 94 | + |
| 95 | + private static SqlNode arrayArgToMap(SqlCall arrayCall) { |
| 96 | + List<SqlNode> mapArgs = new ArrayList<>(); |
| 97 | + for (SqlNode element : arrayCall.getOperandList()) { |
| 98 | + mapArgs.add(cast(element, VARCHAR)); |
| 99 | + mapArgs.add(SqlLiteral.createApproxNumeric("1.0", SqlParserPos.ZERO)); |
| 100 | + } |
| 101 | + return map(mapArgs); |
| 102 | + } |
| 103 | + |
| 104 | + private static SqlNode cast(SqlNode node, SqlTypeName type) { |
| 105 | + SqlDataTypeSpec typeSpec = |
| 106 | + new SqlDataTypeSpec(new SqlBasicTypeNameSpec(type, SqlParserPos.ZERO), SqlParserPos.ZERO); |
| 107 | + return CAST.createCall(SqlParserPos.ZERO, node, typeSpec); |
| 108 | + } |
| 109 | + |
| 110 | + private static SqlNode map(String key, SqlNode value) { |
| 111 | + return MAP_VALUE_CONSTRUCTOR.createCall( |
73 | 112 | SqlParserPos.ZERO, SqlLiteral.createCharString(key, SqlParserPos.ZERO), value); |
74 | 113 | } |
| 114 | + |
| 115 | + private static SqlNode map(List<SqlNode> kvPairs) { |
| 116 | + return MAP_VALUE_CONSTRUCTOR.createCall(SqlParserPos.ZERO, kvPairs.toArray(SqlNode[]::new)); |
| 117 | + } |
75 | 118 | } |
0 commit comments