Skip to content

Commit a1e88e2

Browse files
authored
feat: decimal support for gcd and lcm (#22655)
## Which issue does this PR close? - Closes #19057. ## Rationale for this change A binary gcd and lcm UDF in the datafusion-functions crate supports only Int64, but not Decimals. Adding missing support for decimals. ## What changes are included in this PR? 1. Updated gcd and lcm functions to add decimal support. The integer path is more performant and stays intact. For decimals, the Euclidean algorithm is used for GCD 2. Added coercion rules: casting to decimals if any argument is decimal; otherwise, stay with ints as before 3. Common functionality extracted to `common.rs` to avoid inter-UDF dependency 4. In order to use `calculate_binary_math` for Decimals, updated it to accept a target type instead of raw `Decimal128Type::DATA_TYPE` - it causes scaling issues for these UDFs, see #19621 A bit more on (4). The driving force is this failing example: ```sql query R select gcd(2::decimal(38, 0), 3::decimal(38, 0)); ---- 1 ``` Previously in #19874, I suggested a more complicated solution to extend `calculate_binary_math`. However, it only affected gcd/lcm and could be considered overkill. This PR extends these functions with an extra parameter `cast_target` for `calculate_binary_decimal_math` to perform a proper cast to the actual type used, rather than to the default `Decimal128Type::DATA_TYPE` - it is much lighter. ## Are these changes tested? - Added unit test for UDFs with decimals for array and scalar paths - Added unit tests for the gcd/lcm math itself - Added new SLT tests for decimals ## Are there any user-facing changes? No
1 parent a66c898 commit a1e88e2

7 files changed

Lines changed: 758 additions & 128 deletions

File tree

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::ArrowNativeTypeOp;
19+
use arrow::error::ArrowError;
20+
use num_traits::{CheckedMul, CheckedNeg, Signed};
21+
use std::fmt::Display;
22+
use std::mem::swap;
23+
use std::ops::RemAssign;
24+
25+
/// A gcd helper to compute GCD using Euclidean GCD algorithm
26+
/// on non-negative numbers (scalars and decimals)
27+
fn gcd_helper<T>(a: T, b: T) -> Result<T, ArrowError>
28+
where
29+
T: ArrowNativeTypeOp + RemAssign + CheckedNeg,
30+
{
31+
debug_assert!(a >= T::ZERO);
32+
debug_assert!(b >= T::ZERO);
33+
let (mut a, mut b) = if a > b { (a, b) } else { (b, a) };
34+
35+
while b != T::ZERO {
36+
swap(&mut a, &mut b);
37+
b %= a;
38+
}
39+
40+
Ok(a)
41+
}
42+
43+
/// Computes gcd of two unsigned integers using Binary GCD algorithm
44+
/// Faster, works with integers only
45+
pub(crate) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
46+
if a == 0 {
47+
return b;
48+
}
49+
if b == 0 {
50+
return a;
51+
}
52+
53+
let shift = (a | b).trailing_zeros();
54+
a >>= a.trailing_zeros();
55+
loop {
56+
b >>= b.trailing_zeros();
57+
if a > b {
58+
swap(&mut a, &mut b);
59+
}
60+
b -= a;
61+
if b == 0 {
62+
return a << shift;
63+
}
64+
}
65+
}
66+
67+
/// Computes gcd of two signed numbers (integers or decimals),
68+
/// checking for output integer overflow
69+
pub(crate) fn gcd_signed<T>(x: T, y: T) -> Result<T, ArrowError>
70+
where
71+
T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg,
72+
{
73+
// Make absolute values, keeping type
74+
let a = if x.is_positive() {
75+
x
76+
} else {
77+
x.checked_neg()
78+
.ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))?
79+
};
80+
let b = if y.is_positive() {
81+
y
82+
} else {
83+
y.checked_neg()
84+
.ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))?
85+
};
86+
// Call with signed numbers
87+
gcd_helper(a, b)
88+
}
89+
90+
/// Computes gcd of two signed integers
91+
pub(crate) fn gcd_signed_int(x: i64, y: i64) -> Result<i64, ArrowError> {
92+
let a = x.unsigned_abs();
93+
let b = y.unsigned_abs();
94+
95+
// Call with unsigned numbers
96+
let r = unsigned_gcd(a, b);
97+
// gcd(i64::MIN, i64::MIN) = u64::MIN.unsigned_abs() cannot fit into i64
98+
r.try_into().map_err(|_| {
99+
ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
100+
})
101+
}
102+
103+
/// Computes lcm of two signed numbers (integers or decimals)
104+
pub(crate) fn lcm_signed<T>(x: T, y: T) -> Result<T, ArrowError>
105+
where
106+
T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg + CheckedMul + Display,
107+
{
108+
if x == T::ZERO || y == T::ZERO {
109+
return Ok(T::ZERO);
110+
}
111+
112+
// Make absolute values, keeping type
113+
let a = if x.is_positive() {
114+
x
115+
} else {
116+
x.checked_neg()
117+
.ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))?
118+
};
119+
let b = if y.is_positive() {
120+
y
121+
} else {
122+
y.checked_neg()
123+
.ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))?
124+
};
125+
// Call with signed numbers
126+
let gcd = gcd_helper(a, b)?;
127+
// gcd is not zero since both a and b are not zero, so the division is safe.
128+
(a / gcd).checked_mul(&b).ok_or_else(|| {
129+
ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})"))
130+
})
131+
}
132+
133+
/// Computes lcm of two signed integers,
134+
/// checking for output integer overflow
135+
pub(crate) fn lcm_signed_int(x: i64, y: i64) -> Result<i64, ArrowError> {
136+
if x == 0 || y == 0 {
137+
return Ok(0);
138+
}
139+
140+
let a = x.unsigned_abs();
141+
let b = y.unsigned_abs();
142+
143+
let gcd = gcd_helper::<u64>(a, b)?;
144+
// gcd is not zero since both a and b are not zero, so the division is safe.
145+
(a / gcd)
146+
.checked_mul(b)
147+
.and_then(|v| i64::try_from(v).ok())
148+
.ok_or_else(|| {
149+
ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})"))
150+
})
151+
}
152+
153+
#[cfg(test)]
154+
mod tests {
155+
use super::*;
156+
use arrow_buffer::i256;
157+
158+
const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [
159+
// Basic cases
160+
(48, 18, 6),
161+
(54, 24, 6),
162+
(100, 50, 50),
163+
(17, 19, 1),
164+
(21, 14, 7),
165+
// Edge cases with 0
166+
(0, 0, 0),
167+
(0, 5, 5),
168+
(10, 0, 10),
169+
// Same numbers
170+
(7, 7, 7),
171+
(100, 100, 100),
172+
// One is 1
173+
(1, 1, 1),
174+
(1, 100, 1),
175+
(999, 1, 1),
176+
// Large numbers
177+
(1000000, 500000, 500000),
178+
(123456, 789012, 12),
179+
(999999, 111111, 111111),
180+
// Powers of 2
181+
(64, 128, 64),
182+
(1024, 2048, 1024),
183+
];
184+
185+
const LCM_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [
186+
// Basic cases
187+
(48, 18, 144),
188+
(54, 24, 216),
189+
(100, 50, 100),
190+
(17, 19, 323),
191+
(21, 14, 42),
192+
// Edge cases with 0
193+
(0, 0, 0),
194+
(0, 5, 0),
195+
(10, 0, 0),
196+
// Same numbers
197+
(7, 7, 7),
198+
(100, 100, 100),
199+
// One is 1
200+
(1, 1, 1),
201+
(1, 100, 100),
202+
(999, 1, 999),
203+
// Large numbers
204+
(1_000_000, 500_000, 1_000_000),
205+
(123_456, 789_012, 8_117_355_456),
206+
(999_999, 111_111, 999_999),
207+
// Powers of 2
208+
(64, 128, 128),
209+
(1024, 2048, 2048),
210+
];
211+
212+
#[test]
213+
fn test_gcd_i64() {
214+
let test_cases: Vec<(i64, i64, i64)> = [
215+
GCD_COMMON_TEST_CASES.into(),
216+
vec![
217+
// Max value cases
218+
(1, i64::MAX, 1),
219+
(i64::MAX, 1, 1),
220+
(i64::MAX, i64::MAX, i64::MAX),
221+
],
222+
]
223+
.concat();
224+
225+
// Success cases
226+
for (a, b, expected) in test_cases {
227+
let actual_euclidean = gcd_signed(a, b).expect("should succeed");
228+
assert_eq!(
229+
actual_euclidean, expected,
230+
"gcd_signed({a}, {b}) expected {expected}, actual {actual_euclidean}"
231+
);
232+
let actual_binary: i64 =
233+
unsigned_gcd(a.try_into().unwrap(), b.try_into().unwrap())
234+
.try_into()
235+
.expect("overflow");
236+
assert_eq!(
237+
actual_binary, expected,
238+
"unsigned_gcd({a}, {b}) expected {expected}, actual {actual_binary}"
239+
);
240+
}
241+
}
242+
243+
#[test]
244+
fn test_gcd_decimal() {
245+
let test_cases: Vec<(i256, i256, i256)> = [
246+
GCD_COMMON_TEST_CASES
247+
.iter()
248+
.map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c)))
249+
.collect(),
250+
vec![
251+
(i256::from(1), i256::MAX, i256::from(1)),
252+
(i256::MAX, i256::from(1), i256::from(1)),
253+
(i256::MAX, i256::MAX, i256::MAX),
254+
],
255+
]
256+
.concat();
257+
258+
// Success cases
259+
for (a, b, expected) in test_cases {
260+
let actual = gcd_signed(a, b).expect("should succeed");
261+
assert_eq!(
262+
actual, expected,
263+
"euclid_gcd({a}, {b}) expected {expected}, actual {actual}"
264+
);
265+
}
266+
}
267+
268+
#[test]
269+
fn test_lcm_i64() {
270+
let test_cases: Vec<(i64, i64, i64)> = [
271+
LCM_COMMON_TEST_CASES.into(),
272+
vec![
273+
// Negative inputs - LCM is always non-negative
274+
(-6, 4, 12),
275+
(-4, -6, 12),
276+
// Max value cases
277+
(1, i64::MAX, i64::MAX),
278+
(i64::MAX, 1, i64::MAX),
279+
(i64::MAX, i64::MAX, i64::MAX),
280+
],
281+
]
282+
.concat();
283+
284+
for (a, b, expected) in test_cases {
285+
let actual = lcm_signed_int(a, b).expect("should succeed");
286+
assert_eq!(
287+
actual, expected,
288+
"lcm_signed_int({a}, {b}) expected {expected}, actual {actual}"
289+
);
290+
}
291+
}
292+
293+
#[test]
294+
fn test_lcm_decimal() {
295+
let test_cases: Vec<(i256, i256, i256)> = [
296+
LCM_COMMON_TEST_CASES
297+
.iter()
298+
.map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c)))
299+
.collect(),
300+
vec![
301+
// Negative inputs - LCM is always non-negative
302+
(i256::from(-6_i64), i256::from(4_i64), i256::from(12_i64)),
303+
(i256::from(-4_i64), i256::from(-6_i64), i256::from(12_i64)),
304+
// Max value cases
305+
(i256::from(1_i64), i256::MAX, i256::MAX),
306+
(i256::MAX, i256::from(1_i64), i256::MAX),
307+
(i256::MAX, i256::MAX, i256::MAX),
308+
],
309+
]
310+
.concat();
311+
312+
for (a, b, expected) in test_cases {
313+
let actual = lcm_signed(a, b).expect("should succeed");
314+
assert_eq!(
315+
actual, expected,
316+
"lcm_signed({a}, {b}) expected {expected}, actual {actual}"
317+
);
318+
}
319+
}
320+
}

0 commit comments

Comments
 (0)