Skip to content

Commit e593089

Browse files
committed
rebase_main
1 parent 9bb2ea4 commit e593089

4 files changed

Lines changed: 265 additions & 10 deletions

File tree

native/spark-expr/src/conversion_funcs/boolean.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::SparkResult;
19-
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
19+
use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array, TimestampMicrosecondBuilder};
2020
use arrow::datatypes::DataType;
2121
use std::sync::Arc;
2222

@@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
2828
)
2929
}
3030

31-
// only DF incompatible boolean cast
3231
pub fn cast_boolean_to_decimal(
3332
array: &ArrayRef,
3433
precision: u8,
@@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal(
4342
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
4443
}
4544

45+
pub(crate) fn cast_boolean_to_timestamp(
46+
array_ref: &ArrayRef,
47+
target_tz: &Option<Arc<str>>,
48+
) -> SparkResult<ArrayRef> {
49+
let bool_array = array_ref.as_boolean();
50+
let mut builder = TimestampMicrosecondBuilder::with_capacity(bool_array.len());
51+
52+
for i in 0..bool_array.len() {
53+
if bool_array.is_null(i) {
54+
builder.append_null();
55+
} else {
56+
let micros = if bool_array.value(i) { 1 } else { 0 };
57+
builder.append_value(micros);
58+
}
59+
}
60+
61+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
62+
}
63+
4664
#[cfg(test)]
4765
mod tests {
4866
use super::*;
@@ -53,6 +71,7 @@ mod tests {
5371
Int64Array, Int8Array, StringArray,
5472
};
5573
use arrow::datatypes::DataType::Decimal128;
74+
use arrow::datatypes::TimestampMicrosecondType;
5675
use std::sync::Arc;
5776

5877
fn test_input_bool_array() -> ArrayRef {
@@ -193,4 +212,26 @@ mod tests {
193212
assert_eq!(arr.value(1), expected_arr.value(1));
194213
assert!(arr.is_null(2));
195214
}
215+
216+
#[test]
217+
fn test_cast_boolean_to_timestamp() {
218+
let timezones: [Option<Arc<str>>; 3] = [
219+
Some(Arc::from("UTC")),
220+
Some(Arc::from("America/Los_Angeles")),
221+
None,
222+
];
223+
224+
for tz in &timezones {
225+
let bool_array: ArrayRef =
226+
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]));
227+
228+
let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
229+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
230+
231+
assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond
232+
assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch)
233+
assert!(ts_array.is_null(2));
234+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
235+
}
236+
}
196237
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
// under the License.
1717

1818
use crate::conversion_funcs::boolean::{
19-
cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
19+
cast_boolean_to_decimal, cast_boolean_to_timestamp, is_df_cast_from_bool_spark_compatible,
2020
};
2121
use crate::conversion_funcs::numeric::{
22-
cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128,
23-
cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
24-
is_df_cast_from_float_spark_compatible, is_df_cast_from_int_spark_compatible,
25-
spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8,
26-
spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
22+
cast_decimal_to_timestamp, cast_float32_to_decimal128, cast_float64_to_decimal128,
23+
cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
24+
is_df_cast_from_decimal_spark_compatible, is_df_cast_from_float_spark_compatible,
25+
is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
26+
spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, spark_cast_int_to_int,
27+
spark_cast_nonintegral_numeric_to_integral,
2728
};
2829
use crate::conversion_funcs::string::{
2930
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
@@ -778,7 +779,7 @@ fn cast_binary_formatter(value: &[u8]) -> String {
778779
#[cfg(test)]
779780
mod tests {
780781
use super::*;
781-
use arrow::array::{BooleanArray, StringArray};
782+
use arrow::array::StringArray;
782783
use arrow::datatypes::TimestampMicrosecondType;
783784
use arrow::datatypes::{Field, Fields};
784785
#[test]

native/spark-expr/src/conversion_funcs/numeric.rs

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::array::{
2424
OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder,
2525
};
2626
use arrow::datatypes::{
27-
is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type,
27+
i256, is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type,
2828
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
2929
};
3030
use num::{cast::AsPrimitive, ToPrimitive, Zero};
@@ -75,6 +75,56 @@ pub(crate) fn is_df_cast_from_decimal_spark_compatible(to_type: &DataType) -> bo
7575
)
7676
}
7777

78+
macro_rules! cast_float_to_timestamp_impl {
79+
($array:expr, $builder:expr, $primitive_type:ty, $eval_mode:expr) => {{
80+
let arr = $array.as_primitive::<$primitive_type>();
81+
for i in 0..arr.len() {
82+
if arr.is_null(i) {
83+
$builder.append_null();
84+
} else {
85+
let val = arr.value(i) as f64;
86+
// Path 1: NaN/Infinity check - error says TIMESTAMP
87+
if val.is_nan() || val.is_infinite() {
88+
if $eval_mode == EvalMode::Ansi {
89+
return Err(SparkError::CastInvalidValue {
90+
value: val.to_string(),
91+
from_type: "DOUBLE".to_string(),
92+
to_type: "TIMESTAMP".to_string(),
93+
});
94+
}
95+
$builder.append_null();
96+
} else {
97+
// Path 2: Multiply then check overflow - error says BIGINT
98+
let micros = val * MICROS_PER_SECOND as f64;
99+
if micros.floor() <= i64::MAX as f64 && micros.ceil() >= i64::MIN as f64 {
100+
$builder.append_value(micros as i64);
101+
} else {
102+
if $eval_mode == EvalMode::Ansi {
103+
let value_str = if micros.is_infinite() {
104+
if micros.is_sign_positive() {
105+
"Infinity".to_string()
106+
} else {
107+
"-Infinity".to_string()
108+
}
109+
} else if micros.is_nan() {
110+
"NaN".to_string()
111+
} else {
112+
format!("{:e}", micros).to_uppercase() + "D"
113+
};
114+
return Err(SparkError::CastOverFlow {
115+
value: value_str,
116+
from_type: "DOUBLE".to_string(),
117+
to_type: "BIGINT".to_string(),
118+
});
119+
}
120+
$builder.append_null();
121+
}
122+
}
123+
}
124+
}
125+
}};
126+
}
127+
78128
macro_rules! cast_float_to_string {
79129
($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
80130

@@ -913,6 +963,56 @@ pub(crate) fn cast_int_to_timestamp(
913963
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
914964
}
915965

966+
pub(crate) fn cast_decimal_to_timestamp(
967+
array_ref: &ArrayRef,
968+
target_tz: &Option<Arc<str>>,
969+
scale: i8,
970+
) -> SparkResult<ArrayRef> {
971+
let arr = array_ref.as_primitive::<Decimal128Type>();
972+
let scale_factor = 10_i128.pow(scale as u32);
973+
let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len());
974+
975+
for i in 0..arr.len() {
976+
if arr.is_null(i) {
977+
builder.append_null();
978+
} else {
979+
let value = arr.value(i);
980+
// Note: spark's big decimal
981+
let value_256 = i256::from_i128(value);
982+
let micros_256 = value_256 * i256::from_i128(MICROS_PER_SECOND as i128);
983+
let ts = micros_256 / i256::from_i128(scale_factor);
984+
builder.append_value(ts.as_i128() as i64);
985+
}
986+
}
987+
988+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
989+
}
990+
991+
pub(crate) fn cast_float_to_timestamp(
992+
array_ref: &ArrayRef,
993+
target_tz: &Option<Arc<str>>,
994+
eval_mode: EvalMode,
995+
) -> SparkResult<ArrayRef> {
996+
let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len());
997+
998+
match array_ref.data_type() {
999+
DataType::Float32 => {
1000+
cast_float_to_timestamp_impl!(array_ref, builder, Float32Type, eval_mode)
1001+
}
1002+
DataType::Float64 => {
1003+
cast_float_to_timestamp_impl!(array_ref, builder, Float64Type, eval_mode)
1004+
}
1005+
dt => {
1006+
return Err(SparkError::Internal(format!(
1007+
"Unsupported type for cast_float_to_timestamp: {:?}",
1008+
dt
1009+
)))
1010+
}
1011+
}
1012+
1013+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
1014+
}
1015+
9161016
#[cfg(test)]
9171017
mod tests {
9181018
use super::*;
@@ -1100,4 +1200,113 @@ mod tests {
11001200
assert!(casted.is_null(8));
11011201
assert!(casted.is_null(9));
11021202
}
1203+
1204+
#[test]
1205+
fn test_cast_decimal_to_timestamp() {
1206+
let timezones: [Option<Arc<str>>; 3] = [
1207+
Some(Arc::from("UTC")),
1208+
Some(Arc::from("America/Los_Angeles")),
1209+
None,
1210+
];
1211+
1212+
for tz in &timezones {
1213+
// Decimal128 with scale 6
1214+
let decimal_array: ArrayRef = Arc::new(
1215+
Decimal128Array::from(vec![
1216+
Some(0_i128),
1217+
Some(1_000_000_i128),
1218+
Some(-1_000_000_i128),
1219+
Some(1_500_000_i128),
1220+
Some(123_456_789_i128),
1221+
None,
1222+
])
1223+
.with_precision_and_scale(18, 6)
1224+
.unwrap(),
1225+
);
1226+
1227+
let result = cast_decimal_to_timestamp(&decimal_array, tz, 6).unwrap();
1228+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
1229+
1230+
assert_eq!(ts_array.value(0), 0);
1231+
assert_eq!(ts_array.value(1), 1_000_000);
1232+
assert_eq!(ts_array.value(2), -1_000_000);
1233+
assert_eq!(ts_array.value(3), 1_500_000);
1234+
assert_eq!(ts_array.value(4), 123_456_789);
1235+
assert!(ts_array.is_null(5));
1236+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
1237+
1238+
// Test with scale 2
1239+
let decimal_array: ArrayRef = Arc::new(
1240+
Decimal128Array::from(vec![Some(100_i128), Some(150_i128), Some(-250_i128)])
1241+
.with_precision_and_scale(10, 2)
1242+
.unwrap(),
1243+
);
1244+
1245+
let result = cast_decimal_to_timestamp(&decimal_array, tz, 2).unwrap();
1246+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
1247+
1248+
assert_eq!(ts_array.value(0), 1_000_000);
1249+
assert_eq!(ts_array.value(1), 1_500_000);
1250+
assert_eq!(ts_array.value(2), -2_500_000);
1251+
}
1252+
}
1253+
1254+
#[test]
1255+
fn test_cast_float_to_timestamp() {
1256+
let timezones: [Option<Arc<str>>; 3] = [
1257+
Some(Arc::from("UTC")),
1258+
Some(Arc::from("America/Los_Angeles")),
1259+
None,
1260+
];
1261+
let eval_modes = [EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try];
1262+
1263+
for tz in &timezones {
1264+
for eval_mode in &eval_modes {
1265+
// Float64 tests
1266+
let f64_array: ArrayRef = Arc::new(Float64Array::from(vec![
1267+
Some(0.0),
1268+
Some(1.0),
1269+
Some(-1.0),
1270+
Some(1.5),
1271+
Some(0.000001),
1272+
None,
1273+
]));
1274+
1275+
let result = cast_float_to_timestamp(&f64_array, tz, *eval_mode).unwrap();
1276+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
1277+
1278+
assert_eq!(ts_array.value(0), 0);
1279+
assert_eq!(ts_array.value(1), 1_000_000);
1280+
assert_eq!(ts_array.value(2), -1_000_000);
1281+
assert_eq!(ts_array.value(3), 1_500_000);
1282+
assert_eq!(ts_array.value(4), 1);
1283+
assert!(ts_array.is_null(5));
1284+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
1285+
1286+
// Float32 tests
1287+
let f32_array: ArrayRef = Arc::new(Float32Array::from(vec![
1288+
Some(0.0_f32),
1289+
Some(1.0_f32),
1290+
Some(-1.0_f32),
1291+
None,
1292+
]));
1293+
1294+
let result = cast_float_to_timestamp(&f32_array, tz, *eval_mode).unwrap();
1295+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
1296+
1297+
assert_eq!(ts_array.value(0), 0);
1298+
assert_eq!(ts_array.value(1), 1_000_000);
1299+
assert_eq!(ts_array.value(2), -1_000_000);
1300+
assert!(ts_array.is_null(3));
1301+
}
1302+
}
1303+
1304+
// ANSI mode errors on NaN/Infinity
1305+
let tz = &Some(Arc::from("UTC"));
1306+
let f64_nan: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::NAN)]));
1307+
assert!(cast_float_to_timestamp(&f64_nan, tz, EvalMode::Ansi).is_err());
1308+
1309+
let f64_inf: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::INFINITY)]));
1310+
assert!(cast_float_to_timestamp(&f64_inf, tz, EvalMode::Ansi).is_err());
1311+
}
11031312
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
639639
castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
640640
}
641641

642+
test("cast DecimalType(38,10) to TimestampType") {
643+
castTest(generateDecimalsPrecision38Scale18(), DataTypes.TimestampType)
644+
}
645+
642646
// CAST from StringType
643647

644648
test("cast StringType to BooleanType") {

0 commit comments

Comments
 (0)