Skip to content

Commit 2987c8e

Browse files
committed
fix test
1 parent 2bbd937 commit 2987c8e

8 files changed

Lines changed: 302 additions & 9 deletions

File tree

fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionToSqlConverter.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.apache.doris.analysis.ToSqlParams;
2424
import org.apache.doris.catalog.Function.NullableMode;
2525

26+
import com.google.common.base.Strings;
27+
2628
import java.util.List;
2729
import java.util.stream.Collectors;
2830

@@ -81,13 +83,19 @@ public static String toSql(ScalarFunction fn, boolean ifNotExists) {
8183
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\"");
8284
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
8385
sb.append(",\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\"");
86+
sb.append(",\n \"RUNTIME_VERSION\"=").append("\"" + Strings.nullToEmpty(fn.getRuntimeVersion()) + "\"");
8487
sb.append(",\n \"VOLATILITY\"=").append("\"" + fn.getVolatility().toSql() + "\"");
8588
} else {
8689
sb.append(",\n \"OBJECT_FILE\"=")
8790
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\"");
8891
}
8992
sb.append(",\n \"TYPE\"=").append("\"" + fn.getBinaryType() + "\"");
90-
sb.append("\n);");
93+
if (fn.getBinaryType() == Function.BinaryType.PYTHON_UDF && !Strings.isNullOrEmpty(fn.getFunctionCode())) {
94+
// Preserve inline Python UDF bodies so SHOW CREATE FUNCTION output can be replayed directly.
95+
sb.append("\n)\nAS $$\n").append(fn.getFunctionCode()).append("\n$$;");
96+
} else {
97+
sb.append("\n);");
98+
}
9199
return sb.toString();
92100
}
93101

@@ -137,12 +145,19 @@ public static String toSql(AggregateFunction fn, boolean ifNotExists) {
137145
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\",");
138146
boolean isReturnNull = fn.getNullableMode() == NullableMode.ALWAYS_NULLABLE;
139147
sb.append("\n \"ALWAYS_NULLABLE\"=").append("\"" + isReturnNull + "\",");
148+
sb.append("\n \"RUNTIME_VERSION\"=")
149+
.append("\"" + Strings.nullToEmpty(fn.getRuntimeVersion()) + "\",");
140150
} else {
141151
sb.append("\n \"OBJECT_FILE\"=")
142152
.append("\"" + (fn.getLocation() == null ? "" : fn.getLocation().toString()) + "\",");
143153
}
144154
sb.append("\n \"TYPE\"=").append("\"" + fn.getBinaryType() + "\"");
145-
sb.append("\n);");
155+
if (fn.getBinaryType() == Function.BinaryType.PYTHON_UDF && !Strings.isNullOrEmpty(fn.getFunctionCode())) {
156+
// Preserve inline Python UDAF bodies so SHOW CREATE FUNCTION output can be replayed directly.
157+
sb.append("\n)\nAS $$\n").append(fn.getFunctionCode()).append("\n$$;");
158+
} else {
159+
sb.append("\n);");
160+
}
146161
return sb.toString();
147162
}
148163

fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ public void test() throws Exception {
120120
+ "'runtime_version'='3.10.2', 'volatility'='stable');";
121121
createFunction(pythonUdfSql, ctx);
122122
Assert.assertEquals(2, db.getFunctions().size());
123-
Function pythonFn = db.getFunctions().get(1);
123+
Function pythonFn = findFunction(db, "py_stable");
124124
Assert.assertEquals(FunctionVolatility.STABLE, pythonFn.getVolatility());
125125
Assert.assertTrue(FunctionToSqlConverter.toSql(pythonFn, false).contains("\"VOLATILITY\"=\"stable\""));
126126

127127
String defaultVolatileSql = "create function db1.py_default(int) returns int "
128128
+ "properties('type'='PYTHON_UDF', 'symbol'='evaluate', 'runtime_version'='3.10.2');";
129129
createFunction(defaultVolatileSql, ctx);
130-
Assert.assertEquals(FunctionVolatility.VOLATILE, db.getFunctions().get(2).getVolatility());
130+
Assert.assertEquals(FunctionVolatility.VOLATILE, findFunction(db, "py_default").getVolatility());
131131
}
132132

133133
@Test
@@ -218,4 +218,13 @@ private void createFunction(String sql, ConnectContext connectContext) throws Ex
218218
private boolean containsIgnoreCase(String str, String sub) {
219219
return str.toLowerCase().contains(sub.toLowerCase());
220220
}
221+
222+
private Function findFunction(Database db, String functionName) {
223+
for (Function function : db.getFunctions()) {
224+
if (functionName.equals(function.functionName())) {
225+
return function;
226+
}
227+
}
228+
throw new AssertionError("function not found: " + functionName);
229+
}
221230
}

fe/fe-core/src/test/java/org/apache/doris/catalog/FunctionToSqlConverterTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import org.apache.doris.catalog.Function.BinaryType;
2121
import org.apache.doris.catalog.Function.NullableMode;
22+
import org.apache.doris.common.AnalysisException;
23+
import org.apache.doris.common.util.URI;
2224

2325
import org.junit.jupiter.api.Assertions;
2426
import org.junit.jupiter.api.Test;
@@ -104,6 +106,43 @@ void testScalarFunction_javaUdf_withoutPrepareFnAndCloseFn() {
104106
Assertions.assertFalse(sql.contains("CLOSE_FN"));
105107
}
106108

109+
@Test
110+
void testScalarFunction_pythonUdf_inlineReplaySql() {
111+
FunctionName name = new FunctionName("testDb", "py_inline");
112+
Type[] argTypes = {Type.INT};
113+
ScalarFunction fn = ScalarFunction.createUdf(BinaryType.PYTHON_UDF, name, argTypes,
114+
Type.INT, false, null, "evaluate", null, null);
115+
fn.setRuntimeVersion("3.10.2");
116+
fn.setFunctionCode("def evaluate(x):\n return x + 1");
117+
fn.setVolatility(FunctionVolatility.IMMUTABLE);
118+
119+
String sql = FunctionToSqlConverter.toSql(fn, false);
120+
121+
Assertions.assertTrue(sql.contains("\"RUNTIME_VERSION\"=\"3.10.2\""));
122+
Assertions.assertTrue(sql.contains("\"VOLATILITY\"=\"immutable\""));
123+
Assertions.assertTrue(sql.contains("\"TYPE\"=\"PYTHON_UDF\""));
124+
Assertions.assertTrue(sql.contains("AS $$\ndef evaluate(x):\n return x + 1\n$$;"));
125+
Assertions.assertFalse(sql.endsWith(");"));
126+
}
127+
128+
@Test
129+
void testScalarFunction_pythonUdf_moduleReplaySql() throws AnalysisException {
130+
FunctionName name = new FunctionName("testDb", "py_module");
131+
Type[] argTypes = {Type.INT};
132+
ScalarFunction fn = ScalarFunction.createUdf(BinaryType.PYTHON_UDF, name, argTypes,
133+
Type.INT, false, URI.create("file:///tmp/pyudf.zip"), "pkg.mod.evaluate", null, null);
134+
fn.setRuntimeVersion("3.10.2");
135+
fn.setVolatility(FunctionVolatility.STABLE);
136+
137+
String sql = FunctionToSqlConverter.toSql(fn, false);
138+
139+
Assertions.assertTrue(sql.contains("\"FILE\"=\"file:///tmp/pyudf.zip\""));
140+
Assertions.assertTrue(sql.contains("\"RUNTIME_VERSION\"=\"3.10.2\""));
141+
Assertions.assertTrue(sql.contains("\"VOLATILITY\"=\"stable\""));
142+
Assertions.assertTrue(sql.endsWith(");"));
143+
Assertions.assertFalse(sql.contains("AS $$"));
144+
}
145+
107146
// ======================== ScalarFunction — IF NOT EXISTS ========================
108147

109148
@Test
@@ -212,6 +251,30 @@ void testAggregateFunction_javaUdf_ifNotExists() {
212251
Assertions.assertTrue(sql.contains("CREATE AGGREGATE FUNCTION IF NOT EXISTS "));
213252
}
214253

254+
@Test
255+
void testAggregateFunction_pythonUdf_inlineReplaySql() {
256+
FunctionName name = new FunctionName("testDb", "py_agg");
257+
Type[] argTypes = {Type.INT};
258+
AggregateFunction fn = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder()
259+
.binaryType(BinaryType.PYTHON_UDF)
260+
.name(name)
261+
.argsType(argTypes)
262+
.retType(Type.INT)
263+
.intermediateType(Type.INT)
264+
.hasVarArgs(false)
265+
.symbolName("SumState")
266+
.build();
267+
fn.setRuntimeVersion("3.10.2");
268+
fn.setFunctionCode("class SumState:\n pass");
269+
270+
String sql = FunctionToSqlConverter.toSql(fn, false);
271+
272+
Assertions.assertTrue(sql.contains("\"RUNTIME_VERSION\"=\"3.10.2\""));
273+
Assertions.assertTrue(sql.contains("\"TYPE\"=\"PYTHON_UDF\""));
274+
Assertions.assertTrue(sql.contains("AS $$\nclass SumState:\n pass\n$$;"));
275+
Assertions.assertFalse(sql.endsWith(");"));
276+
}
277+
215278
// ======================== AggregateFunction — NATIVE ========================
216279

217280
@Test

regression-test/suites/javaudf_p0/test_javaudf_float.groovy

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ suite("test_javaudf_float") {
5656
sql """ CREATE FUNCTION java_udf_float_test(FLOAT,FLOAT) RETURNS FLOAT PROPERTIES (
5757
"file"="file://${jarPath}",
5858
"symbol"="org.apache.doris.udf.FloatTest",
59-
"type"="JAVA_UDF"
59+
"type"="JAVA_UDF",
60+
"volatility"="immutable"
6061
); """
6162

6263
qt_select """ SELECT java_udf_float_test(cast(2.83645 as float),cast(111.1111111 as float)) as result; """

regression-test/suites/mtmv_p0/test_expand_star_mtmv.groovy

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ suite("test_expand_star_mtmv","mtmv") {
6262
sql """ CREATE FUNCTION ${functionName}(date, date) RETURNS boolean PROPERTIES (
6363
"file"="file://${jarPath}",
6464
"symbol"="org.apache.doris.udf.DateTest1",
65-
"type"="JAVA_UDF"
65+
"type"="JAVA_UDF",
66+
"volatility"="immutable"
6667
); """
6768

6869
sql """

regression-test/suites/pythonudf_p0/test_pythonudf_aggregate.groovy

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ suite("test_pythonudf_aggregate") {
2828
PROPERTIES (
2929
"type" = "PYTHON_UDF",
3030
"symbol" = "evaluate",
31-
"runtime_version" = "${runtime_version}"
31+
"runtime_version" = "${runtime_version}",
32+
"volatility" = "immutable"
3233
)
3334
AS \$\$
3435
def evaluate(score):
@@ -120,7 +121,8 @@ def evaluate(score):
120121
PROPERTIES (
121122
"type" = "PYTHON_UDF",
122123
"symbol" = "evaluate",
123-
"runtime_version" = "${runtime_version}"
124+
"runtime_version" = "${runtime_version}",
125+
"volatility" = "immutable"
124126
)
125127
AS \$\$
126128
def evaluate(age):

regression-test/suites/pythonudf_p0/test_pythonudf_float.groovy

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ suite("test_pythonudf_float") {
5353
"symbol"="float_test.evaluate",
5454
"type"="PYTHON_UDF",
5555
"runtime_version" = "${runtime_version}",
56-
"always_nullable" = "true"
56+
"always_nullable" = "true",
57+
"volatility" = "immutable"
5758
); """
5859

5960
qt_select """ SELECT python_udf_float_test(cast(2.83645 as float),cast(111.1111111 as float)) as result; """

0 commit comments

Comments
 (0)