Skip to content

Commit 34d9ea8

Browse files
committed
int_to_binary_boolean_to_decimal
1 parent 61d7db3 commit 34d9ea8

3 files changed

Lines changed: 45 additions & 13 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::utils::array_with_timezone;
19+
use crate::EvalMode::Legacy;
1920
use crate::{timezone, BinaryOutputStyle};
2021
use crate::{EvalMode, SparkError, SparkResult};
2122
use arrow::array::builder::StringBuilder;
@@ -25,8 +26,8 @@ use arrow::array::{
2526
};
2627
use arrow::compute::can_cast_types;
2728
use arrow::datatypes::{
28-
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type,
29-
GenericBinaryType, Schema,
29+
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
30+
Schema,
3031
};
3132
use arrow::{
3233
array::{
@@ -66,7 +67,6 @@ use std::{
6667
num::Wrapping,
6768
sync::Arc,
6869
};
69-
use crate::EvalMode::Legacy;
7070

7171
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
7272

@@ -305,7 +305,10 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b
305305

306306
fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
307307
use DataType::*;
308-
matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
308+
matches!(
309+
to_type,
310+
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
311+
)
309312
}
310313

311314
fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
@@ -1125,9 +1128,18 @@ fn cast_array(
11251128
}
11261129
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
11271130
(Int8, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int8Array, 1),
1128-
(Int16, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int16Array, 2),
1129-
(Int32, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int32Array, 4),
1130-
(Int64, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int64Array, 8),
1131+
(Int16, Binary) if (eval_mode == Legacy) => {
1132+
cast_whole_num_to_binary!(&array, Int16Array, 2)
1133+
}
1134+
(Int32, Binary) if (eval_mode == Legacy) => {
1135+
cast_whole_num_to_binary!(&array, Int32Array, 4)
1136+
}
1137+
(Int64, Binary) if (eval_mode == Legacy) => {
1138+
cast_whole_num_to_binary!(&array, Int64Array, 8)
1139+
}
1140+
(Boolean, Decimal128(precision, scale)) => {
1141+
cast_boolean_to_decimal(&array, *precision, *scale)
1142+
}
11311143
_ if cast_options.is_adapting_schema
11321144
|| is_datafusion_spark_compatible(from_type, to_type) =>
11331145
{
@@ -1146,6 +1158,16 @@ fn cast_array(
11461158
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
11471159
}
11481160

1161+
fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult<ArrayRef> {
1162+
let bool_array = array.as_boolean();
1163+
let scale_factor = 10_i128.pow(scale as u32);
1164+
let result: Decimal128Array = bool_array
1165+
.iter()
1166+
.map(|v| v.map(|b| if b { scale_factor } else { 0 }))
1167+
.collect();
1168+
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
1169+
}
1170+
11491171
fn cast_string_to_float(
11501172
array: &ArrayRef,
11511173
to_type: &DataType,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.expressions
2121

2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Literal}
2323
import org.apache.spark.sql.internal.SQLConf
24-
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, DecimalType, NullType, StructType}
24+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2525

2626
import org.apache.comet.CometConf
2727
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -263,7 +263,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
263263

264264
private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
265265
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
266-
DataTypes.FloatType | DataTypes.DoubleType =>
266+
DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
267267
Compatible()
268268
case _ => unsupported(DataTypes.BooleanType, toType)
269269
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
134134
castTest(generateBools(), DataTypes.DoubleType)
135135
}
136136

137-
ignore("cast BooleanType to DecimalType(10,2)") {
138-
// Arrow error: Cast error: Casting from Boolean to Decimal128(10, 2) not supported
137+
test("cast BooleanType to DecimalType(10,2)") {
139138
castTest(generateBools(), DataTypes.createDecimalType(10, 2))
140139
}
141140

141+
test("cast BooleanType to DecimalType(14,4)") {
142+
castTest(generateBools(), DataTypes.createDecimalType(14, 4))
143+
}
144+
145+
test("cast BooleanType to DecimalType(30,0)") {
146+
castTest(generateBools(), DataTypes.createDecimalType(30, 0))
147+
}
148+
142149
test("cast BooleanType to StringType") {
143150
castTest(generateBools(), DataTypes.StringType)
144151
}
@@ -1366,9 +1373,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13661373
}
13671374

13681375
if (testTry) {
1376+
data.createOrReplaceTempView("t")
13691377
// try_cast() should always return null for invalid inputs
1378+
// not using spark DSL since it `try_cast` is only available from Spark 4x
13701379
val df2 =
1371-
data.select(col("a"), col("a").try_cast(toType)).orderBy(col("a"))
1380+
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
13721381
if (hasIncompatibleType) {
13731382
checkSparkAnswer(df2)
13741383
} else {
@@ -1432,8 +1441,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14321441

14331442
// try_cast() should always return null for invalid inputs
14341443
if (testTry) {
1444+
data.createOrReplaceTempView("t")
14351445
val df2 =
1436-
data.select(col("a"), col("a").try_cast(toType)).orderBy(col("a"))
1446+
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
14371447
if (hasIncompatibleType) {
14381448
checkSparkAnswer(df2)
14391449
} else {

0 commit comments

Comments
 (0)