Skip to content

Commit 2bbd937

Browse files
committed
[Enhancement](udf) support volatility for scalar UDFs
1 parent 3f3c79c commit 2bbd937

41 files changed

Lines changed: 532 additions & 65 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

fe/fe-catalog/src/main/java/org/apache/doris/catalog/Function.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ public enum BinaryType {
114114
protected String runtimeVersion;
115115
@SerializedName("fc")
116116
protected String functionCode;
117+
@SerializedName("vol")
118+
protected FunctionVolatility volatility = FunctionVolatility.IMMUTABLE;
117119

118120
// Only used for serialization
119121
protected Function() {
@@ -174,6 +176,7 @@ public Function(Function other) {
174176
this.expirationTime = other.expirationTime;
175177
this.runtimeVersion = other.runtimeVersion;
176178
this.functionCode = other.functionCode;
179+
this.volatility = other.getVolatility();
177180
}
178181

179182
public Function clone() {
@@ -301,6 +304,14 @@ public void setFunctionCode(String functionCode) {
301304
this.functionCode = functionCode;
302305
}
303306

307+
public FunctionVolatility getVolatility() {
308+
return volatility == null ? FunctionVolatility.IMMUTABLE : volatility;
309+
}
310+
311+
public void setVolatility(FunctionVolatility volatility) {
312+
this.volatility = volatility == null ? FunctionVolatility.IMMUTABLE : volatility;
313+
}
314+
304315
// TODO(cmy): Currently we judge whether it is UDF by wheter the 'location' is set.
305316
// Maybe we should use a separate variable to identify,
306317
// but additional variables need to modify the persistence information.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.catalog;
19+
20+
import java.util.Locale;
21+
22+
/** Function volatility controls which optimizer rewrites are safe for a function call. */
23+
public enum FunctionVolatility {
24+
IMMUTABLE,
25+
STABLE,
26+
VOLATILE;
27+
28+
public static FunctionVolatility fromString(String value) {
29+
if (value == null) {
30+
return IMMUTABLE;
31+
}
32+
try {
33+
return FunctionVolatility.valueOf(value.trim().toUpperCase(Locale.ROOT));
34+
} catch (IllegalArgumentException e) {
35+
throw new IllegalArgumentException("Invalid volatility: '" + value
36+
+ "'. Expected one of: immutable, stable, volatile", e);
37+
}
38+
}
39+
40+
public String toSql() {
41+
return name().toLowerCase(Locale.ROOT);
42+
}
43+
}

fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionToSqlConverter.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ public static String toSql(ScalarFunction fn, boolean ifNotExists) {
7575
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\"");
7676
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
7777
sb.append(",\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\"");
78+
sb.append(",\n \"VOLATILITY\"=").append("\"" + fn.getVolatility().toSql() + "\"");
79+
} else if (fn.getBinaryType() == Function.BinaryType.PYTHON_UDF) {
80+
sb.append(",\n \"FILE\"=")
81+
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\"");
82+
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
83+
sb.append(",\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\"");
84+
sb.append(",\n \"VOLATILITY\"=").append("\"" + fn.getVolatility().toSql() + "\"");
7885
} else {
7986
sb.append(",\n \"OBJECT_FILE\"=")
8087
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\"");
@@ -125,6 +132,11 @@ public static String toSql(AggregateFunction fn, boolean ifNotExists) {
125132
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\",");
126133
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
127134
sb.append("\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\",");
135+
} else if (fn.getBinaryType() == Function.BinaryType.PYTHON_UDF) {
136+
sb.append("\n \"FILE\"=")
137+
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\",");
138+
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
139+
sb.append("\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\",");
128140
} else {
129141
sb.append("\n \"OBJECT_FILE\"=")
130142
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\",");

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushDownFilterThroughProject.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC
4545
PhysicalProject<? extends Plan> project = (PhysicalProject<? extends Plan>) child;
4646
Map<Slot, Expression> childAlias = project.getAliasToProducer();
4747
if (filter.getInputSlots().stream().map(childAlias::get).filter(Objects::nonNull)
48-
.anyMatch(Expression::containsUniqueFunction)) {
48+
.anyMatch(Expression::containsVolatileExpression)) {
4949
return filter;
5050
}
5151
PhysicalFilter<? extends Plan> newFilter = filter.withConjunctsAndChild(

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,7 @@ private List<Expression> bindGroupByUniqueId(List<Expression> groupByExpressions
14191419
// 2. for 'group by a + random(), a + random() + 1', the two 'random()' will be different.
14201420
int containsUniqueGroupByCount = 0;
14211421
for (Expression groupByExpr : groupByExpressions) {
1422-
if (groupByExpr.containsUniqueFunction()) {
1422+
if (groupByExpr.containsVolatileExpression()) {
14231423
containsUniqueGroupByCount++;
14241424
}
14251425
}
@@ -1433,7 +1433,7 @@ private List<Expression> bindGroupByUniqueId(List<Expression> groupByExpressions
14331433
groupByExpressions.size());
14341434
for (Expression groupByExpr : groupByExpressions) {
14351435
Expression newGroupByExpr = groupByExpr;
1436-
if (groupByExpr.containsUniqueFunction()) {
1436+
if (groupByExpr.containsVolatileExpression()) {
14371437
Expression ignoreUniqueIdExpr = ExpressionUtils.setIgnoreUniqueIdForUniqueFunc(groupByExpr, true);
14381438
Expression previousGroupByExpr = ignoreUniqueIdGroupByExprs.get(ignoreUniqueIdExpr);
14391439
if (previousGroupByExpr == null) {
@@ -1476,7 +1476,7 @@ private <T extends Expression> List<T> bindExprsUniqueIdWithGroupBy(List<T> expr
14761476
// c) let E3 = rewrite E2 with enable unique ids. then E3 is the bind unique id expression for E.
14771477
private <T extends Expression> T bindExprUniqueIdWithGroupBy(T expression,
14781478
Map<Expression, Expression> bindUniqueIdReplaceMap) {
1479-
if (!expression.containsUniqueFunction() || bindUniqueIdReplaceMap.isEmpty()) {
1479+
if (!expression.containsVolatileExpression() || bindUniqueIdReplaceMap.isEmpty()) {
14801480
return expression;
14811481
}
14821482

@@ -1522,7 +1522,7 @@ private Map<Expression, Expression> getBelowAggregateGroupByUniqueFuncReplaceMap
15221522
private Map<Expression, Expression> getGroupByUniqueFuncReplaceMap(List<Expression> groupByByExpressions) {
15231523
Map<Expression, Expression> replaceMap = Maps.newHashMap();
15241524
for (Expression expression : groupByByExpressions) {
1525-
if (expression.containsUniqueFunction()) {
1525+
if (expression.containsVolatileExpression()) {
15261526
Expression ignoreUniqueIdExpr = ExpressionUtils.setIgnoreUniqueIdForUniqueFunc(expression, true);
15271527
// for sql:
15281528
// select distinct a + random(), a + random()
@@ -1554,7 +1554,7 @@ private Plan bindRepeat(MatchingContext<LogicalRepeat<Plan>> ctx) {
15541554
= ImmutableList.builderWithExpectedSize(boundGroupingSet.size());
15551555
for (Expression groupBy : boundGroupingSet) {
15561556
Expression newGroupBy = groupBy;
1557-
if (groupBy.containsUniqueFunction()) {
1557+
if (groupBy.containsVolatileExpression()) {
15581558
Expression ignoreUniqueIdGroupBy = ExpressionUtils.setIgnoreUniqueIdForUniqueFunc(groupBy, true);
15591559
Expression previousGroupBy = ignoreUniqueIdGroupByExpressions.get(ignoreUniqueIdGroupBy);
15601560
if (previousGroupBy == null) {

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
5050
matchesType(InPredicate.class)
5151
.when(inPredicate ->
5252
inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
53-
&& !inPredicate.getCompareExpr().containsUniqueFunction())
53+
&& !inPredicate.getCompareExpr().containsVolatileExpression())
5454
.then(this::rewrite)
5555
.toRule(ExpressionRuleType.IN_PREDICATE_EXTRACT_NON_CONSTANT)
5656
);

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranch.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private Optional<Expression> tryPushIntoNvl(Expression parent, int childIndex, N
166166
// so there will exist twice 'first' in the rewritten IF expression, which may increase the computation cost.
167167
// if the plan is not filter and not join, then push down action may not have positive effect,
168168
// considering this, we give up the rewrite if the plan is not condition plan or first contains unique function.
169-
if (first.containsUniqueFunction() || !isConditionPlan) {
169+
if (first.containsVolatileExpression() || !isConditionPlan) {
170170
return Optional.empty();
171171
}
172172
If ifExpr = new If(new IsNull(first), second, first);
@@ -182,7 +182,7 @@ private Optional<Expression> tryPushIntoNullIf(Expression parent, int childIndex
182182
// so there will exist twice 'first' in the rewritten IF expression, which may increase the computation cost.
183183
// if the plan is not filter and not join, then push down action may not have positive effect,
184184
// considering this, we give up the rewrite if the plan is not condition plan or first contains unique function.
185-
if (first.containsUniqueFunction() || !isConditionPlan) {
185+
if (first.containsVolatileExpression() || !isConditionPlan) {
186186
return Optional.empty();
187187
}
188188
If ifExpr = new If(new EqualTo(first, second), new NullLiteral(nullIf.getDataType()), first);

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import org.apache.doris.nereids.trees.expressions.NamedExpression;
2727
import org.apache.doris.nereids.trees.expressions.Slot;
2828
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
29+
import org.apache.doris.nereids.trees.expressions.VolatileExpression;
2930
import org.apache.doris.nereids.trees.expressions.functions.Function;
30-
import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction;
3131
import org.apache.doris.nereids.trees.plans.JoinType;
3232
import org.apache.doris.nereids.trees.plans.Plan;
3333
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@@ -274,22 +274,24 @@ public <T extends Expression> Optional<Pair<List<T>, LogicalProject<Plan>>> rewr
274274
*/
275275
@VisibleForTesting
276276
public List<NamedExpression> tryGenUniqueFunctionAlias(Collection<? extends Expression> targets) {
277-
Map<UniqueFunction, Integer> unqiueFunctionCounter = Maps.newLinkedHashMap();
277+
Map<Expression, Integer> unqiueFunctionCounter = Maps.newLinkedHashMap();
278278
for (Expression target : targets) {
279279
target.foreach(e -> {
280280
Expression expr = (Expression) e;
281-
if (expr instanceof UniqueFunction) {
282-
unqiueFunctionCounter.merge((UniqueFunction) expr, 1, Integer::sum);
281+
if (expr instanceof VolatileExpression && ((VolatileExpression) expr).isVolatile()) {
282+
unqiueFunctionCounter.merge(expr, 1, Integer::sum);
283283
}
284284
});
285285
}
286286

287287
ImmutableList.Builder<NamedExpression> builder
288288
= ImmutableList.builderWithExpectedSize(unqiueFunctionCounter.size());
289-
for (Entry<UniqueFunction, Integer> entry : unqiueFunctionCounter.entrySet()) {
289+
for (Entry<Expression, Integer> entry : unqiueFunctionCounter.entrySet()) {
290290
if (entry.getValue() > 1) {
291291
ExprId exprId = StatementScopeIdGenerator.newExprId();
292-
String name = "$_" + entry.getKey().getName() + "_" + exprId.asInt() + "_$";
292+
String functionName = entry.getKey() instanceof Function
293+
? ((Function) entry.getKey()).getName() : "volatile";
294+
String name = "$_" + functionName + "_" + exprId.asInt() + "_$";
293295
builder.add(new Alias(exprId, entry.getKey(), name));
294296
}
295297
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CollectFilterAboveConsumer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public Rule build() {
3838
LogicalCTEConsumer cteConsumer = filter.child();
3939
Set<Expression> exprs = filter.getConjuncts();
4040
for (Expression expr : exprs) {
41-
if (expr.containsUniqueFunction()) {
41+
if (expr.containsVolatileExpression()) {
4242
continue;
4343
}
4444
Expression rewrittenExpr = expr.rewriteUp(e -> {

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private boolean needRewrite(LogicalJoin<Plan, Plan> join) {
108108

109109
// 1. expr contains slots from both sides;
110110
private boolean isConditionNeedRewrite(Expression expr, Set<Slot> leftSlots, Set<Slot> rightSlots) {
111-
if (expr.containsUniqueFunction()) {
111+
if (expr.containsVolatileExpression()) {
112112
return false;
113113
}
114114
Set<Slot> exprSlots = expr.getInputSlots();

0 commit comments

Comments
 (0)