Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
[[bench]]
name = "cast_from_boolean"
harness = false

[[bench]]
name = "cast_non_int_numeric_timestamp"
harness = false
143 changes: 143 additions & 0 deletions native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::builder::{BooleanBuilder, Decimal128Builder, Float32Builder, Float64Builder};
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
use std::sync::Arc;

const BATCH_SIZE: usize = 8192;

fn criterion_benchmark(c: &mut Criterion) {
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()));

let mut group = c.benchmark_group("cast_non_int_numeric_to_timestamp");

// Float32 -> Timestamp
let batch_f32 = create_float32_batch();
let expr_f32 = Arc::new(Column::new("a", 0));
let cast_f32_to_ts = Cast::new(expr_f32, timestamp_type.clone(), spark_cast_options.clone());
group.bench_function("cast_f32_to_timestamp", |b| {
b.iter(|| cast_f32_to_ts.evaluate(&batch_f32).unwrap());
});

// Float64 -> Timestamp
let batch_f64 = create_float64_batch();
let expr_f64 = Arc::new(Column::new("a", 0));
let cast_f64_to_ts = Cast::new(expr_f64, timestamp_type.clone(), spark_cast_options.clone());
group.bench_function("cast_f64_to_timestamp", |b| {
b.iter(|| cast_f64_to_ts.evaluate(&batch_f64).unwrap());
});

// Boolean -> Timestamp
let batch_bool = create_boolean_batch();
let expr_bool = Arc::new(Column::new("a", 0));
let cast_bool_to_ts = Cast::new(
expr_bool,
timestamp_type.clone(),
spark_cast_options.clone(),
);
group.bench_function("cast_bool_to_timestamp", |b| {
b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap());
});

// Decimal128 -> Timestamp
let batch_decimal = create_decimal128_batch();
let expr_decimal = Arc::new(Column::new("a", 0));
let cast_decimal_to_ts = Cast::new(
expr_decimal,
timestamp_type.clone(),
spark_cast_options.clone(),
);
group.bench_function("cast_decimal_to_timestamp", |b| {
b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap());
});

group.finish();
}

fn create_float32_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let mut b = Float32Builder::with_capacity(BATCH_SIZE);
for i in 0..BATCH_SIZE {
if i % 10 == 0 {
b.append_null();
} else {
b.append_value(rand::random::<f32>());
}
}
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
}

fn create_float64_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
let mut b = Float64Builder::with_capacity(BATCH_SIZE);
for i in 0..BATCH_SIZE {
if i % 10 == 0 {
b.append_null();
} else {
b.append_value(rand::random::<f64>());
}
}
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
}

fn create_boolean_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]));
let mut b = BooleanBuilder::with_capacity(BATCH_SIZE);
for i in 0..BATCH_SIZE {
if i % 10 == 0 {
b.append_null();
} else {
b.append_value(rand::random::<bool>());
}
}
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
}

fn create_decimal128_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Decimal128(18, 6),
true,
)]));
let mut b = Decimal128Builder::with_capacity(BATCH_SIZE);
for i in 0..BATCH_SIZE {
if i % 10 == 0 {
b.append_null();
} else {
b.append_value(i as i128 * 1_000_000);
}
}
let array = b.finish().with_precision_and_scale(18, 6).unwrap();
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}

fn config() -> Criterion {
Criterion::default()
}

criterion_group! {
name = benches;
config = config();
targets = criterion_benchmark
}
criterion_main!(benches);
45 changes: 43 additions & 2 deletions native/spark-expr/src/conversion_funcs/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::SparkResult;
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array, TimestampMicrosecondBuilder};
use arrow::datatypes::DataType;
use std::sync::Arc;

Expand All @@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
)
}

// only DF incompatible boolean cast
pub fn cast_boolean_to_decimal(
array: &ArrayRef,
precision: u8,
Expand All @@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal(
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
}

pub(crate) fn cast_boolean_to_timestamp(
array_ref: &ArrayRef,
target_tz: &Option<Arc<str>>,
) -> SparkResult<ArrayRef> {
let bool_array = array_ref.as_boolean();
let mut builder = TimestampMicrosecondBuilder::with_capacity(bool_array.len());

for i in 0..bool_array.len() {
if bool_array.is_null(i) {
builder.append_null();
} else {
let micros = if bool_array.value(i) { 1 } else { 0 };
builder.append_value(micros);
}
}

Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -53,6 +71,7 @@ mod tests {
Int64Array, Int8Array, StringArray,
};
use arrow::datatypes::DataType::Decimal128;
use arrow::datatypes::TimestampMicrosecondType;
use std::sync::Arc;

fn test_input_bool_array() -> ArrayRef {
Expand Down Expand Up @@ -193,4 +212,26 @@ mod tests {
assert_eq!(arr.value(1), expected_arr.value(1));
assert!(arr.is_null(2));
}

#[test]
fn test_cast_boolean_to_timestamp() {
let timezones: [Option<Arc<str>>; 3] = [
Some(Arc::from("UTC")),
Some(Arc::from("America/Los_Angeles")),
None,
];

for tz in &timezones {
let bool_array: ArrayRef =
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]));

let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
let ts_array = result.as_primitive::<TimestampMicrosecondType>();

assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond
assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch)
assert!(ts_array.is_null(2));
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
}
}
}
16 changes: 10 additions & 6 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
// under the License.

use crate::conversion_funcs::boolean::{
cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
cast_boolean_to_decimal, cast_boolean_to_timestamp, is_df_cast_from_bool_spark_compatible,
};
use crate::conversion_funcs::numeric::{
cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128,
cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
is_df_cast_from_float_spark_compatible, is_df_cast_from_int_spark_compatible,
spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8,
spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
cast_decimal_to_timestamp, cast_float32_to_decimal128, cast_float64_to_decimal128,
cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
is_df_cast_from_decimal_spark_compatible, is_df_cast_from_float_spark_compatible,
is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, spark_cast_int_to_int,
spark_cast_nonintegral_numeric_to_integral,
};
use crate::conversion_funcs::string::{
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
Expand Down Expand Up @@ -384,6 +385,9 @@ pub(crate) fn cast_array(
cast_boolean_to_decimal(&array, *precision, *scale)
}
(Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz),
(Float32 | Float64, Timestamp(_, tz)) => cast_float_to_timestamp(&array, tz, eval_mode),
(Boolean, Timestamp(_, tz)) => cast_boolean_to_timestamp(&array, tz),
(Decimal128(_, scale), Timestamp(_, tz)) => cast_decimal_to_timestamp(&array, tz, *scale),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(&from_type, to_type) =>
{
Expand Down
Loading