Skip to content

Commit ad7c57a

Browse files
authored
perf: Optimize NULL handling in lcm, gcd (#21468)
## Which issue does this PR close? - Closes #21467. ## Rationale for this change This PR implements three distinct optimizations: 1. `lcm` was computing the result NULL buffer iteratively. This is relatively slow. Switching to Arrow's `try_binary` kernel makes the implementation more concise and also improves performance by computing the result NULL buffer via the bitwise union of the input NULL buffers. 2. The `gcd` scalar arg path was doing similarly; switching to Arrow's `try_unary` yields a similar speedup. 3. For the `gcd` scalar path, computing the GCD can only fail in a few edge cases (e.g., `gcd(i64::MIN, i64::MIN)`). It is cheap to check for these edge-cases; for most `gcd` inputs, we can use Arrow's `unary` kernel instead of `try_unary`. The former is more efficient because it allows LLVM to vectorize the code more effectively. Benchmarks (ARM64): ``` - gcd array and scalar: 2.9ms → 2.2ms, -25% faster - lcm both array: 2.7ms → 2.0ms, -26% faster ``` ## What changes are included in this PR? * Add benchmark for `lcm` * Improve SLT test coverage * Move Rust unit test for `lcm` to SLT * Optimize `lcm` and `gcm` NULL handling * Optimize `gcm` to avoid overhead for edge cases ## Are these changes tested? Yes. Benchmark results above. I inspected the generated code for the `gcd` case to confirm that LLVM is able to generate better code for the `unary` case than for the `try_unary` case. ## Are there any user-facing changes? No.
1 parent d61be49 commit ad7c57a

File tree

5 files changed

+185
-57
lines changed

5 files changed

+185
-57
lines changed

datafusion/functions/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ harness = false
132132
name = "gcd"
133133
required-features = ["math_expressions"]
134134

135+
[[bench]]
136+
harness = false
137+
name = "lcm"
138+
required-features = ["math_expressions"]
139+
135140
[[bench]]
136141
harness = false
137142
name = "nanvl"
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::Field;
19+
use arrow::{
20+
array::{ArrayRef, Int64Array},
21+
datatypes::DataType,
22+
};
23+
use criterion::{Criterion, criterion_group, criterion_main};
24+
use datafusion_common::config::ConfigOptions;
25+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
26+
use datafusion_functions::math::lcm;
27+
use rand::Rng;
28+
use std::hint::black_box;
29+
use std::sync::Arc;
30+
31+
fn generate_i64_array(n_rows: usize) -> ArrayRef {
32+
let mut rng = rand::rng();
33+
let values = (0..n_rows)
34+
.map(|_| rng.random_range(0..1000))
35+
.collect::<Vec<_>>();
36+
Arc::new(Int64Array::from(values)) as ArrayRef
37+
}
38+
39+
fn criterion_benchmark(c: &mut Criterion) {
40+
let n_rows = 100000;
41+
let array_a = ColumnarValue::Array(generate_i64_array(n_rows));
42+
let array_b = ColumnarValue::Array(generate_i64_array(n_rows));
43+
let udf = lcm();
44+
let config_options = Arc::new(ConfigOptions::default());
45+
46+
c.bench_function("lcm both array", |b| {
47+
b.iter(|| {
48+
black_box(
49+
udf.invoke_with_args(ScalarFunctionArgs {
50+
args: vec![array_a.clone(), array_b.clone()],
51+
arg_fields: vec![
52+
Field::new("a", array_a.data_type(), true).into(),
53+
Field::new("b", array_b.data_type(), true).into(),
54+
],
55+
number_rows: n_rows,
56+
return_field: Field::new("f", DataType::Int64, true).into(),
57+
config_options: Arc::clone(&config_options),
58+
})
59+
.expect("lcm should work on valid values"),
60+
)
61+
})
62+
});
63+
}
64+
65+
criterion_group!(benches, criterion_benchmark);
66+
criterion_main!(benches);

datafusion/functions/src/math/gcd.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray};
18+
use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
1919
use arrow::compute::try_binary;
2020
use arrow::datatypes::{DataType, Int64Type};
2121
use arrow::error::ArrowError;
@@ -126,18 +126,23 @@ fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
126126
}
127127

128128
fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
129+
let prim = arr.as_primitive::<Int64Type>();
129130
match scalar {
131+
Some(scalar_value) if scalar_value != 0 && scalar_value != i64::MIN => {
132+
// The gcd result divides both inputs' absolute values. When the
133+
// scalar is neither 0 nor i64::MIN, the gcd's absolute value fits
134+
// in i64, so the cast to i64 below cannot overflow. This allows us
135+
// to use `unary` instead of `try_unary`, which allows LLVM to
136+
// vectorize more effectively.
137+
let sv = scalar_value.unsigned_abs();
138+
let result: PrimitiveArray<Int64Type> =
139+
prim.unary(|val| unsigned_gcd(val.unsigned_abs(), sv) as i64);
140+
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
141+
}
130142
Some(scalar_value) => {
131-
let result: Result<Int64Array> = arr
132-
.as_primitive::<Int64Type>()
133-
.iter()
134-
.map(|val| match val {
135-
Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
136-
_ => Ok(None),
137-
})
138-
.collect();
139-
140-
result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
143+
let result: PrimitiveArray<Int64Type> =
144+
prim.try_unary(|val| compute_gcd(val, scalar_value))?;
145+
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
141146
}
142147
None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
143148
}
@@ -171,7 +176,8 @@ pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
171176
let a = x.unsigned_abs();
172177
let b = y.unsigned_abs();
173178
let r = unsigned_gcd(a, b);
174-
// gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64
179+
// The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or
180+
// gcd(i64::MIN, i64::MIN)), which does not fit into i64.
175181
r.try_into().map_err(|_| {
176182
ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
177183
})

datafusion/functions/src/math/lcm.rs

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::{ArrayRef, Int64Array};
20+
use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
21+
use arrow::compute::try_binary;
2122
use arrow::datatypes::DataType;
2223
use arrow::datatypes::DataType::Int64;
24+
use arrow::datatypes::Int64Type;
2325

2426
use arrow::error::ArrowError;
25-
use datafusion_common::{Result, arrow_datafusion_err, exec_err};
27+
use datafusion_common::{Result, exec_err};
2628
use datafusion_expr::{
2729
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
2830
Volatility,
@@ -91,7 +93,7 @@ impl ScalarUDFImpl for LcmFunc {
9193

9294
/// Lcm SQL function
9395
fn lcm(args: &[ArrayRef]) -> Result<ArrayRef> {
94-
let compute_lcm = |x: i64, y: i64| {
96+
let compute_lcm = |x: i64, y: i64| -> Result<i64, ArrowError> {
9597
if x == 0 || y == 0 {
9698
return Ok(0);
9799
}
@@ -105,55 +107,20 @@ fn lcm(args: &[ArrayRef]) -> Result<ArrayRef> {
105107
.checked_mul(b)
106108
.and_then(|v| i64::try_from(v).ok())
107109
.ok_or_else(|| {
108-
arrow_datafusion_err!(ArrowError::ComputeError(format!(
110+
ArrowError::ComputeError(format!(
109111
"Signed integer overflow in LCM({x}, {y})"
110-
)))
112+
))
111113
})
112114
};
113115

114116
match args[0].data_type() {
115117
Int64 => {
116-
let arg1 = downcast_named_arg!(&args[0], "x", Int64Array);
117-
let arg2 = downcast_named_arg!(&args[1], "y", Int64Array);
118-
119-
Ok(arg1
120-
.iter()
121-
.zip(arg2.iter())
122-
.map(|(a1, a2)| match (a1, a2) {
123-
(Some(a1), Some(a2)) => Ok(Some(compute_lcm(a1, a2)?)),
124-
_ => Ok(None),
125-
})
126-
.collect::<Result<Int64Array>>()
127-
.map(Arc::new)? as ArrayRef)
118+
let arg1 = args[0].as_primitive::<Int64Type>();
119+
let arg2 = args[1].as_primitive::<Int64Type>();
120+
121+
let result: PrimitiveArray<Int64Type> = try_binary(arg1, arg2, compute_lcm)?;
122+
Ok(Arc::new(result) as ArrayRef)
128123
}
129124
other => exec_err!("Unsupported data type {other:?} for function lcm"),
130125
}
131126
}
132-
133-
#[cfg(test)]
134-
mod test {
135-
use std::sync::Arc;
136-
137-
use arrow::array::{ArrayRef, Int64Array};
138-
139-
use datafusion_common::cast::as_int64_array;
140-
141-
use crate::math::lcm::lcm;
142-
143-
#[test]
144-
fn test_lcm_i64() {
145-
let args: Vec<ArrayRef> = vec![
146-
Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x
147-
Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y
148-
];
149-
150-
let result = lcm(&args).expect("failed to initialize function lcm");
151-
let ints = as_int64_array(&result).expect("failed to initialize function lcm");
152-
153-
assert_eq!(ints.len(), 4);
154-
assert_eq!(ints.value(0), 0);
155-
assert_eq!(ints.value(1), 6);
156-
assert_eq!(ints.value(2), 75);
157-
assert_eq!(ints.value(3), 16);
158-
}
159-
}

datafusion/sqllogictest/test_files/math.slt

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,87 @@ query RT
857857
SELECT log(2.5, 10.9::double), arrow_typeof(log(2.5, 10.9::double));
858858
----
859859
2.606992198152 Float64
860+
861+
# lcm with array and scalar
862+
863+
query I
864+
SELECT lcm(column1, 5) FROM (VALUES (0), (3), (25), (-16));
865+
----
866+
0
867+
15
868+
25
869+
80
870+
871+
query I
872+
SELECT lcm(6, column1) FROM (VALUES (4), (9), (0));
873+
----
874+
12
875+
18
876+
0
877+
878+
# lcm array and scalar with nulls in the array
879+
query I
880+
SELECT lcm(column1, 5) FROM (VALUES (0), (NULL), (25));
881+
----
882+
0
883+
NULL
884+
25
885+
886+
query I
887+
SELECT lcm(6, column1) FROM (VALUES (4), (NULL), (0));
888+
----
889+
12
890+
NULL
891+
0
892+
893+
# lcm scalar edge values
894+
query I
895+
SELECT lcm(9223372036854775807, 1);
896+
----
897+
9223372036854775807
898+
899+
query I
900+
SELECT lcm(9223372036854775807, 9223372036854775807);
901+
----
902+
9223372036854775807
903+
904+
# gcd with array and scalar
905+
906+
query I
907+
SELECT gcd(column1, 12) FROM (VALUES (8), (18), (0), (-36));
908+
----
909+
4
910+
6
911+
12
912+
12
913+
914+
query I
915+
SELECT gcd(15, column1) FROM (VALUES (10), (25), (0));
916+
----
917+
5
918+
5
919+
15
920+
921+
# gcd array and scalar with nulls in the array
922+
query I
923+
SELECT gcd(column1, 12) FROM (VALUES (8), (NULL), (0), (-36));
924+
----
925+
4
926+
NULL
927+
12
928+
12
929+
930+
query I
931+
SELECT gcd(15, column1) FROM (VALUES (NULL), (25), (0));
932+
----
933+
NULL
934+
5
935+
15
936+
937+
# gcd array and scalar=0
938+
query I
939+
SELECT gcd(column1, 0) FROM (VALUES (7), (-3), (0));
940+
----
941+
7
942+
3
943+
0

0 commit comments

Comments
 (0)