Skip to content

Commit 6568862

Browse files
committed
[GLUTEN-12157][VL] Register sin, tan, tanh, radians, ln in Velox sparksql function registry
1 parent 6f508d2 commit 6568862

3 files changed

Lines changed: 60 additions & 3 deletions

File tree

backends-velox/src/test/scala/org/apache/gluten/functions/MathFunctionsValidateSuite.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ class MathFunctionsValidateSuiteAnsiOn extends FunctionsValidateSuite {
6666
}
6767
}
6868

69-
abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite {
69+
class MathFunctionsValidateSuite extends FunctionsValidateSuite {
70+
71+
// Disable ANSI mode: Spark 4 enables it by default, which wraps math functions
72+
// in ANSI check nodes and prevents ProjectExecTransformer from being the top-level
73+
// plan node. ANSI-specific behaviour is tested in MathFunctionsValidateSuiteAnsiOn.
74+
override protected def sparkConf: SparkConf = {
75+
super.sparkConf
76+
.set(SQLConf.ANSI_ENABLED.key, "false")
77+
}
7078

7179
disableFallbackCheck
7280
import testImplicits._
@@ -248,6 +256,12 @@ abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite {
248256
}
249257
}
250258

259+
test("ln") {
260+
runQueryAndCompare("SELECT ln(l_orderkey) from lineitem limit 1") {
261+
checkGlutenPlan[ProjectExecTransformer]
262+
}
263+
}
264+
251265
test("log") {
252266
runQueryAndCompare("SELECT log(10, l_orderkey) from lineitem limit 1") {
253267
checkGlutenPlan[ProjectExecTransformer]
@@ -295,6 +309,12 @@ abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite {
295309
}
296310
}
297311

312+
test("radians") {
313+
runQueryAndCompare("SELECT radians(l_orderkey) from lineitem limit 1") {
314+
checkGlutenPlan[ProjectExecTransformer]
315+
}
316+
}
317+
298318
test("rint") {
299319
withTempPath {
300320
path =>
@@ -332,6 +352,24 @@ abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite {
332352
}
333353
}
334354

355+
test("sin") {
356+
runQueryAndCompare("SELECT sin(l_orderkey) from lineitem limit 1") {
357+
checkGlutenPlan[ProjectExecTransformer]
358+
}
359+
}
360+
361+
test("tan") {
362+
runQueryAndCompare("SELECT tan(l_orderkey) from lineitem limit 1") {
363+
checkGlutenPlan[ProjectExecTransformer]
364+
}
365+
}
366+
367+
test("tanh") {
368+
runQueryAndCompare("SELECT tanh(l_orderkey) from lineitem limit 1") {
369+
checkGlutenPlan[ProjectExecTransformer]
370+
}
371+
}
372+
335373
test("try_add") {
336374
runQueryAndCompare(
337375
"select try_add(cast(l_orderkey as int), 1), try_add(cast(l_orderkey as int), 2147483647)" +

backends-velox/src/test/scala/org/apache/gluten/functions/ScalarFunctionsValidateSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,23 @@ package org.apache.gluten.functions
1919
import org.apache.gluten.config.GlutenConfig
2020
import org.apache.gluten.execution.{BatchScanExecTransformer, FilterExecTransformer, ProjectExecTransformer}
2121

22-
import org.apache.spark.SparkException
22+
import org.apache.spark.{SparkConf, SparkException}
2323
import org.apache.spark.sql.Row
2424
import org.apache.spark.sql.catalyst.optimizer.NullPropagation
2525
import org.apache.spark.sql.execution.ProjectExec
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types._
2828

29-
abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
29+
class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
30+
31+
// Disable ANSI mode: Spark 4 enables it by default, which wraps scalar functions
32+
// in ANSI check nodes and prevents ProjectExecTransformer from being the top-level
33+
// plan node.
34+
override protected def sparkConf: SparkConf = {
35+
super.sparkConf
36+
.set(SQLConf.ANSI_ENABLED.key, "false")
37+
}
38+
3039
disableFallbackCheck
3140

3241
import testImplicits._

cpp/velox/operators/functions/RegistrationAllFunctions.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "velox/functions/iceberg/Register.h"
2525
#include "velox/functions/lib/CheckedArithmetic.h"
2626
#include "velox/functions/lib/RegistrationHelpers.h"
27+
#include "velox/functions/prestosql/Arithmetic.h"
2728
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
2829
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
2930
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -76,6 +77,15 @@ void registerFunctionOverwrite() {
7677
kRowConstructorWithAllNull,
7778
std::make_unique<RowConstructorWithNullCallToSpecialForm>(kRowConstructorWithAllNull));
7879

80+
// Register math functions that are present in the prestosql implementation
81+
// but not yet in the sparksql registry. These are semantically identical
82+
// to Spark's behavior for the same names.
83+
velox::registerFunction<velox::functions::LnFunction, double, double>({"ln"});
84+
velox::registerFunction<velox::functions::RadiansFunction, double, double>({"radians"});
85+
velox::registerFunction<velox::functions::SinFunction, double, double>({"sin"});
86+
velox::registerFunction<velox::functions::TanFunction, double, double>({"tan"});
87+
velox::registerFunction<velox::functions::TanhFunction, double, double>({"tanh"});
88+
7989
velox::functions::registerPrestoVectorFunctions();
8090
}
8191

0 commit comments

Comments
 (0)