diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 346fb8d0f3f95..b902e543fbdf9 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -202,7 +202,7 @@ impl Range { } /// Generate `generate_series()` function which includes upper bound. - fn generate_series() -> Self { + pub fn generate_series() -> Self { Self { signature: Self::defined_signature(), include_upper_bound: true, @@ -296,7 +296,7 @@ impl Range { /// gen_range(3) => [0, 1, 2] /// gen_range(1, 4) => [1, 2, 3] /// gen_range(1, 7, 2) => [1, 3, 5] - fn gen_range_inner(&self, args: &[ArrayRef]) -> Result { + pub fn gen_range_inner(&self, args: &[ArrayRef]) -> Result { let (start_array, stop_array, step_array) = match args { [stop_array] => (None, as_int64_array(stop_array)?, None), [start_array, stop_array] => ( @@ -353,7 +353,7 @@ impl Range { Ok(arr) } - fn gen_range_date(&self, args: &[ArrayRef]) -> Result { + pub fn gen_range_date(&self, args: &[ArrayRef]) -> Result { let [start, stop, step] = take_function_args(self.name(), args)?; let step = as_interval_mdn_array(step)?; @@ -417,7 +417,7 @@ impl Range { Ok(arr) } - fn gen_range_timestamp(&self, args: &[ArrayRef]) -> Result { + pub fn gen_range_timestamp(&self, args: &[ArrayRef]) -> Result { let [start, stop, step] = take_function_args(self.name(), args)?; let step = as_interval_mdn_array(step)?; diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 6c16e05361641..020d1ecce1aa7 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -17,6 +17,7 @@ pub mod array_contains; pub mod repeat; +pub mod sequence; pub mod shuffle; pub mod slice; pub mod spark_array; @@ -30,6 +31,7 @@ make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); make_udf_function!(repeat::SparkArrayRepeat, array_repeat); make_udf_function!(slice::SparkSlice, slice); +make_udf_function!(sequence::SparkSequence, sequence); pub mod expr_fn { use datafusion_functions::export_functions; @@ -55,6 +57,11 @@ pub mod expr_fn { "Returns a slice of the array from the start index with the given length.", array start length )); + export_functions!(( + sequence, + "Returns a sequence of the array from the start index and end index.", + start stop step + )); } pub fn functions() -> Vec> { @@ -63,6 +70,7 @@ pub fn functions() -> Vec> { array(), shuffle(), array_repeat(), + sequence(), slice(), ] } diff --git a/datafusion/spark/src/function/array/sequence.rs b/datafusion/spark/src/function/array/sequence.rs new file mode 100644 index 0000000000000..e0c338640276e --- /dev/null +++ b/datafusion/spark/src/function/array/sequence.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::functions_nested_utils::make_scalar_function; +use arrow::array::{Array, Int64Builder}; +use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::internal_err; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::range::Range; +use std::sync::Arc; + +/// Spark-compatible `sequence` expression. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSequence { + signature: Signature, +} + +impl Default for SparkSequence { + fn default() -> Self { + Self::new() + } +} + +impl SparkSequence { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkSequence { + fn name(&self) -> &str { + "sequence" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let return_type = if args.arg_fields[0].data_type().is_null() + || args.arg_fields[1].data_type().is_null() + { + DataType::Null + } else { + DataType::List(Arc::new(Field::new_list_field( + args.arg_fields[0].data_type().clone(), + true, + ))) + }; + + Ok(Arc::new(Field::new( + "this_field_name_is_irrelevant", + return_type, + true, + ))) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + match arg_types.len() { + 2 => { + let first_data_type = + check_type(arg_types[0].clone(), "first".to_string().as_str())?; + let second_data_type = + check_type(arg_types[1].clone(), "second".to_string().as_str())?; + + if !first_data_type.is_null() + && !second_data_type.is_null() + && (first_data_type != second_data_type) + { + return exec_err!( + "first({first_data_type}) and second({second_data_type}) input types should be same" + ); + } + + Ok(vec![first_data_type, second_data_type]) + } + 3 => { + let first_data_type = + check_type(arg_types[0].clone(), "first".to_string().as_str())?; + let second_data_type = + check_type(arg_types[1].clone(), "second".to_string().as_str())?; + let third_data_type = check_interval_type( + arg_types[2].clone(), + "third".to_string().as_str(), + )?; + + if !first_data_type.is_null() && !second_data_type.is_null() { + if first_data_type != second_data_type { + return exec_err!( + "first({first_data_type}) and second({second_data_type}) input types should be same" + ); + } + + if !check_interval_type_by_first_type( + &first_data_type, + &third_data_type, + ) { + return exec_err!( + "interval type should be integer for integer input or time based" + ); + } + } + + Ok(vec![first_data_type, second_data_type, third_data_type]) + } + _ => { + exec_err!("num of input parameters should be 2 or 3") + } + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = &args.args; + + if args.iter().any(|arg| arg.data_type().is_null()) { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + + match args[0].data_type() { + DataType::Int64 => { + validate_int64_sequence_step(args)?; + let optional_new_args = add_step_argument_if_not_exists(args)?; + let new_args = match optional_new_args { + Some(new_args) => &new_args.to_owned(), + None => args, + }; + make_scalar_function(|args| { + Range::generate_series().gen_range_inner(args) + })(new_args) + } + DataType::Date32 | DataType::Date64 => { + let optional_new_args = add_interval_argument_if_not_exists(args); + let new_args = match optional_new_args { + Some(new_args) => &new_args.to_owned(), + None => args, + }; + make_scalar_function(|args| Range::generate_series().gen_range_date(args))( + new_args, + ) + } + DataType::Timestamp(_, _) => { + let optional_new_args = add_interval_argument_if_not_exists(args); + let new_args = match optional_new_args { + Some(new_args) => &new_args.to_owned(), + None => args, + }; + make_scalar_function(|args| { + Range::generate_series().gen_range_timestamp(args) + })(new_args) + } + dt => { + internal_err!( + "Signature failed to guard unknown input type for {}: {dt}", + self.name() + ) + } + } + } +} + +/// Validates explicit `step` for 3-argument integer `sequence` (Spark semantics). +fn validate_int64_sequence_step(args: &[ColumnarValue]) -> Result<()> { + if args.len() != 3 { + return Ok(()); + } + let arrays = ColumnarValue::values_to_arrays(args)?; + let start = as_int64_array(&arrays[0])?; + let stop = as_int64_array(&arrays[1])?; + let step = as_int64_array(&arrays[2])?; + for i in 0..start.len() { + if start.is_null(i) || stop.is_null(i) || step.is_null(i) { + continue; + } + let s = start.value(i); + let e = stop.value(i); + let st = step.value(i); + if st == 0 { + return exec_err!("Step cannot be 0 for sequence"); + } + if s < e && st <= 0 { + return exec_err!("When start < stop, step must be positive"); + } + if s > e && st >= 0 { + return exec_err!("When start > stop, step must be negative"); + } + } + Ok(()) +} + +/// When only start and stop are given, Spark picks step `1` if start ≤ stop and `-1` if start > stop. +fn add_step_argument_if_not_exists( + args: &[ColumnarValue], +) -> Result>> { + if args.len() != 2 { + return Ok(None); + } + let arrays = ColumnarValue::values_to_arrays(args)?; + let start = as_int64_array(&arrays[0])?; + let stop = as_int64_array(&arrays[1])?; + let len = start.len(); + let mut step = Int64Builder::with_capacity(len); + for i in 0..len { + if start.is_null(i) || stop.is_null(i) { + step.append_null(); + } else if start.value(i) > stop.value(i) { + step.append_value(-1); + } else { + step.append_value(1); + } + } + let step = step.finish(); + Ok(Some(vec![ + args[0].clone(), + args[1].clone(), + ColumnarValue::Array(Arc::new(step)), + ])) +} + +fn check_type( + data_type: DataType, + param_name: &str, +) -> Result { + let result_type = match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + | DataType::Null => data_type, + _ => { + return exec_err!( + "{} parameter type must be one of integer, date or timestamp type but found: {}", + param_name, + data_type + ); + } + }; + Ok(result_type) +} + +fn check_interval_type( + data_type: DataType, + param_name: &str, +) -> Result { + let result_type = match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + DataType::Interval(_) => data_type, + _ => { + return exec_err!( + "{} parameter type must be one of integer or interval type but found: {}", + param_name, + data_type + ); + } + }; + Ok(result_type) +} + +fn check_interval_type_by_first_type( + first_data_type: &DataType, + third_data_type: &DataType, +) -> bool { + match first_data_type { + DataType::Int64 | DataType::UInt64 => first_data_type == third_data_type, + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) => { + matches!(third_data_type, DataType::Interval(_)) + } + _ => false, + } +} + +fn add_interval_argument_if_not_exists( + args: &[ColumnarValue], +) -> Option> { + if args.len() == 2 { + let mut new_args = args.to_owned(); + new_args.push(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( + Some(IntervalMonthDayNano { + months: 0, + days: 1, + nanoseconds: 0, + }), + ))); + Some(new_args) + } else { + None + } +} diff --git a/datafusion/sqllogictest/test_files/spark/array/sequence.slt b/datafusion/sqllogictest/test_files/spark/array/sequence.slt index bb4aa06bfd257..7bd6420e29311 100644 --- a/datafusion/sqllogictest/test_files/spark/array/sequence.slt +++ b/datafusion/sqllogictest/test_files/spark/array/sequence.slt @@ -23,10 +23,201 @@ ## Original Query: SELECT sequence(1, 5); ## PySpark 3.5.5 Result: {'sequence(1, 5)': [1, 2, 3, 4, 5], 'typeof(sequence(1, 5))': 'array', 'typeof(1)': 'int', 'typeof(5)': 'int'} -#query -#SELECT sequence(1::int, 5::int); -## Original Query: SELECT sequence(5, 1); -## PySpark 3.5.5 Result: {'sequence(5, 1)': [5, 4, 3, 2, 1], 'typeof(sequence(5, 1))': 'array', 'typeof(5)': 'int', 'typeof(1)': 'int'} -#query -#SELECT sequence(5::int, 1::int); +query ? +SELECT sequence(1::int, 3::int); +---- +[1, 2, 3] + +query ? +SELECT sequence(1, 6, 2); +---- +[1, 3, 5] + +query ? +SELECT sequence(0, 5, 1); +---- +[0, 1, 2, 3, 4, 5] + +query ? +SELECT sequence(-3::int, 3::int); +---- +[-3, -2, -1, 0, 1, 2, 3] + +query ? +SELECT sequence(5::int, 1::int); +---- +[5, 4, 3, 2, 1] + +query ? +SELECT sequence(5::int, -3::int); +---- +[5, 4, 3, 2, 1, 0, -1, -2, -3] + +query ? +SELECT sequence(5, -3); +---- +[5, 4, 3, 2, 1, 0, -1, -2, -3] + +query ? +SELECT sequence(5, -3, -2); +---- +[5, 3, 1, -1, -3] + +query ? +SELECT sequence(5, 1, -1); +---- +[5, 4, 3, 2, 1] + +query ? +SELECT sequence(2, 2); +---- +[2] + +query ? +SELECT sequence(-2, -2); +---- +[-2] + +query ? +SELECT sequence(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '-1' DAY) +---- +[2023-01-03T00:00:00, 2023-01-02T00:00:00, 2023-01-01T00:00:00] + +query ? +SELECT sequence(DATE '2018-01-01', DATE '2018-03-01', INTERVAL 1 MONTH); +---- +[2018-01-01, 2018-02-01, 2018-03-01] + +# Basic timestamp sequence with 1 day interval +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-04T00:00:00', INTERVAL '1' DAY) +---- +[2023-01-01T00:00:00, 2023-01-02T00:00:00, 2023-01-03T00:00:00, 2023-01-04T00:00:00] + +# Timestamp sequence with hour interval +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T03:00:00', INTERVAL '1' HOUR) +---- +[2023-01-01T00:00:00, 2023-01-01T01:00:00, 2023-01-01T02:00:00, 2023-01-01T03:00:00] + +# Timestamp sequence with month interval +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-04-01T00:00:00', INTERVAL '1' MONTH) +---- +[2023-01-01T00:00:00, 2023-02-01T00:00:00, 2023-03-01T00:00:00, 2023-04-01T00:00:00] + +# Timestamp sequence (includes end) +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- +[2023-01-01T00:00:00, 2023-01-02T00:00:00, 2023-01-03T00:00:00] + +# Timestamp sequence with timezone +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00+00:00', TIMESTAMP '2023-01-03T00:00:00+00:00', INTERVAL '1' DAY) +---- +[2023-01-01T00:00:00, 2023-01-02T00:00:00, 2023-01-03T00:00:00] + +# Negative timestamp sequence (going backwards) +query ? +SELECT sequence(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '-1' DAY) +---- +[2023-01-03T00:00:00, 2023-01-02T00:00:00, 2023-01-01T00:00:00] + +query ? +SELECT sequence(DATE '2018-01-01', DATE '2018-01-04'); +---- +[2018-01-01, 2018-01-02, 2018-01-03, 2018-01-04] + +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00'); +---- +[2023-01-01T00:00:00, 2023-01-02T00:00:00, 2023-01-03T00:00:00] + +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- +[2023-01-01T00:00:00] + +query ? +SELECT sequence(DATE '1992-01-01', DATE '1992-01-03') +---- +[1992-01-01, 1992-01-02, 1992-01-03] + +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T00:00:00') +---- +[2023-01-01T00:00:00] + +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- +[2023-01-01T00:00:00] + +query error +SELECT sequence(DATE '1992-01-01', DATE '1992-01-03', INTERVAL '6' HOUR) +---- +DataFusion error: Execution error: Cannot generate date range less than 1 day. + + +# NULL VALUES TESTS +query ? +SELECT sequence(NULL, NULL); +---- +NULL + +query ? +SELECT sequence(1, NULL); +---- +NULL + +query ? +SELECT sequence(NULL, 2); +---- +NULL + +query ? +SELECT sequence(NULL, TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- +NULL + +query ? +SELECT sequence(NULL::TIMESTAMP, TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- +NULL + +query ? +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', NULL::TIMESTAMP, INTERVAL '1' DAY) +---- +NULL + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: interval type should be integer for integer input or time based +SELECT sequence(1, 6, INTERVAL 1 MONTH); + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: interval type should be integer for integer input or time based +SELECT sequence(DATE '2023-01-01', DATE '2023-03-01', 1); + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: interval type should be integer for integer input or time based +SELECT sequence(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-03-01T00:00:00', 1); + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: second parameter type must be one of integer, date or timestamp type but found: Utf8 +SELECT sequence(1, 'abc'); + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: first parameter type must be one of integer, date or timestamp type but found: Utf8 +SELECT sequence('abc', 2); + +query error DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: num of input parameters should be 2 or 3 +SELECT sequence(1); + +query error DataFusion error: Execution error: When start > stop, step must be negative +SELECT sequence(5, 0, 1); + +query error DataFusion error: Execution error: When start < stop, step must be positive +SELECT sequence(0, 5, -1); + +query error DataFusion error: Execution error: Step cannot be 0 for sequence +SELECT sequence(1, 5, 0); + +query error DataFusion error: Execution error: Step cannot be 0 for sequence +SELECT sequence(5, 1, 0);