Skip to content

Commit 609a605

Browse files
Merge branch 'apache:main' into main
2 parents 7c2f082 + 394eb5d commit 609a605

5 files changed

Lines changed: 101 additions & 32 deletions

File tree

.github/workflows/pr_benchmark_check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
cargo check --benches
7272
7373
- name: Cache Maven dependencies
74-
uses: actions/cache@v4
74+
uses: actions/cache@v5
7575
with:
7676
path: |
7777
~/.m2/repository

native/Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,18 @@ macro_rules! cast_decimal_to_int16_down {
685685
.map(|value| match value {
686686
Some(value) => {
687687
let divisor = 10_i128.pow($scale as u32);
688-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
688+
let truncated = value / divisor;
689689
let is_overflow = truncated.abs() > i32::MAX.into();
690690
if is_overflow {
691691
return Err(cast_overflow(
692-
&format!("{}.{}BD", truncated, decimal),
692+
&format!(
693+
"{}BD",
694+
format_decimal_str(
695+
&value.to_string(),
696+
$precision as usize,
697+
$scale
698+
)
699+
),
693700
&format!("DECIMAL({},{})", $precision, $scale),
694701
$dest_type_str,
695702
));
@@ -698,7 +705,14 @@ macro_rules! cast_decimal_to_int16_down {
698705
<$rust_dest_type>::try_from(i32_value)
699706
.map_err(|_| {
700707
cast_overflow(
701-
&format!("{}.{}BD", truncated, decimal),
708+
&format!(
709+
"{}BD",
710+
format_decimal_str(
711+
&value.to_string(),
712+
$precision as usize,
713+
$scale
714+
)
715+
),
702716
&format!("DECIMAL({},{})", $precision, $scale),
703717
$dest_type_str,
704718
)
@@ -748,11 +762,18 @@ macro_rules! cast_decimal_to_int32_up {
748762
.map(|value| match value {
749763
Some(value) => {
750764
let divisor = 10_i128.pow($scale as u32);
751-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
765+
let truncated = value / divisor;
752766
let is_overflow = truncated.abs() > $max_dest_val.into();
753767
if is_overflow {
754768
return Err(cast_overflow(
755-
&format!("{}.{}BD", truncated, decimal),
769+
&format!(
770+
"{}BD",
771+
format_decimal_str(
772+
&value.to_string(),
773+
$precision as usize,
774+
$scale
775+
)
776+
),
756777
&format!("DECIMAL({},{})", $precision, $scale),
757778
$dest_type_str,
758779
));
@@ -780,6 +801,30 @@ macro_rules! cast_decimal_to_int32_up {
780801
}};
781802
}
782803

804+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
805+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
806+
let (sign, rest) = match value_str.strip_prefix('-') {
807+
Some(stripped) => ("-", stripped),
808+
None => ("", value_str),
809+
};
810+
let bound = precision.min(rest.len()) + sign.len();
811+
let value_str = &value_str[0..bound];
812+
813+
if scale == 0 {
814+
value_str.to_string()
815+
} else if scale < 0 {
816+
let padding = value_str.len() + scale.unsigned_abs() as usize;
817+
format!("{value_str:0<padding$}")
818+
} else if rest.len() > scale as usize {
819+
// Decimal separator is in the middle of the string
820+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
821+
format!("{whole}.{decimal}")
822+
} else {
823+
// String has to be padded
824+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
825+
}
826+
}
827+
783828
impl Cast {
784829
pub fn new(
785830
child: Arc<dyn PhysicalExpr>,
@@ -1866,12 +1911,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
18661911
),
18671912
(DataType::Decimal128(precision, scale), DataType::Int8) => {
18681913
cast_decimal_to_int16_down!(
1869-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1914+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
18701915
)
18711916
}
18721917
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18731918
cast_decimal_to_int16_down!(
1874-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1919+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18751920
)
18761921
}
18771922
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
529529

530530
test("cast DecimalType(10,2) to ShortType") {
531531
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
532+
castTest(
533+
generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
534+
DataTypes.ShortType)
532535
}
533536

534537
test("cast DecimalType(10,2) to IntegerType") {
@@ -553,14 +556,23 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
553556

554557
test("cast DecimalType(38,18) to ShortType") {
555558
castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
559+
castTest(
560+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
561+
DataTypes.ShortType)
556562
}
557563

558564
test("cast DecimalType(38,18) to IntegerType") {
559565
castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
566+
castTest(
567+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
568+
DataTypes.IntegerType)
560569
}
561570

562571
test("cast DecimalType(38,18) to LongType") {
563572
castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
573+
castTest(
574+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
575+
DataTypes.LongType)
564576
}
565577

566578
test("cast DecimalType(10,2) to StringType") {
@@ -1205,6 +1217,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12051217
BigDecimal("32768.678"),
12061218
BigDecimal("123456.789"),
12071219
BigDecimal("99999999.999"))
1220+
generateDecimalsPrecision10Scale2(values)
1221+
}
1222+
1223+
private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = {
12081224
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
12091225
}
12101226

@@ -1227,6 +1243,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12271243
// Long Max
12281244
BigDecimal("9223372036854775808.234567"),
12291245
BigDecimal("99999999999999999999.999999999999"))
1246+
generateDecimalsPrecision38Scale18(values)
1247+
}
1248+
1249+
private def generateDecimalsPrecision38Scale18(values: Seq[BigDecimal]): DataFrame = {
12301250
withNulls(values).toDF("a")
12311251
}
12321252

spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,6 @@ case class StringExprConfig(
4444
// spotless:on
4545
object CometStringExpressionBenchmark extends CometBenchmarkBase {
4646

47-
/**
48-
* Generic method to run a string expression benchmark with the given configuration.
49-
*/
50-
def runStringExprBenchmark(config: StringExprConfig, values: Int): Unit = {
51-
withTempPath { dir =>
52-
withTempTable("parquetV1Table") {
53-
prepareTable(dir, spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 100) AS c1 FROM $tbl"))
54-
55-
val extraConfigs =
56-
Map(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") ++ config.extraCometConfigs
57-
58-
runExpressionBenchmark(config.name, values, config.query, extraConfigs)
59-
}
60-
}
61-
}
62-
6347
// Configuration for all string expression benchmarks
6448
private val stringExpressions = List(
6549
StringExprConfig("Substring", "select substring(c1, 1, 100) from parquetV1Table"),
@@ -71,7 +55,16 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase {
7155
StringExprConfig("chr", "select chr(c1) from parquetV1Table"),
7256
StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"),
7357
StringExprConfig("trim", "select trim(c1) from parquetV1Table"),
58+
StringExprConfig("btrim", "select btrim(c1) from parquetV1Table"),
59+
StringExprConfig("ltrim", "select ltrim(c1) from parquetV1Table"),
60+
StringExprConfig("rtrim", "select rtrim(c1) from parquetV1Table"),
61+
StringExprConfig("lpad", "select lpad(c1, 120, 'x') from parquetV1Table"),
62+
StringExprConfig("rpad", "select rpad(c1, 120, 'x') from parquetV1Table"),
63+
StringExprConfig("concat", "select concat(c1, c1) from parquetV1Table"),
7464
StringExprConfig("concatws", "select concat_ws(' ', c1, c1) from parquetV1Table"),
65+
StringExprConfig("contains", "select contains(c1, '123') from parquetV1Table"),
66+
StringExprConfig("startsWith", "select startswith(c1, '123') from parquetV1Table"),
67+
StringExprConfig("endsWith", "select endswith(c1, '123') from parquetV1Table"),
7568
StringExprConfig("length", "select length(c1) from parquetV1Table"),
7669
StringExprConfig("repeat", "select repeat(c1, 3) from parquetV1Table"),
7770
StringExprConfig("reverse", "select reverse(c1) from parquetV1Table"),
@@ -81,11 +74,22 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase {
8174
StringExprConfig("translate", "select translate(c1, '123456', 'aBcDeF') from parquetV1Table"))
8275

8376
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
84-
val values = 1024 * 1024;
77+
runBenchmarkWithTable("String expressions", 1024) { v =>
78+
withTempPath { dir =>
79+
withTempTable("parquetV1Table") {
80+
prepareTable(
81+
dir,
82+
spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl"))
83+
84+
val extraConfigs = Map(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true")
8585

86-
stringExpressions.foreach { config =>
87-
runBenchmarkWithTable(config.name, values) { v =>
88-
runStringExprBenchmark(config, v)
86+
stringExpressions.foreach { config =>
87+
val allConfigs = extraConfigs ++ config.extraCometConfigs
88+
runBenchmark(config.name) {
89+
runExpressionBenchmark(config.name, v, config.query, allConfigs)
90+
}
91+
}
92+
}
8993
}
9094
}
9195
}

0 commit comments

Comments
 (0)