Skip to content

Commit d9ffaee

Browse files
Apply refactoring to fix array.slt test failures
1 parent 1048074 commit d9ffaee

File tree

3 files changed

+113
-28
lines changed

3 files changed

+113
-28
lines changed

datafusion/functions-nested/src/range.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ impl Range {
297297
///
298298
/// # Arguments
299299
///
300-
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero) values.
300+
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values.
301301
///
302302
/// # Examples
303303
///
@@ -338,23 +338,10 @@ impl Range {
338338
usize::try_from(step.unsigned_abs()).map_err(|_| {
339339
not_impl_datafusion_err!("step {} can't fit into usize", step)
340340
})?;
341-
if start < stop {
342-
values.extend(
343-
gen_range_iter(
344-
start,
345-
stop,
346-
step < 0,
347-
self.include_upper_bound,
348-
)
341+
values.extend(
342+
gen_range_iter(start, stop, step < 0, self.include_upper_bound)
349343
.step_by(step_abs),
350-
)
351-
} else {
352-
values.extend(
353-
gen_range_iter(start, stop, true, self.include_upper_bound)
354-
.step_by(step_abs),
355-
)
356-
};
357-
344+
);
358345
offsets.push(values.len() as i32);
359346
valid.append_non_null();
360347
}

datafusion/spark/src/function/array/sequence.rs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use crate::function::functions_nested_utils::make_scalar_function;
19+
use arrow::array::{Array, Int64Builder};
1920
use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano};
21+
use datafusion_common::cast::as_int64_array;
2022
use datafusion_common::internal_err;
2123
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
2224
use datafusion_expr::{
@@ -146,12 +148,21 @@ impl ScalarUDFImpl for SparkSequence {
146148
if args.iter().any(|arg| arg.data_type().is_null()) {
147149
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
148150
}
151+
149152
match args[0].data_type() {
150-
DataType::Int64 => make_scalar_function(|args| {
151-
Range::generate_series().gen_range_inner(args)
152-
})(args),
153+
DataType::Int64 => {
154+
validate_int64_sequence_step(args)?;
155+
let optional_new_args = add_step_argument_if_not_exists(args)?;
156+
let new_args = match optional_new_args {
157+
Some(new_args) => &new_args.to_owned(),
158+
None => args,
159+
};
160+
make_scalar_function(|args| {
161+
Range::generate_series().gen_range_inner(args)
162+
})(new_args)
163+
}
153164
DataType::Date32 | DataType::Date64 => {
154-
let optional_new_args = add_interval_if_not_exists(args);
165+
let optional_new_args = add_interval_argument_if_not_exists(args);
155166
let new_args = match optional_new_args {
156167
Some(new_args) => &new_args.to_owned(),
157168
None => args,
@@ -161,7 +172,7 @@ impl ScalarUDFImpl for SparkSequence {
161172
)
162173
}
163174
DataType::Timestamp(_, _) => {
164-
let optional_new_args = add_interval_if_not_exists(args);
175+
let optional_new_args = add_interval_argument_if_not_exists(args);
165176
let new_args = match optional_new_args {
166177
Some(new_args) => &new_args.to_owned(),
167178
None => args,
@@ -180,6 +191,62 @@ impl ScalarUDFImpl for SparkSequence {
180191
}
181192
}
182193

194+
/// Validates explicit `step` for 3-argument integer `sequence` (Spark semantics).
195+
fn validate_int64_sequence_step(args: &[ColumnarValue]) -> Result<()> {
196+
if args.len() != 3 {
197+
return Ok(());
198+
}
199+
let arrays = ColumnarValue::values_to_arrays(args)?;
200+
let start = as_int64_array(&arrays[0])?;
201+
let stop = as_int64_array(&arrays[1])?;
202+
let step = as_int64_array(&arrays[2])?;
203+
for i in 0..start.len() {
204+
if start.is_null(i) || stop.is_null(i) || step.is_null(i) {
205+
continue;
206+
}
207+
let s = start.value(i);
208+
let e = stop.value(i);
209+
let st = step.value(i);
210+
if st == 0 {
211+
return exec_err!("Step cannot be 0 for sequence");
212+
}
213+
if s < e && st <= 0 {
214+
return exec_err!("When start < stop, step must be positive");
215+
}
216+
if s > e && st >= 0 {
217+
return exec_err!("When start > stop, step must be negative");
218+
}
219+
}
220+
Ok(())
221+
}
222+
223+
/// When only start and stop are given, Spark picks step `1` if start ≤ stop and `-1` if start > stop.
224+
fn add_step_argument_if_not_exists(args: &[ColumnarValue]) -> Result<Option<Vec<ColumnarValue>>> {
225+
if args.len() != 2 {
226+
return Ok(None);
227+
}
228+
let arrays = ColumnarValue::values_to_arrays(args)?;
229+
let start = as_int64_array(&arrays[0])?;
230+
let stop = as_int64_array(&arrays[1])?;
231+
let len = start.len();
232+
let mut step = Int64Builder::with_capacity(len);
233+
for i in 0..len {
234+
if start.is_null(i) || stop.is_null(i) {
235+
step.append_null();
236+
} else if start.value(i) > stop.value(i) {
237+
step.append_value(-1);
238+
} else {
239+
step.append_value(1);
240+
}
241+
}
242+
let step = step.finish();
243+
Ok(Some(vec![
244+
args[0].clone(),
245+
args[1].clone(),
246+
ColumnarValue::Array(Arc::new(step)),
247+
]))
248+
}
249+
183250
fn check_type(
184251
data_type: DataType,
185252
param_name: &str,
@@ -242,7 +309,7 @@ fn check_interval_type_by_first_type(
242309
}
243310
}
244311

245-
fn add_interval_if_not_exists(args: &[ColumnarValue]) -> Option<Vec<ColumnarValue>> {
312+
fn add_interval_argument_if_not_exists(args: &[ColumnarValue]) -> Option<Vec<ColumnarValue>> {
246313
if args.len() == 2 {
247314
let mut new_args = args.to_owned();
248315
new_args.push(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(

datafusion/sqllogictest/test_files/spark/array/sequence.slt

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ SELECT sequence(0, 5, 1);
3939
----
4040
[0, 1, 2, 3, 4, 5]
4141

42-
query ?
43-
SELECT sequence(5, 0, 1);
44-
----
45-
[5, 4, 3, 2, 1, 0]
46-
4742
query ?
4843
SELECT sequence(-3::int, 3::int);
4944
----
@@ -290,3 +285,39 @@ SELECT sequence(1);
290285
DataFusion error: Error during planning: Execution error: Function 'sequence' user-defined coercion failed with: Execution error: num of input parameters should be 2 or 3. No function matches the given name and argument types 'sequence(Int64)'. You might need to add explicit type casts.
291286
Candidate functions:
292287
sequence(UserDefined)
288+
289+
290+
query ?
291+
SELECT sequence(2, 2);
292+
----
293+
[2]
294+
295+
query ?
296+
SELECT sequence(-2, -2);
297+
----
298+
[-2]
299+
300+
query error
301+
SELECT sequence(5, 0, 1);
302+
----
303+
DataFusion error: Execution error: When start > stop, step must be negative
304+
305+
306+
query error
307+
SELECT sequence(0, 5, -1);
308+
----
309+
DataFusion error: Execution error: When start < stop, step must be positive
310+
311+
312+
query error
313+
SELECT sequence(1, 5, 0);
314+
----
315+
DataFusion error: Execution error: Step cannot be 0 for sequence
316+
317+
318+
query error
319+
SELECT sequence(5, 1, 0);
320+
----
321+
DataFusion error: Execution error: Step cannot be 0 for sequence
322+
323+

0 commit comments

Comments
 (0)