Skip to content

Commit a871070

Browse files
Apply refactoring to fix array.slt test failures
1 parent a3e7e95 commit a871070

File tree

3 files changed

+117
-33
lines changed

3 files changed

+117
-33
lines changed

datafusion/functions-nested/src/range.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ impl Range {
289289
///
290290
/// # Arguments
291291
///
292-
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero) values.
292+
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values.
293293
///
294294
/// # Examples
295295
///
@@ -330,23 +330,10 @@ impl Range {
330330
usize::try_from(step.unsigned_abs()).map_err(|_| {
331331
not_impl_datafusion_err!("step {} can't fit into usize", step)
332332
})?;
333-
if start < stop {
334-
values.extend(
335-
gen_range_iter(
336-
start,
337-
stop,
338-
step < 0,
339-
self.include_upper_bound,
340-
)
333+
values.extend(
334+
gen_range_iter(start, stop, step < 0, self.include_upper_bound)
341335
.step_by(step_abs),
342-
)
343-
} else {
344-
values.extend(
345-
gen_range_iter(start, stop, true, self.include_upper_bound)
346-
.step_by(step_abs),
347-
)
348-
};
349-
336+
);
350337
offsets.push(values.len() as i32);
351338
valid.append_non_null();
352339
}

datafusion/spark/src/function/array/sequence.rs

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
// under the License.
1717

1818
use crate::function::functions_nested_utils::make_scalar_function;
19+
use arrow::array::{Array, Int64Builder};
1920
use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano};
21+
use datafusion_common::cast::as_int64_array;
2022
use datafusion_common::internal_err;
2123
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
2224
use datafusion_expr::{
2325
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2426
};
2527
use datafusion_functions_nested::range::Range;
26-
use std::any::Any;
2728
use std::sync::Arc;
2829

2930
/// Spark-compatible `sequence` expression.
@@ -48,10 +49,6 @@ impl SparkSequence {
4849
}
4950

5051
impl ScalarUDFImpl for SparkSequence {
51-
fn as_any(&self) -> &dyn Any {
52-
self
53-
}
54-
5552
fn name(&self) -> &str {
5653
"sequence"
5754
}
@@ -146,12 +143,21 @@ impl ScalarUDFImpl for SparkSequence {
146143
if args.iter().any(|arg| arg.data_type().is_null()) {
147144
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
148145
}
146+
149147
match args[0].data_type() {
150-
DataType::Int64 => make_scalar_function(|args| {
151-
Range::generate_series().gen_range_inner(args)
152-
})(args),
148+
DataType::Int64 => {
149+
validate_int64_sequence_step(args)?;
150+
let optional_new_args = add_step_argument_if_not_exists(args)?;
151+
let new_args = match optional_new_args {
152+
Some(new_args) => &new_args.to_owned(),
153+
None => args,
154+
};
155+
make_scalar_function(|args| {
156+
Range::generate_series().gen_range_inner(args)
157+
})(new_args)
158+
}
153159
DataType::Date32 | DataType::Date64 => {
154-
let optional_new_args = add_interval_if_not_exists(args);
160+
let optional_new_args = add_interval_argument_if_not_exists(args);
155161
let new_args = match optional_new_args {
156162
Some(new_args) => &new_args.to_owned(),
157163
None => args,
@@ -161,7 +167,7 @@ impl ScalarUDFImpl for SparkSequence {
161167
)
162168
}
163169
DataType::Timestamp(_, _) => {
164-
let optional_new_args = add_interval_if_not_exists(args);
170+
let optional_new_args = add_interval_argument_if_not_exists(args);
165171
let new_args = match optional_new_args {
166172
Some(new_args) => &new_args.to_owned(),
167173
None => args,
@@ -180,6 +186,64 @@ impl ScalarUDFImpl for SparkSequence {
180186
}
181187
}
182188

189+
/// Validates explicit `step` for 3-argument integer `sequence` (Spark semantics).
190+
fn validate_int64_sequence_step(args: &[ColumnarValue]) -> Result<()> {
191+
if args.len() != 3 {
192+
return Ok(());
193+
}
194+
let arrays = ColumnarValue::values_to_arrays(args)?;
195+
let start = as_int64_array(&arrays[0])?;
196+
let stop = as_int64_array(&arrays[1])?;
197+
let step = as_int64_array(&arrays[2])?;
198+
for i in 0..start.len() {
199+
if start.is_null(i) || stop.is_null(i) || step.is_null(i) {
200+
continue;
201+
}
202+
let s = start.value(i);
203+
let e = stop.value(i);
204+
let st = step.value(i);
205+
if st == 0 {
206+
return exec_err!("Step cannot be 0 for sequence");
207+
}
208+
if s < e && st <= 0 {
209+
return exec_err!("When start < stop, step must be positive");
210+
}
211+
if s > e && st >= 0 {
212+
return exec_err!("When start > stop, step must be negative");
213+
}
214+
}
215+
Ok(())
216+
}
217+
218+
/// When only start and stop are given, Spark picks step `1` if start ≤ stop and `-1` if start > stop.
219+
fn add_step_argument_if_not_exists(
220+
args: &[ColumnarValue],
221+
) -> Result<Option<Vec<ColumnarValue>>> {
222+
if args.len() != 2 {
223+
return Ok(None);
224+
}
225+
let arrays = ColumnarValue::values_to_arrays(args)?;
226+
let start = as_int64_array(&arrays[0])?;
227+
let stop = as_int64_array(&arrays[1])?;
228+
let len = start.len();
229+
let mut step = Int64Builder::with_capacity(len);
230+
for i in 0..len {
231+
if start.is_null(i) || stop.is_null(i) {
232+
step.append_null();
233+
} else if start.value(i) > stop.value(i) {
234+
step.append_value(-1);
235+
} else {
236+
step.append_value(1);
237+
}
238+
}
239+
let step = step.finish();
240+
Ok(Some(vec![
241+
args[0].clone(),
242+
args[1].clone(),
243+
ColumnarValue::Array(Arc::new(step)),
244+
]))
245+
}
246+
183247
fn check_type(
184248
data_type: DataType,
185249
param_name: &str,
@@ -242,7 +306,9 @@ fn check_interval_type_by_first_type(
242306
}
243307
}
244308

245-
fn add_interval_if_not_exists(args: &[ColumnarValue]) -> Option<Vec<ColumnarValue>> {
309+
fn add_interval_argument_if_not_exists(
310+
args: &[ColumnarValue],
311+
) -> Option<Vec<ColumnarValue>> {
246312
if args.len() == 2 {
247313
let mut new_args = args.to_owned();
248314
new_args.push(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(

datafusion/sqllogictest/test_files/spark/array/sequence.slt

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ SELECT sequence(0, 5, 1);
3939
----
4040
[0, 1, 2, 3, 4, 5]
4141

42-
query ?
43-
SELECT sequence(5, 0, 1);
44-
----
45-
[5, 4, 3, 2, 1, 0]
46-
4742
query ?
4843
SELECT sequence(-3::int, 3::int);
4944
----
@@ -290,3 +285,39 @@ SELECT sequence(1);
290285
DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: num of input parameters should be 2 or 3. No function matches the given name and argument types 'sequence(Int64)'. You might need to add explicit type casts.
291286
Candidate functions:
292287
sequence(UserDefined)
288+
289+
290+
query ?
291+
SELECT sequence(2, 2);
292+
----
293+
[2]
294+
295+
query ?
296+
SELECT sequence(-2, -2);
297+
----
298+
[-2]
299+
300+
query error
301+
SELECT sequence(5, 0, 1);
302+
----
303+
DataFusion error: Execution error: When start > stop, step must be negative
304+
305+
306+
query error
307+
SELECT sequence(0, 5, -1);
308+
----
309+
DataFusion error: Execution error: When start < stop, step must be positive
310+
311+
312+
query error
313+
SELECT sequence(1, 5, 0);
314+
----
315+
DataFusion error: Execution error: Step cannot be 0 for sequence
316+
317+
318+
query error
319+
SELECT sequence(5, 1, 0);
320+
----
321+
DataFusion error: Execution error: Step cannot be 0 for sequence
322+
323+

0 commit comments

Comments
 (0)