Skip to content

Commit ac5994e

Browse files
committed
General UDAF pushdown as scripts
Signed-off-by: Songkan Tang <songkant@amazon.com>
1 parent 65baa2a commit ac5994e

25 files changed

Lines changed: 2033 additions & 93 deletions

common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java

Lines changed: 493 additions & 0 deletions
Large diffs are not rendered by default.

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,68 +3267,114 @@ private void flattenParsedPattern(
32673267
Boolean showNumberedToken) {
32683268
List<RexNode> fattenedNodes = new ArrayList<>();
32693269
List<String> projectNames = new ArrayList<>();
3270-
// Flatten map struct fields
3270+
3271+
// For aggregation mode with numbered tokens, we need to compute tokens locally
3272+
// using evalAggSamples. The UDAF returns pattern with wildcards and sample_logs,
3273+
// but NOT tokens (to avoid XContent serialization issues with nested Maps).
3274+
RexNode parsedPatternResult = null;
3275+
if (flattenPatternAggResult && showNumberedToken) {
3276+
// Extract pattern string (with wildcards) from UDAF result
3277+
RexNode patternStr =
3278+
PPLFuncImpTable.INSTANCE.resolve(
3279+
context.rexBuilder,
3280+
BuiltinFunctionName.INTERNAL_ITEM,
3281+
parsedNode,
3282+
context.rexBuilder.makeLiteral(PatternUtils.PATTERN));
3283+
// Extract sample_logs from UDAF result
3284+
RexNode sampleLogs =
3285+
PPLFuncImpTable.INSTANCE.resolve(
3286+
context.rexBuilder,
3287+
BuiltinFunctionName.INTERNAL_ITEM,
3288+
explicitMapType(context, parsedNode, SqlTypeName.VARCHAR),
3289+
context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS));
3290+
RexNode showNumberedTokenLiteral = context.rexBuilder.makeLiteral(true);
3291+
3292+
// Call evalAggSamples to transform pattern (wildcards -> numbered tokens) and compute tokens
3293+
parsedPatternResult =
3294+
PPLFuncImpTable.INSTANCE.resolve(
3295+
context.rexBuilder,
3296+
BuiltinFunctionName.INTERNAL_PATTERN_PARSER,
3297+
patternStr,
3298+
sampleLogs,
3299+
showNumberedTokenLiteral);
3300+
}
3301+
3302+
// Flatten map struct fields - pattern
3303+
RelDataType varcharType =
3304+
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR);
3305+
RexNode patternSource = parsedPatternResult != null ? parsedPatternResult : parsedNode;
32713306
RexNode patternExpr =
3272-
context.rexBuilder.makeCast(
3273-
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
3274-
PPLFuncImpTable.INSTANCE.resolve(
3275-
context.rexBuilder,
3276-
BuiltinFunctionName.INTERNAL_ITEM,
3277-
parsedNode,
3278-
context.rexBuilder.makeLiteral(PatternUtils.PATTERN)),
3279-
true,
3280-
true);
3307+
extractAndCastMapField(context, patternSource, PatternUtils.PATTERN, varcharType);
32813308
fattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias));
32823309
projectNames.add(originalPatternResultAlias);
3310+
32833311
if (flattenPatternAggResult) {
3312+
RelDataType bigintType =
3313+
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT);
32843314
RexNode patternCountExpr =
3285-
context.rexBuilder.makeCast(
3286-
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT),
3287-
PPLFuncImpTable.INSTANCE.resolve(
3288-
context.rexBuilder,
3289-
BuiltinFunctionName.INTERNAL_ITEM,
3290-
parsedNode,
3291-
context.rexBuilder.makeLiteral(PatternUtils.PATTERN_COUNT)),
3292-
true,
3293-
true);
3315+
extractAndCastMapField(context, parsedNode, PatternUtils.PATTERN_COUNT, bigintType);
32943316
fattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT));
32953317
projectNames.add(PatternUtils.PATTERN_COUNT);
32963318
}
3319+
32973320
if (showNumberedToken) {
3321+
// Create MAP<VARCHAR, ARRAY<VARCHAR>> type for tokens
3322+
RelDataType tokensType =
3323+
context
3324+
.rexBuilder
3325+
.getTypeFactory()
3326+
.createMapType(
3327+
varcharType,
3328+
context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1));
3329+
RexNode tokensSource = parsedPatternResult != null ? parsedPatternResult : parsedNode;
32983330
RexNode tokensExpr =
3299-
context.rexBuilder.makeCast(
3300-
UserDefinedFunctionUtils.tokensMap,
3301-
PPLFuncImpTable.INSTANCE.resolve(
3302-
context.rexBuilder,
3303-
BuiltinFunctionName.INTERNAL_ITEM,
3304-
parsedNode,
3305-
context.rexBuilder.makeLiteral(PatternUtils.TOKENS)),
3306-
true,
3307-
true);
3331+
extractAndCastMapField(context, tokensSource, PatternUtils.TOKENS, tokensType);
33083332
fattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS));
33093333
projectNames.add(PatternUtils.TOKENS);
33103334
}
3335+
33113336
if (flattenPatternAggResult) {
3337+
RelDataType sampleLogsArrayType =
3338+
context
3339+
.rexBuilder
3340+
.getTypeFactory()
3341+
.createArrayType(
3342+
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1);
33123343
RexNode sampleLogsExpr =
3313-
context.rexBuilder.makeCast(
3314-
context
3315-
.rexBuilder
3316-
.getTypeFactory()
3317-
.createArrayType(
3318-
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1),
3319-
PPLFuncImpTable.INSTANCE.resolve(
3320-
context.rexBuilder,
3321-
BuiltinFunctionName.INTERNAL_ITEM,
3322-
explicitMapType(context, parsedNode, SqlTypeName.VARCHAR),
3323-
context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)),
3324-
true,
3325-
true);
3344+
extractAndCastMapField(
3345+
context,
3346+
explicitMapType(context, parsedNode, SqlTypeName.VARCHAR),
3347+
PatternUtils.SAMPLE_LOGS,
3348+
sampleLogsArrayType);
33263349
fattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS));
33273350
projectNames.add(PatternUtils.SAMPLE_LOGS);
33283351
}
33293352
projectPlusOverriding(fattenedNodes, projectNames, context);
33303353
}
33313354

3355+
/**
3356+
* Helper method to extract a field from a map and cast it to the specified type. Creates a
3357+
* SAFE_CAST (makeCast with safe=true) around an INTERNAL_ITEM call.
3358+
*
3359+
* @param context The Calcite plan context
3360+
* @param source The source RexNode containing the map
3361+
* @param fieldName The name of the field to extract from the map
3362+
* @param targetType The target type to cast to
3363+
* @return A RexNode representing SAFE_CAST(INTERNAL_ITEM(source, fieldName))
3364+
*/
3365+
private RexNode extractAndCastMapField(
3366+
CalcitePlanContext context, RexNode source, String fieldName, RelDataType targetType) {
3367+
return context.rexBuilder.makeCast(
3368+
targetType,
3369+
PPLFuncImpTable.INSTANCE.resolve(
3370+
context.rexBuilder,
3371+
BuiltinFunctionName.INTERNAL_ITEM,
3372+
source,
3373+
context.rexBuilder.makeLiteral(fieldName)),
3374+
true,
3375+
true);
3376+
}
3377+
33323378
private void buildExpandRelNode(
33333379
RexInputRef arrayFieldRex, String arrayFieldName, String alias, CalcitePlanContext context) {
33343380
// 3. Capture the outer row in a CorrelationId

core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.opensearch.sql.common.antlr.SyntaxCheckException;
2121
import org.opensearch.sql.common.patterns.BrainLogParser;
2222
import org.opensearch.sql.common.patterns.PatternUtils;
23-
import org.opensearch.sql.common.patterns.PatternUtils.ParseResult;
2423

2524
public class LogPatternAggFunction implements UserDefinedAggFunction<LogParserAccumulator> {
2625
private int bufferLimit = 100000;
@@ -190,7 +189,6 @@ public Object value(Object... argList) {
190189
partialMerge(argList);
191190
clearBuffer();
192191

193-
Boolean showToken = (Boolean) argList[3];
194192
return patternGroupMap.values().stream()
195193
.sorted(
196194
Comparator.comparing(
@@ -201,24 +199,18 @@ public Object value(Object... argList) {
201199
String pattern = (String) m.get(PatternUtils.PATTERN);
202200
Long count = (Long) m.get(PatternUtils.PATTERN_COUNT);
203201
List<String> sampleLogs = (List<String>) m.get(PatternUtils.SAMPLE_LOGS);
204-
Map<String, List<String>> tokensMap = new HashMap<>();
205-
ParseResult parseResult = null;
206-
if (showToken) {
207-
parseResult = PatternUtils.parsePattern(pattern, PatternUtils.WILDCARD_PATTERN);
208-
for (String sampleLog : sampleLogs) {
209-
PatternUtils.extractVariables(
210-
parseResult, sampleLog, tokensMap, PatternUtils.WILDCARD_PREFIX);
211-
}
212-
}
202+
// For aggregation mode, always return pattern with wildcards (<*>, <*IP*>).
203+
// The transformation to numbered tokens (<token1>, <token2>) and token
204+
// extraction is done downstream by evalAggSamples in flattenParsedPattern.
205+
// This ensures consistent behavior between UDAF pushdown and regular
206+
// aggregation paths.
213207
return ImmutableMap.of(
214208
PatternUtils.PATTERN,
215-
showToken
216-
? parseResult.toTokenOrderString(PatternUtils.WILDCARD_PREFIX)
217-
: pattern,
209+
pattern, // Always return original wildcard format
218210
PatternUtils.PATTERN_COUNT,
219211
count,
220212
PatternUtils.TOKENS,
221-
showToken ? tokensMap : Collections.EMPTY_MAP,
213+
Collections.EMPTY_MAP, // Tokens computed downstream by evalAggSamples
222214
PatternUtils.SAMPLE_LOGS,
223215
sampleLogs);
224216
})

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.*;
1212

1313
import com.google.common.collect.ImmutableSet;
14+
import java.lang.reflect.Type;
1415
import java.time.Instant;
1516
import java.time.ZoneId;
1617
import java.time.ZoneOffset;
@@ -209,7 +210,7 @@ public static List<Expression> convertToExprValues(
209210
* @return an adapted ImplementorUDF with the expr method, which is a UserDefinedFunctionBuilder
210211
*/
211212
public static ImplementorUDF adaptExprMethodToUDF(
212-
java.lang.reflect.Type type,
213+
Type type,
213214
String methodName,
214215
SqlReturnTypeInference returnTypeInference,
215216
NullPolicy nullPolicy,
@@ -240,7 +241,7 @@ public UDFOperandMetadata getOperandMetadata() {
240241
* FunctionProperties} at the beginning to a Calcite-compatible UserDefinedFunctionBuilder.
241242
*/
242243
public static ImplementorUDF adaptExprMethodWithPropertiesToUDF(
243-
java.lang.reflect.Type type,
244+
Type type,
244245
String methodName,
245246
SqlReturnTypeInference returnTypeInference,
246247
NullPolicy nullPolicy,
@@ -317,4 +318,44 @@ public static List<Expression> prependFunctionProperties(
317318
operandsWithProperties.addFirst(properties);
318319
return Collections.unmodifiableList(operandsWithProperties);
319320
}
321+
322+
/**
323+
* Adapt a static method from any class to a UserDefinedFunctionBuilder. This is a general-purpose
324+
* adapter that can wrap static helper methods (e.g., PatternAggregationHelpers methods) as UDFs
325+
* for use in scripted metrics.
326+
*
327+
* @param type the class containing the static method
328+
* @param methodName the name of the static method to be invoked
329+
* @param returnTypeInference the return type inference of the UDF
330+
* @param nullPolicy the null policy of the UDF
331+
* @param operandMetadata type checker for operands
332+
* @return an adapted ImplementorUDF wrapping the static method
333+
*/
334+
public static ImplementorUDF adaptStaticMethodToUDF(
335+
Type type,
336+
String methodName,
337+
SqlReturnTypeInference returnTypeInference,
338+
NullPolicy nullPolicy,
339+
@Nullable UDFOperandMetadata operandMetadata) {
340+
341+
NotNullImplementor implementor =
342+
(translator, call, translatedOperands) -> {
343+
// For static methods that work with generic objects (Map, List, etc.),
344+
// we don't need type conversion like adaptMathFunctionToUDF
345+
// Just pass the operands directly to the static method
346+
return Expressions.call(type, methodName, translatedOperands);
347+
};
348+
349+
return new ImplementorUDF(implementor, nullPolicy) {
350+
@Override
351+
public SqlReturnTypeInference getReturnTypeInference() {
352+
return returnTypeInference;
353+
}
354+
355+
@Override
356+
public UDFOperandMetadata getOperandMetadata() {
357+
return operandMetadata;
358+
}
359+
};
360+
}
320361
}

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ public enum BuiltinFunctionName {
346346
INTERNAL_PATTERN_PARSER(FunctionName.of("pattern_parser")),
347347
INTERNAL_PATTERN(FunctionName.of("pattern")),
348348
INTERNAL_UNCOLLECT_PATTERNS(FunctionName.of("uncollect_patterns")),
349+
// Pattern aggregation UDFs for scripted metric pushdown
350+
PATTERN_INIT_UDF(FunctionName.of("pattern_init_udf"), true),
351+
PATTERN_ADD_UDF(FunctionName.of("pattern_add_udf"), true),
352+
PATTERN_COMBINE_UDF(FunctionName.of("pattern_combine_udf"), true),
353+
PATTERN_RESULT_UDF(FunctionName.of("pattern_result_udf"), true),
349354
INTERNAL_GROK(FunctionName.of("grok"), true),
350355
INTERNAL_PARSE(FunctionName.of("parse"), true),
351356
INTERNAL_REGEXP_REPLACE_PG_4(FunctionName.of("regexp_replace_pg_4"), true),

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@
2121
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
2222
import org.apache.calcite.avatica.util.TimeUnit;
2323
import org.apache.calcite.linq4j.tree.Expression;
24+
import org.apache.calcite.rel.type.RelDataType;
25+
import org.apache.calcite.rel.type.RelDataTypeFactory;
2426
import org.apache.calcite.rex.RexCall;
2527
import org.apache.calcite.sql.SqlAggFunction;
2628
import org.apache.calcite.sql.SqlKind;
2729
import org.apache.calcite.sql.SqlOperator;
2830
import org.apache.calcite.sql.type.ReturnTypes;
31+
import org.apache.calcite.sql.type.SqlTypeName;
2932
import org.apache.calcite.sql.type.SqlTypeTransforms;
33+
import org.apache.calcite.sql.type.SqlTypeUtil;
3034
import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable;
3135
import org.apache.calcite.util.BuiltInMethod;
3236
import org.opensearch.sql.calcite.udf.udaf.FirstAggFunction;
@@ -40,6 +44,7 @@
4044
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
4145
import org.opensearch.sql.calcite.utils.PPLReturnTypes;
4246
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
47+
import org.opensearch.sql.common.patterns.PatternAggregationHelpers;
4348
import org.opensearch.sql.data.type.ExprCoreType;
4449
import org.opensearch.sql.expression.datetime.DateTimeFunctions;
4550
import org.opensearch.sql.expression.function.CollectionUDF.AppendFunctionImpl;
@@ -482,6 +487,49 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
482487
PPLReturnTypes.STRING_ARRAY,
483488
PPLOperandTypes.ANY_SCALAR_OPTIONAL_INTEGER);
484489

490+
// Pattern aggregation helper UDFs for scripted metric pushdown
491+
// This UDF takes state as parameter and modifies it in-place (for OpenSearch scripted metric)
492+
public static final SqlOperator PATTERN_INIT_UDF =
493+
UserDefinedFunctionUtils.adaptStaticMethodToUDF(
494+
PatternAggregationHelpers.class,
495+
"initPatternState",
496+
ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map<String, Object>
497+
NullPolicy.ANY,
498+
null) // Takes state as parameter
499+
.toUDF("PATTERN_INIT_UDF");
500+
501+
public static final SqlOperator PATTERN_ADD_UDF =
502+
UserDefinedFunctionUtils.adaptStaticMethodToUDF(
503+
PatternAggregationHelpers.class,
504+
"addLogToPattern",
505+
ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map<String, Object>
506+
NullPolicy.ANY,
507+
null) // TODO: Add proper operand type checking
508+
.toUDF("PATTERN_ADD_UDF");
509+
510+
public static final SqlOperator PATTERN_COMBINE_UDF =
511+
UserDefinedFunctionUtils.adaptStaticMethodToUDF(
512+
PatternAggregationHelpers.class,
513+
"combinePatternAccumulators",
514+
ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map<String, Object>
515+
NullPolicy.ANY,
516+
null) // TODO: Add proper operand type checking
517+
.toUDF("PATTERN_COMBINE_UDF");
518+
519+
public static final SqlOperator PATTERN_RESULT_UDF =
520+
UserDefinedFunctionUtils.adaptStaticMethodToUDF(
521+
PatternAggregationHelpers.class,
522+
"producePatternResultFromStates",
523+
opBinding -> {
524+
// Returns List<Map<String, Object>> - represented as ARRAY<ANY>
525+
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
526+
RelDataType anyType = typeFactory.createSqlType(SqlTypeName.ANY);
527+
return SqlTypeUtil.createArrayType(typeFactory, anyType, true);
528+
},
529+
NullPolicy.ANY,
530+
null) // TODO: Add proper operand type checking
531+
.toUDF("PATTERN_RESULT_UDF");
532+
485533
public static final SqlOperator ENHANCED_COALESCE =
486534
new EnhancedCoalesceFunction().toUDF("COALESCE");
487535

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@
163163
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOW;
164164
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NULLIF;
165165
import static org.opensearch.sql.expression.function.BuiltinFunctionName.OR;
166+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_ADD_UDF;
167+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_COMBINE_UDF;
168+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_INIT_UDF;
169+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_RESULT_UDF;
166170
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERCENTILE_APPROX;
167171
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_ADD;
168172
import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_DIFF;
@@ -981,6 +985,11 @@ void populate() {
981985
registerOperator(WEEKOFYEAR, PPLBuiltinOperators.WEEK);
982986

983987
registerOperator(INTERNAL_PATTERN_PARSER, PPLBuiltinOperators.PATTERN_PARSER);
988+
// Register pattern aggregation helper UDFs for scripted metric pushdown
989+
registerOperator(PATTERN_INIT_UDF, PPLBuiltinOperators.PATTERN_INIT_UDF);
990+
registerOperator(PATTERN_ADD_UDF, PPLBuiltinOperators.PATTERN_ADD_UDF);
991+
registerOperator(PATTERN_COMBINE_UDF, PPLBuiltinOperators.PATTERN_COMBINE_UDF);
992+
registerOperator(PATTERN_RESULT_UDF, PPLBuiltinOperators.PATTERN_RESULT_UDF);
984993
registerOperator(TONUMBER, PPLBuiltinOperators.TONUMBER);
985994
registerOperator(TOSTRING, PPLBuiltinOperators.TOSTRING);
986995
register(

0 commit comments

Comments
 (0)