Skip to content

Commit 5c2b123

Browse files
authored
feat(spark): implement array_repeat function (#19702)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #15914 - Closes #19701 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? Implementation of spark `array_repeat` function. ## Are these changes tested? yes ## Are there any user-facing changes? yes
1 parent 20870da commit 5c2b123

6 files changed

Lines changed: 333 additions & 114 deletions

File tree

datafusion/spark/src/function/array/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod repeat;
1819
pub mod shuffle;
1920
pub mod spark_array;
2021

@@ -24,6 +25,7 @@ use std::sync::Arc;
2425

2526
make_udf_function!(spark_array::SparkArray, array);
2627
make_udf_function!(shuffle::SparkShuffle, shuffle);
28+
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
2729

2830
pub mod expr_fn {
2931
use datafusion_functions::export_functions;
@@ -34,8 +36,13 @@ pub mod expr_fn {
3436
"Returns a random permutation of the given array.",
3537
args
3638
));
39+
export_functions!((
40+
array_repeat,
41+
"returns an array containing element count times.",
42+
element count
43+
));
3744
}
3845

3946
pub fn functions() -> Vec<Arc<ScalarUDF>> {
40-
vec![array(), shuffle()]
47+
vec![array(), shuffle(), array_repeat()]
4148
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::{DataType, Field};
19+
use datafusion_common::utils::take_function_args;
20+
use datafusion_common::{Result, ScalarValue, exec_err};
21+
use datafusion_expr::{
22+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23+
};
24+
use datafusion_functions_nested::repeat::ArrayRepeat;
25+
use std::any::Any;
26+
use std::sync::Arc;
27+
28+
use crate::function::null_utils::{
29+
NullMaskResolution, apply_null_mask, compute_null_mask,
30+
};
31+
32+
/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL.
33+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
34+
#[derive(Debug, PartialEq, Eq, Hash)]
35+
pub struct SparkArrayRepeat {
36+
signature: Signature,
37+
}
38+
39+
impl Default for SparkArrayRepeat {
40+
fn default() -> Self {
41+
Self::new()
42+
}
43+
}
44+
45+
impl SparkArrayRepeat {
46+
pub fn new() -> Self {
47+
Self {
48+
signature: Signature::user_defined(Volatility::Immutable),
49+
}
50+
}
51+
}
52+
53+
impl ScalarUDFImpl for SparkArrayRepeat {
54+
fn as_any(&self) -> &dyn Any {
55+
self
56+
}
57+
58+
fn name(&self) -> &str {
59+
"array_repeat"
60+
}
61+
62+
fn signature(&self) -> &Signature {
63+
&self.signature
64+
}
65+
66+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67+
Ok(DataType::List(Arc::new(Field::new_list_field(
68+
arg_types[0].clone(),
69+
true,
70+
))))
71+
}
72+
73+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74+
spark_array_repeat(args)
75+
}
76+
77+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78+
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
79+
80+
// Coerce the second argument to Int64/UInt64 if it's a numeric type
81+
let second = match second_type {
82+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
83+
DataType::Int64
84+
}
85+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
86+
DataType::UInt64
87+
}
88+
_ => return exec_err!("count must be an integer type"),
89+
};
90+
91+
Ok(vec![first_type.clone(), second])
92+
}
93+
}
94+
95+
/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
96+
/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
97+
fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98+
let ScalarFunctionArgs {
99+
args: arg_values,
100+
arg_fields,
101+
number_rows,
102+
return_field,
103+
config_options,
104+
} = args;
105+
let return_type = return_field.data_type().clone();
106+
107+
// Step 1: Check for NULL mask in incoming args
108+
let null_mask = compute_null_mask(&arg_values, number_rows)?;
109+
110+
// If any argument is null then return NULL immediately
111+
if matches!(null_mask, NullMaskResolution::ReturnNull) {
112+
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
113+
}
114+
115+
// Step 2: Delegate to DataFusion's array_repeat
116+
let array_repeat_func = ArrayRepeat::new();
117+
let func_args = ScalarFunctionArgs {
118+
args: arg_values,
119+
arg_fields,
120+
number_rows,
121+
return_field,
122+
config_options,
123+
};
124+
let result = array_repeat_func.invoke_with_args(func_args)?;
125+
126+
// Step 3: Apply NULL mask to result
127+
apply_null_mask(result, null_mask, &return_type)
128+
}

datafusion/spark/src/function/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub mod lambda;
3333
pub mod map;
3434
pub mod math;
3535
pub mod misc;
36+
mod null_utils;
3637
pub mod predicate;
3738
pub mod string;
3839
pub mod r#struct;
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::Array;
19+
use arrow::buffer::NullBuffer;
20+
use arrow::datatypes::DataType;
21+
use datafusion_common::{Result, ScalarValue};
22+
use datafusion_expr::ColumnarValue;
23+
use std::sync::Arc;
24+
25+
pub(crate) enum NullMaskResolution {
26+
/// Return NULL as the result (e.g., scalar inputs with at least one NULL)
27+
ReturnNull,
28+
/// No null mask needed (e.g., all scalar inputs are non-NULL)
29+
NoMask,
30+
/// Null mask to apply for arrays
31+
Apply(NullBuffer),
32+
}
33+
34+
/// Compute NULL mask for the arguments using NullBuffer::union
35+
pub(crate) fn compute_null_mask(
36+
args: &[ColumnarValue],
37+
number_rows: usize,
38+
) -> Result<NullMaskResolution> {
39+
// Check if all arguments are scalars
40+
let all_scalars = args
41+
.iter()
42+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
43+
44+
if all_scalars {
45+
// For scalars, check if any is NULL
46+
for arg in args {
47+
if let ColumnarValue::Scalar(scalar) = arg
48+
&& scalar.is_null()
49+
{
50+
return Ok(NullMaskResolution::ReturnNull);
51+
}
52+
}
53+
// No NULLs in scalars
54+
Ok(NullMaskResolution::NoMask)
55+
} else {
56+
// For arrays, compute NULL mask for each row using NullBuffer::union
57+
let array_len = args
58+
.iter()
59+
.find_map(|arg| match arg {
60+
ColumnarValue::Array(array) => Some(array.len()),
61+
_ => None,
62+
})
63+
.unwrap_or(number_rows);
64+
65+
// Convert all scalars to arrays for uniform processing
66+
let arrays: Result<Vec<_>> = args
67+
.iter()
68+
.map(|arg| match arg {
69+
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
70+
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
71+
})
72+
.collect();
73+
let arrays = arrays?;
74+
75+
// Use NullBuffer::union to combine all null buffers
76+
let combined_nulls = arrays
77+
.iter()
78+
.map(|arr| arr.nulls())
79+
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
80+
81+
match combined_nulls {
82+
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
83+
None => Ok(NullMaskResolution::NoMask),
84+
}
85+
}
86+
}
87+
88+
/// Apply NULL mask to the result using NullBuffer::union
89+
pub(crate) fn apply_null_mask(
90+
result: ColumnarValue,
91+
null_mask: NullMaskResolution,
92+
return_type: &DataType,
93+
) -> Result<ColumnarValue> {
94+
match (result, null_mask) {
95+
// Scalar with ReturnNull mask means return NULL of the correct type
96+
(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
97+
Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?))
98+
}
99+
// Scalar without mask, return as-is
100+
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
101+
// Array with NULL mask - use NullBuffer::union to combine nulls
102+
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
103+
// Combine the result's existing nulls with our computed null mask
104+
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
105+
106+
// Create new array with combined nulls
107+
let new_array = array
108+
.into_data()
109+
.into_builder()
110+
.nulls(combined_nulls)
111+
.build()?;
112+
113+
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
114+
new_array,
115+
))))
116+
}
117+
// Array without NULL mask, return as-is
118+
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
119+
// Edge cases that shouldn't happen in practice
120+
(scalar, _) => Ok(scalar),
121+
}
122+
}

0 commit comments

Comments
 (0)