Skip to content

Commit b10e40b

Browse files
author
B Vadlamani
committed
refactor_boolean_cast_ops
1 parent 1d01b7d commit b10e40b

4 files changed

Lines changed: 144 additions & 105 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::DataType;
19+
20+
pub fn can_cast_from_boolean(to_type: &DataType) -> bool {
21+
use DataType::*;
22+
matches!(
23+
to_type,
24+
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
25+
)
26+
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 7 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::conversion_funcs::boolean::can_cast_from_boolean;
19+
use crate::conversion_funcs::utils::spark_cast_postprocess;
1820
use crate::utils::array_with_timezone;
1921
use crate::{timezone, BinaryOutputStyle};
2022
use crate::{EvalMode, SparkError, SparkResult};
@@ -36,7 +38,7 @@ use arrow::{
3638
GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
3739
PrimitiveArray,
3840
},
39-
compute::{cast_with_options, take, unary, CastOptions},
41+
compute::{cast_with_options, take, CastOptions},
4042
datatypes::{
4143
is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
4244
Float64Type, Int64Type, TimestampMicrosecondType,
@@ -47,16 +49,10 @@ use arrow::{
4749
};
4850
use base64::prelude::*;
4951
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
50-
use datafusion::common::{
51-
cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
52-
ScalarValue,
53-
};
52+
use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue};
5453
use datafusion::physical_expr::PhysicalExpr;
5554
use datafusion::physical_plan::ColumnarValue;
56-
use num::{
57-
cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive,
58-
Zero,
59-
};
55+
use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, Zero};
6056
use regex::Regex;
6157
use std::str::FromStr;
6258
use std::{
@@ -69,8 +65,6 @@ use std::{
6965

7066
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
7167

72-
const MICROS_PER_SECOND: i64 = 1000000;
73-
7468
static CAST_OPTIONS: CastOptions = CastOptions {
7569
safe: true,
7670
format_options: FormatOptions::new()
@@ -187,7 +181,7 @@ pub fn cast_supported(
187181
}
188182

189183
match (from_type, to_type) {
190-
(Boolean, _) => can_cast_from_boolean(to_type, options),
184+
(Boolean, _) => can_cast_from_boolean(to_type),
191185
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
192186
if options.allow_cast_unsigned_ints =>
193187
{
@@ -302,11 +296,6 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b
302296
}
303297
}
304298

305-
fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
306-
use DataType::*;
307-
matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
308-
}
309-
310299
fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
311300
use DataType::*;
312301
matches!(
@@ -1321,16 +1310,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
13211310
DataType::Null => {
13221311
matches!(to_type, DataType::List(_))
13231312
}
1324-
DataType::Boolean => matches!(
1325-
to_type,
1326-
DataType::Int8
1327-
| DataType::Int16
1328-
| DataType::Int32
1329-
| DataType::Int64
1330-
| DataType::Float32
1331-
| DataType::Float64
1332-
| DataType::Utf8
1333-
),
1313+
DataType::Boolean => can_cast_from_boolean(to_type),
13341314
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
13351315
matches!(
13361316
to_type,
@@ -2987,84 +2967,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
29872967
}
29882968
}
29892969

2990-
/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
2991-
/// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
2992-
/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
2993-
/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
2994-
/// expressions/cast.rs, so it can be still Dictionary.
2995-
fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2996-
match (from_type, to_type) {
2997-
(DataType::Timestamp(_, _), DataType::Int64) => {
2998-
// See Spark's `Cast` expression
2999-
unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
3000-
}
3001-
(DataType::Dictionary(_, value_type), DataType::Int64)
3002-
if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
3003-
{
3004-
// See Spark's `Cast` expression
3005-
unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
3006-
}
3007-
(DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
3008-
(DataType::Dictionary(_, value_type), DataType::Utf8)
3009-
if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
3010-
{
3011-
remove_trailing_zeroes(array)
3012-
}
3013-
_ => array,
3014-
}
3015-
}
3016-
3017-
/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
3018-
fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
3019-
where
3020-
T: ArrowPrimitiveType,
3021-
F: Fn(T::Native) -> T::Native,
3022-
{
3023-
if let Some(d) = array.as_any_dictionary_opt() {
3024-
let new_values = unary_dyn::<F, T>(d.values(), op)?;
3025-
return Ok(Arc::new(d.with_values(Arc::new(new_values))));
3026-
}
3027-
3028-
match array.as_primitive_opt::<T>() {
3029-
Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
3030-
Ok(Arc::new(unary::<T, F, T>(
3031-
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
3032-
op,
3033-
)))
3034-
}
3035-
_ => Err(ArrowError::NotYetImplemented(format!(
3036-
"Cannot perform unary operation of type {} on array of type {}",
3037-
T::DATA_TYPE,
3038-
array.data_type()
3039-
))),
3040-
}
3041-
}
3042-
3043-
/// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
3044-
/// to match Spark behavior
3045-
/// example:
3046-
/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
3047-
/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
3048-
/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
3049-
/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00"
3050-
/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
3051-
fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
3052-
let string_array = as_generic_string_array::<i32>(&array).unwrap();
3053-
let result = string_array
3054-
.iter()
3055-
.map(|s| s.map(trim_end))
3056-
.collect::<GenericStringArray<i32>>();
3057-
Arc::new(result) as ArrayRef
3058-
}
3059-
3060-
fn trim_end(s: &str) -> &str {
3061-
if s.rfind('.').is_some() {
3062-
s.trim_end_matches('0')
3063-
} else {
3064-
s
3065-
}
3066-
}
3067-
30682970
#[cfg(test)]
30692971
mod tests {
30702972
use arrow::array::StringArray;

native/spark-expr/src/conversion_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod boolean;
1819
pub mod cast;
20+
mod utils;
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray,
20+
};
21+
use arrow::compute::unary;
22+
use arrow::datatypes::{DataType, Int64Type};
23+
use arrow::error::ArrowError;
24+
use datafusion::common::cast::as_generic_string_array;
25+
use num::integer::div_floor;
26+
use std::sync::Arc;
27+
28+
const MICROS_PER_SECOND: i64 = 1000000;
29+
/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
30+
pub fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
31+
where
32+
T: ArrowPrimitiveType,
33+
F: Fn(T::Native) -> T::Native,
34+
{
35+
if let Some(d) = array.as_any_dictionary_opt() {
36+
let new_values = unary_dyn::<F, T>(d.values(), op)?;
37+
return Ok(Arc::new(d.with_values(Arc::new(new_values))));
38+
}
39+
40+
match array.as_primitive_opt::<T>() {
41+
Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
42+
Ok(Arc::new(unary::<T, F, T>(
43+
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
44+
op,
45+
)))
46+
}
47+
_ => Err(ArrowError::NotYetImplemented(format!(
48+
"Cannot perform unary operation of type {} on array of type {}",
49+
T::DATA_TYPE,
50+
array.data_type()
51+
))),
52+
}
53+
}
54+
55+
/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
56+
/// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
57+
/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
58+
/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
59+
/// expressions/cast.rs, so it can be still Dictionary.
60+
pub fn spark_cast_postprocess(
61+
array: ArrayRef,
62+
from_type: &DataType,
63+
to_type: &DataType,
64+
) -> ArrayRef {
65+
match (from_type, to_type) {
66+
(DataType::Timestamp(_, _), DataType::Int64) => {
67+
// See Spark's `Cast` expression
68+
unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
69+
}
70+
(DataType::Dictionary(_, value_type), DataType::Int64)
71+
if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
72+
{
73+
// See Spark's `Cast` expression
74+
unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
75+
}
76+
(DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
77+
(DataType::Dictionary(_, value_type), DataType::Utf8)
78+
if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
79+
{
80+
remove_trailing_zeroes(array)
81+
}
82+
_ => array,
83+
}
84+
}
85+
86+
/// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
87+
/// to match Spark behavior
88+
/// example:
89+
/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
90+
/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
91+
/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
92+
/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00"
93+
/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
94+
fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
95+
let string_array = as_generic_string_array::<i32>(&array).unwrap();
96+
let result = string_array
97+
.iter()
98+
.map(|s| s.map(trim_end))
99+
.collect::<GenericStringArray<i32>>();
100+
Arc::new(result) as ArrayRef
101+
}
102+
103+
fn trim_end(s: &str) -> &str {
104+
if s.rfind('.').is_some() {
105+
s.trim_end_matches('0')
106+
} else {
107+
s
108+
}
109+
}

0 commit comments

Comments
 (0)