Skip to content

Commit 485c038

Browse files
committed
Merge remote-tracking branch 'origin/main' into issues/4636 (1929/2069)
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
2 parents 0667fa7 + 11727a4 commit 485c038

18 files changed

Lines changed: 1185 additions & 12 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import java.sql.Connection;
1111
import java.sql.SQLException;
12+
import java.util.ArrayList;
1213
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
@@ -70,6 +71,16 @@ public class CalcitePlanContext {
7071

7172
@Getter public Map<String, RexLambdaRef> rexLambdaRefMap;
7273

74+
/**
75+
* List of captured variables from outer scope for lambda functions. When a lambda body references
76+
* a field that is not a lambda parameter, it gets captured and stored here. The captured
77+
* variables are passed as additional arguments to the transform function.
78+
*/
79+
@Getter private List<RexNode> capturedVariables;
80+
81+
/** Whether we're currently inside a lambda context. */
82+
@Getter @Setter private boolean inLambdaContext = false;
83+
7384
/**
7485
* -- SETTER -- Sets the SQL operator table provider. This must be called during initialization by
7586
* the opensearch module.
@@ -90,6 +101,24 @@ private CalcitePlanContext(FrameworkConfig config, SysLimit sysLimit, QueryType
90101
this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder());
91102
this.functionProperties = new FunctionProperties(QueryType.PPL);
92103
this.rexLambdaRefMap = new HashMap<>();
104+
this.capturedVariables = new ArrayList<>();
105+
}
106+
107+
/**
108+
* Private constructor for creating a context that shares relBuilder with parent. Used by clone()
109+
* to create lambda contexts that can resolve fields from the parent context.
110+
*/
111+
private CalcitePlanContext(CalcitePlanContext parent) {
112+
this.config = parent.config;
113+
this.sysLimit = parent.sysLimit;
114+
this.queryType = parent.queryType;
115+
this.connection = parent.connection;
116+
this.relBuilder = parent.relBuilder; // Share the same relBuilder
117+
this.rexBuilder = parent.rexBuilder; // Share the same rexBuilder
118+
this.functionProperties = parent.functionProperties;
119+
this.rexLambdaRefMap = new HashMap<>(); // New map for lambda variables
120+
this.capturedVariables = new ArrayList<>(); // New list for captured variables
121+
this.inLambdaContext = true; // Mark that we're inside a lambda
93122
}
94123

95124
/**
@@ -157,8 +186,13 @@ public Optional<RexCorrelVariable> peekCorrelVar() {
157186
}
158187
}
159188

189+
/**
190+
* Creates a clone of this context that shares the relBuilder with the parent. This allows lambda
191+
* expressions to reference fields from the current row while having their own lambda variable
192+
* mappings.
193+
*/
160194
public CalcitePlanContext clone() {
161-
return new CalcitePlanContext(config, sysLimit, queryType);
195+
return new CalcitePlanContext(this);
162196
}
163197

164198
public static CalcitePlanContext create(
@@ -190,4 +224,42 @@ public static boolean isLegacyPreferred() {
190224
public void putRexLambdaRefMap(Map<String, RexLambdaRef> candidateMap) {
191225
this.rexLambdaRefMap.putAll(candidateMap);
192226
}
227+
228+
/**
229+
* Captures an external variable for use inside a lambda. Returns a RexLambdaRef that references
230+
* the captured variable by its index in the captured variables list. The actual RexNode value is
231+
* stored in capturedVariables and will be passed as additional arguments to the transform
232+
* function.
233+
*
234+
* @param fieldRef The RexInputRef representing the external field
235+
* @param fieldName The name of the field being captured
236+
* @return A RexLambdaRef that can be used inside the lambda to reference the captured value
237+
*/
238+
public RexLambdaRef captureVariable(RexNode fieldRef, String fieldName) {
239+
// Check if this variable is already captured
240+
for (int i = 0; i < capturedVariables.size(); i++) {
241+
if (capturedVariables.get(i).equals(fieldRef)) {
242+
// Return existing reference - offset by number of lambda params (1 for array element)
243+
return rexLambdaRefMap.get("__captured_" + i);
244+
}
245+
}
246+
247+
// Add to captured variables list
248+
int captureIndex = capturedVariables.size();
249+
capturedVariables.add(fieldRef);
250+
251+
// Create a lambda ref for this captured variable
252+
// The index is offset by the number of lambda parameters (1 for single-param lambda)
253+
// Count only actual lambda parameters, not captured variables
254+
int lambdaParamCount =
255+
(int)
256+
rexLambdaRefMap.keySet().stream().filter(key -> !key.startsWith("__captured_")).count();
257+
RexLambdaRef lambdaRef =
258+
new RexLambdaRef(lambdaParamCount + captureIndex, fieldName, fieldRef.getType());
259+
260+
// Store it so we can find it again if the same field is referenced multiple times
261+
rexLambdaRefMap.put("__captured_" + captureIndex, lambdaRef);
262+
263+
return lambdaRef;
264+
}
193265
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,20 @@ public RexNode visitLambdaFunction(LambdaFunction node, CalcitePlanContext conte
297297
TYPE_FACTORY.createSqlType(SqlTypeName.ANY))))
298298
.collect(Collectors.toList());
299299
RexNode body = node.getFunction().accept(this, context);
300+
301+
// Add captured variables as additional lambda parameters
302+
// They are stored with keys like "__captured_0", "__captured_1", etc.
303+
List<RexNode> capturedVars = context.getCapturedVariables();
304+
if (capturedVars != null && !capturedVars.isEmpty()) {
305+
args = new ArrayList<>(args);
306+
for (int i = 0; i < capturedVars.size(); i++) {
307+
RexLambdaRef capturedRef = context.getRexLambdaRefMap().get("__captured_" + i);
308+
if (capturedRef != null) {
309+
args.add(capturedRef);
310+
}
311+
}
312+
}
313+
300314
RexNode lambdaNode = context.rexBuilder.makeLambdaCall(body, args);
301315
return lambdaNode;
302316
} catch (Exception e) {
@@ -390,6 +404,7 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
390404
context.setInCoalesceFunction(true);
391405
}
392406

407+
List<RexNode> capturedVars = null;
393408
try {
394409
for (UnresolvedExpression arg : args) {
395410
if (arg instanceof LambdaFunction) {
@@ -408,6 +423,8 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
408423
lambdaNode = analyze(arg, lambdaContext);
409424
}
410425
arguments.add(lambdaNode);
426+
// Capture any external variables that were referenced in the lambda
427+
capturedVars = lambdaContext.getCapturedVariables();
411428
} else {
412429
arguments.add(analyze(arg, context));
413430
}
@@ -418,6 +435,15 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
418435
}
419436
}
420437

438+
// For transform/mvmap functions with captured variables, add them as additional arguments
439+
if (capturedVars != null && !capturedVars.isEmpty()) {
440+
if (node.getFuncName().equalsIgnoreCase("mvmap")
441+
|| node.getFuncName().equalsIgnoreCase("transform")) {
442+
arguments = new ArrayList<>(arguments);
443+
arguments.addAll(capturedVars);
444+
}
445+
}
446+
421447
if ("LIKE".equalsIgnoreCase(node.getFuncName()) && arguments.size() == 2) {
422448
RexNode defaultCaseSensitive =
423449
CalcitePlanContext.isLegacyPreferred()

core/src/main/java/org/opensearch/sql/calcite/QualifiedNameResolver.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,30 @@ private static RexNode resolveInNonJoinCondition(
6464
QualifiedName nameNode, CalcitePlanContext context) {
6565
log.debug("resolveInNonJoinCondition() called with nameNode={}", nameNode);
6666

67-
return resolveLambdaVariable(nameNode, context)
68-
.or(() -> resolveFieldDirectly(nameNode, context, 1))
69-
.or(() -> resolveFieldWithAlias(nameNode, context, 1))
70-
.or(() -> resolveFieldWithoutAlias(nameNode, context, 1))
71-
.or(() -> resolveRenamedField(nameNode, context))
72-
.or(() -> resolveCorrelationField(nameNode, context))
67+
// First try to resolve as lambda variable
68+
Optional<RexNode> lambdaVar = resolveLambdaVariable(nameNode, context);
69+
if (lambdaVar.isPresent()) {
70+
return lambdaVar.get();
71+
}
72+
73+
// Try to resolve as regular field
74+
Optional<RexNode> fieldRef =
75+
resolveFieldDirectly(nameNode, context, 1)
76+
.or(() -> resolveFieldWithAlias(nameNode, context, 1))
77+
.or(() -> resolveFieldWithoutAlias(nameNode, context, 1))
78+
.or(() -> resolveRenamedField(nameNode, context));
79+
80+
if (fieldRef.isPresent()) {
81+
// If we're in a lambda context and this is not a lambda variable,
82+
// we need to capture it as an external variable
83+
if (context.isInLambdaContext()) {
84+
log.debug("Capturing external field {} in lambda context", nameNode);
85+
return context.captureVariable(fieldRef.get(), nameNode.toString());
86+
}
87+
return fieldRef.get();
88+
}
89+
90+
return resolveCorrelationField(nameNode, context)
7391
.or(() -> replaceWithNullLiteralInCoalesce(context))
7492
.orElseThrow(() -> getNotFoundException(nameNode));
7593
}

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ public enum BuiltinFunctionName {
7575
MVAPPEND(FunctionName.of("mvappend")),
7676
MVJOIN(FunctionName.of("mvjoin")),
7777
MVINDEX(FunctionName.of("mvindex")),
78+
MVFIND(FunctionName.of("mvfind")),
7879
MVZIP(FunctionName.of("mvzip")),
7980
SPLIT(FunctionName.of("split")),
8081
MVDEDUP(FunctionName.of("mvdedup")),
82+
MVMAP(FunctionName.of("mvmap")),
8183
FORALL(FunctionName.of("forall")),
8284
EXISTS(FunctionName.of("exists")),
8385
FILTER(FunctionName.of("filter")),
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.expression.function.CollectionUDF;
7+
8+
import java.util.List;
9+
import java.util.regex.Pattern;
10+
import java.util.regex.PatternSyntaxException;
11+
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
12+
import org.apache.calcite.adapter.enumerable.NullPolicy;
13+
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
14+
import org.apache.calcite.linq4j.tree.Expression;
15+
import org.apache.calcite.linq4j.tree.Expressions;
16+
import org.apache.calcite.linq4j.tree.Types;
17+
import org.apache.calcite.rex.RexCall;
18+
import org.apache.calcite.rex.RexLiteral;
19+
import org.apache.calcite.sql.type.OperandTypes;
20+
import org.apache.calcite.sql.type.ReturnTypes;
21+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
22+
import org.apache.calcite.sql.type.SqlTypeFamily;
23+
import org.opensearch.sql.expression.function.ImplementorUDF;
24+
import org.opensearch.sql.expression.function.UDFOperandMetadata;
25+
26+
/**
27+
* MVFIND function implementation that finds the index of the first element in a multivalue array
28+
* that matches a regular expression.
29+
*
30+
* <p>Usage: mvfind(array, regex)
31+
*
32+
* <p>Returns the 0-based index of the first array element matching the regex pattern, or NULL if no
33+
* match is found.
34+
*
35+
* <p>Example: mvfind(array('apple', 'banana', 'apricot'), 'ban.*') returns 1
36+
*/
37+
public class MVFindFunctionImpl extends ImplementorUDF {
38+
public MVFindFunctionImpl() {
39+
super(new MVFindImplementor(), NullPolicy.ANY);
40+
}
41+
42+
@Override
43+
public SqlReturnTypeInference getReturnTypeInference() {
44+
return ReturnTypes.INTEGER_NULLABLE;
45+
}
46+
47+
@Override
48+
public UDFOperandMetadata getOperandMetadata() {
49+
// Accept ARRAY and STRING for the regex pattern
50+
return UDFOperandMetadata.wrap(
51+
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER));
52+
}
53+
54+
public static class MVFindImplementor implements NotNullImplementor {
55+
@Override
56+
public Expression implement(
57+
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
58+
Expression arrayExpr = translatedOperands.get(0);
59+
Expression patternExpr = translatedOperands.get(1);
60+
61+
// Check if regex pattern is a literal - compile at planning time
62+
if (call.operands.size() >= 2 && call.operands.get(1) instanceof RexLiteral) {
63+
RexLiteral patternLiteral = (RexLiteral) call.operands.get(1);
64+
Expression literalPatternExpr = tryCompileLiteralPattern(patternLiteral, arrayExpr);
65+
if (literalPatternExpr != null) {
66+
return literalPatternExpr;
67+
}
68+
}
69+
70+
// For dynamic patterns, use evalWithString
71+
return Expressions.call(
72+
Types.lookupMethod(MVFindFunctionImpl.class, "evalWithString", List.class, String.class),
73+
arrayExpr,
74+
patternExpr);
75+
}
76+
77+
private static Expression tryCompileLiteralPattern(
78+
RexLiteral patternLiteral, Expression arrayExpr) {
79+
// Use getValueAs(String.class) to correctly unwrap Calcite NlsString
80+
String patternString = patternLiteral.getValueAs(String.class);
81+
if (patternString == null) {
82+
return null;
83+
}
84+
try {
85+
// Compile pattern at planning time and validate
86+
Pattern compiledPattern = Pattern.compile(patternString);
87+
// Generate code that uses the pre-compiled pattern
88+
return Expressions.call(
89+
Types.lookupMethod(
90+
MVFindFunctionImpl.class, "evalWithPattern", List.class, Pattern.class),
91+
arrayExpr,
92+
Expressions.constant(compiledPattern, Pattern.class));
93+
} catch (PatternSyntaxException e) {
94+
// Convert to IllegalArgumentException so it's treated as a client error (400)
95+
throw new IllegalArgumentException(
96+
String.format("Invalid regex pattern '%s': %s", patternString, e.getDescription()), e);
97+
}
98+
}
99+
}
100+
101+
private static Integer mvfindCore(List<Object> array, Pattern pattern) {
102+
for (int i = 0; i < array.size(); i++) {
103+
Object element = array.get(i);
104+
if (element != null) {
105+
String strValue = element.toString();
106+
if (pattern.matcher(strValue).find()) {
107+
return i; // Return 0-based index
108+
}
109+
}
110+
}
111+
return null; // No match found
112+
}
113+
114+
/**
115+
* Evaluates mvfind with a pre-compiled Pattern (for literal patterns compiled at planning time).
116+
* Any runtime exceptions from mvfindCore will propagate unchanged.
117+
*
118+
* @param array The array to search
119+
* @param pattern The pre-compiled regex pattern
120+
* @return The 0-based index of the first matching element, or null if no match
121+
*/
122+
public static Integer evalWithPattern(List<Object> array, Pattern pattern) {
123+
if (array == null || pattern == null) {
124+
return null;
125+
}
126+
return mvfindCore(array, pattern);
127+
}
128+
129+
/**
130+
* Evaluates mvfind with a string pattern (for dynamic patterns at runtime).
131+
*
132+
* @param array The array to search
133+
* @param regex The regex pattern string
134+
* @return The 0-based index of the first matching element, or null if no match
135+
*/
136+
public static Integer evalWithString(List<Object> array, String regex) {
137+
if (array == null || regex == null) {
138+
return null;
139+
}
140+
return mvfind(array, regex);
141+
}
142+
143+
/**
144+
* Evaluates mvfind with a String pattern. Compiles the regex pattern and executes search. Throws
145+
* IllegalArgumentException for invalid regex patterns; other runtime exceptions propagate
146+
* unchanged.
147+
*
148+
* @param array The array to search
149+
* @param regex The regex pattern string
150+
* @return The 0-based index of the first matching element, or null if no match
151+
* @throws IllegalArgumentException if the regex pattern is invalid
152+
*/
153+
private static Integer mvfind(List<Object> array, String regex) {
154+
if (array == null || regex == null) {
155+
return null;
156+
}
157+
158+
Pattern pattern;
159+
try {
160+
pattern = Pattern.compile(regex);
161+
} catch (PatternSyntaxException e) {
162+
// Invalid regex is a client error (400)
163+
throw new IllegalArgumentException(
164+
String.format("Invalid regex pattern '%s': %s", regex, e.getDescription()), e);
165+
}
166+
return mvfindCore(array, pattern);
167+
}
168+
}

0 commit comments

Comments
 (0)