Skip to content

Commit 744dcc7

Browse files
feat: Implement datafusion-spark sequence function
1 parent 538a201 commit 744dcc7

File tree

4 files changed

+554
-14
lines changed

4 files changed

+554
-14
lines changed

datafusion/functions-nested/src/range.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ impl Range {
203203
}
204204

205205
/// Generate `generate_series()` function which includes upper bound.
206-
fn generate_series() -> Self {
206+
pub fn generate_series() -> Self {
207207
Self {
208208
signature: Self::defined_signature(),
209209
include_upper_bound: true,
@@ -297,14 +297,14 @@ impl Range {
297297
///
298298
/// # Arguments
299299
///
300-
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values.
300+
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero) values.
301301
///
302302
/// # Examples
303303
///
304304
/// gen_range(3) => [0, 1, 2]
305305
/// gen_range(1, 4) => [1, 2, 3]
306306
/// gen_range(1, 7, 2) => [1, 3, 5]
307-
fn gen_range_inner(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
307+
pub fn gen_range_inner(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
308308
let (start_array, stop_array, step_array) = match args {
309309
[stop_array] => (None, as_int64_array(stop_array)?, None),
310310
[start_array, stop_array] => (
@@ -338,10 +338,23 @@ impl Range {
338338
usize::try_from(step.unsigned_abs()).map_err(|_| {
339339
not_impl_datafusion_err!("step {} can't fit into usize", step)
340340
})?;
341-
values.extend(
342-
gen_range_iter(start, stop, step < 0, self.include_upper_bound)
341+
if start < stop {
342+
values.extend(
343+
gen_range_iter(
344+
start,
345+
stop,
346+
step < 0,
347+
self.include_upper_bound,
348+
)
343349
.step_by(step_abs),
344-
);
350+
)
351+
} else {
352+
values.extend(
353+
gen_range_iter(start, stop, true, self.include_upper_bound)
354+
.step_by(step_abs),
355+
)
356+
};
357+
345358
offsets.push(values.len() as i32);
346359
valid.append_non_null();
347360
}
@@ -361,7 +374,7 @@ impl Range {
361374
Ok(arr)
362375
}
363376

364-
fn gen_range_date(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
377+
pub fn gen_range_date(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
365378
let [start, stop, step] = take_function_args(self.name(), args)?;
366379
let step = as_interval_mdn_array(step)?;
367380

@@ -425,7 +438,7 @@ impl Range {
425438
Ok(arr)
426439
}
427440

428-
fn gen_range_timestamp(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
441+
pub fn gen_range_timestamp(&self, args: &[ArrayRef]) -> Result<ArrayRef> {
429442
let [start, stop, step] = take_function_args(self.name(), args)?;
430443
let step = as_interval_mdn_array(step)?;
431444

datafusion/spark/src/function/array/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod array_contains;
1919
pub mod repeat;
20+
pub mod sequence;
2021
pub mod shuffle;
2122
pub mod slice;
2223
pub mod spark_array;
@@ -30,6 +31,7 @@ make_udf_function!(spark_array::SparkArray, array);
3031
make_udf_function!(shuffle::SparkShuffle, shuffle);
3132
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
3233
make_udf_function!(slice::SparkSlice, slice);
34+
make_udf_function!(sequence::SparkSequence, sequence);
3335

3436
pub mod expr_fn {
3537
use datafusion_functions::export_functions;
@@ -55,6 +57,11 @@ pub mod expr_fn {
5557
"Returns a slice of the array from the start index with the given length.",
5658
array start length
5759
));
60+
export_functions!((
61+
sequence,
62+
"Returns a sequence of the array from the start index and end index.",
63+
start stop
64+
));
5865
}
5966

6067
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -63,6 +70,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
6370
array(),
6471
shuffle(),
6572
array_repeat(),
73+
sequence(),
6674
slice(),
6775
]
6876
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::function::functions_nested_utils::make_scalar_function;
19+
use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano};
20+
use datafusion_common::internal_err;
21+
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
22+
use datafusion_expr::{
23+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24+
};
25+
use datafusion_functions_nested::range::Range;
26+
use std::any::Any;
27+
use std::sync::Arc;
28+
29+
/// Spark-compatible `sequence` expression.
30+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#sequence>
31+
#[derive(Debug, PartialEq, Eq, Hash)]
32+
pub struct SparkSequence {
33+
signature: Signature,
34+
}
35+
36+
impl Default for SparkSequence {
37+
fn default() -> Self {
38+
Self::new()
39+
}
40+
}
41+
42+
impl SparkSequence {
43+
pub fn new() -> Self {
44+
Self {
45+
signature: Signature::user_defined(Volatility::Immutable),
46+
}
47+
}
48+
}
49+
50+
impl ScalarUDFImpl for SparkSequence {
51+
fn as_any(&self) -> &dyn Any {
52+
self
53+
}
54+
55+
fn name(&self) -> &str {
56+
"sequence"
57+
}
58+
59+
fn signature(&self) -> &Signature {
60+
&self.signature
61+
}
62+
63+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
64+
internal_err!("return_field_from_args should be used instead")
65+
}
66+
67+
fn return_field_from_args(
68+
&self,
69+
args: datafusion_expr::ReturnFieldArgs,
70+
) -> Result<FieldRef> {
71+
let return_type = if args.arg_fields[0].data_type().is_null()
72+
|| args.arg_fields[1].data_type().is_null()
73+
{
74+
DataType::Null
75+
} else {
76+
DataType::List(Arc::new(Field::new_list_field(
77+
args.arg_fields[0].data_type().clone(),
78+
true,
79+
)))
80+
};
81+
82+
Ok(Arc::new(Field::new(
83+
"this_field_name_is_irrelevant",
84+
return_type,
85+
true,
86+
)))
87+
}
88+
89+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
90+
match arg_types.len() {
91+
2 => {
92+
let first_data_type =
93+
check_type(arg_types[0].clone(), "first".to_string().as_str())?;
94+
let second_data_type =
95+
check_type(arg_types[1].clone(), "second".to_string().as_str())?;
96+
97+
if !first_data_type.is_null()
98+
&& !second_data_type.is_null()
99+
&& (first_data_type != second_data_type)
100+
{
101+
return exec_err!(
102+
"first({first_data_type}) and second({second_data_type}) input types should be same"
103+
);
104+
}
105+
106+
Ok(vec![first_data_type, second_data_type])
107+
}
108+
3 => {
109+
let first_data_type =
110+
check_type(arg_types[0].clone(), "first".to_string().as_str())?;
111+
let second_data_type =
112+
check_type(arg_types[1].clone(), "second".to_string().as_str())?;
113+
let third_data_type = check_interval_type(
114+
arg_types[2].clone(),
115+
"third".to_string().as_str(),
116+
)?;
117+
118+
if !first_data_type.is_null() && !second_data_type.is_null() {
119+
if first_data_type != second_data_type {
120+
return exec_err!(
121+
"first({first_data_type}) and second({second_data_type}) input types should be same"
122+
);
123+
}
124+
125+
if !check_interval_type_by_first_type(
126+
&first_data_type,
127+
&third_data_type,
128+
) {
129+
return exec_err!(
130+
"interval type should be integer for integer input or time based"
131+
);
132+
}
133+
}
134+
135+
Ok(vec![first_data_type, second_data_type, third_data_type])
136+
}
137+
_ => {
138+
exec_err!("num of input parameters should be 2 or 3")
139+
}
140+
}
141+
}
142+
143+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
144+
let args = &args.args;
145+
146+
if args.iter().any(|arg| arg.data_type().is_null()) {
147+
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
148+
}
149+
match args[0].data_type() {
150+
DataType::Int64 => make_scalar_function(|args| {
151+
Range::generate_series().gen_range_inner(args)
152+
})(args),
153+
DataType::Date32 | DataType::Date64 => {
154+
let optional_new_args = add_interval_if_not_exists(args);
155+
let new_args = match optional_new_args {
156+
Some(new_args) => &new_args.to_owned(),
157+
None => args,
158+
};
159+
make_scalar_function(|args| Range::generate_series().gen_range_date(args))(
160+
new_args,
161+
)
162+
}
163+
DataType::Timestamp(_, _) => {
164+
let optional_new_args = add_interval_if_not_exists(args);
165+
let new_args = match optional_new_args {
166+
Some(new_args) => &new_args.to_owned(),
167+
None => args,
168+
};
169+
make_scalar_function(|args| {
170+
Range::generate_series().gen_range_timestamp(args)
171+
})(new_args)
172+
}
173+
dt => {
174+
internal_err!(
175+
"Signature failed to guard unknown input type for {}: {dt}",
176+
self.name()
177+
)
178+
}
179+
}
180+
}
181+
}
182+
183+
fn check_type(
184+
data_type: DataType,
185+
param_name: &str,
186+
) -> Result<DataType, DataFusionError> {
187+
let result_type = match data_type {
188+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
189+
DataType::Int64
190+
}
191+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
192+
DataType::UInt64
193+
}
194+
DataType::Date32
195+
| DataType::Date64
196+
| DataType::Timestamp(_, _)
197+
| DataType::Null => data_type,
198+
_ => {
199+
return exec_err!(
200+
"{} parameter type must be one of integer, date or timestamp type but found: {}",
201+
param_name,
202+
data_type
203+
);
204+
}
205+
};
206+
Ok(result_type)
207+
}
208+
209+
fn check_interval_type(
210+
data_type: DataType,
211+
param_name: &str,
212+
) -> Result<DataType, DataFusionError> {
213+
let result_type = match data_type {
214+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
215+
DataType::Int64
216+
}
217+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
218+
DataType::UInt64
219+
}
220+
DataType::Interval(_) => data_type,
221+
_ => {
222+
return exec_err!(
223+
"{} parameter type must be one of integer or interval type but found: {}",
224+
param_name,
225+
data_type
226+
);
227+
}
228+
};
229+
Ok(result_type)
230+
}
231+
232+
fn check_interval_type_by_first_type(
233+
first_data_type: &DataType,
234+
third_data_type: &DataType,
235+
) -> bool {
236+
match first_data_type {
237+
DataType::Int64 | DataType::UInt64 => first_data_type == third_data_type,
238+
DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) => {
239+
matches!(third_data_type, DataType::Interval(_))
240+
}
241+
_ => false,
242+
}
243+
}
244+
245+
fn add_interval_if_not_exists(args: &[ColumnarValue]) -> Option<Vec<ColumnarValue>> {
246+
if args.len() == 2 {
247+
let mut new_args = args.to_owned();
248+
new_args.push(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
249+
Some(IntervalMonthDayNano {
250+
months: 0,
251+
days: 1,
252+
nanoseconds: 0,
253+
}),
254+
)));
255+
Some(new_args)
256+
} else {
257+
None
258+
}
259+
}

0 commit comments

Comments
 (0)