Skip to content

Commit 29f1acd

Browse files
authored
feat: change approx percentile/median UDFs to return floats (#21074)
## Which issue does this PR close? - Closes #18092. ## Rationale for this change 1. Migrating to the modern TypeSignature API: [264030c/datafusion/expr-common/src/signature.rs](https://github.com/apache/datafusion/blob/264030cca76d0bdb4d8809f252b422e72624a345/datafusion/expr-common/src/signature.rs) 2. Coercing types of `approx_percentile_cont`, `approx_percentile_cont_with_weight`, `approx_median` to floats. It matches PostgreSQL, DuckDB, and ClickHouse behaviour, except for Spark. ## What changes are included in this PR? - Port remaining UDFs (approx_percentile_cont, approx_percentile_cont_with_weight, approx_median, stub functions) to signature APIs - Deprecate INTEGERS and NUMERICS arrays in favour of using the TypeSignature API - They are not removed yet, but marked as deprecated to avoid breaking downstream - Fix up a SLT for approx_percentile_cont, approx_median to make sure it returns a float ## Are these changes tested? - Tests are passing - Updated tests to expect floats in return types ## Are there any user-facing changes? - Signatures of `approx_percentile_cont`, `approx_percentile_cont_with_weight`, `approx_median` changed, so they now return floats instead of integers (as seen in tests)
1 parent afc0784 commit 29f1acd

File tree

15 files changed

+281
-222
lines changed

15 files changed

+281
-222
lines changed

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ async fn test_fn_approx_median() -> Result<()> {
402402
+-----------------------+
403403
| approx_median(test.b) |
404404
+-----------------------+
405-
| 10 |
405+
| 10.0 |
406406
+-----------------------+
407407
");
408408

@@ -422,7 +422,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
422422
+---------------------------------------------------------------------------+
423423
| approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] |
424424
+---------------------------------------------------------------------------+
425-
| 10 |
425+
| 10.0 |
426426
+---------------------------------------------------------------------------+
427427
");
428428

@@ -437,7 +437,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
437437
+----------------------------------------------------------------------------+
438438
| approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] |
439439
+----------------------------------------------------------------------------+
440-
| 100 |
440+
| 100.0 |
441441
+----------------------------------------------------------------------------+
442442
");
443443

@@ -457,7 +457,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
457457
+--------------------------------------------------------------------+
458458
| approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] |
459459
+--------------------------------------------------------------------+
460-
| 10 |
460+
| 10.0 |
461461
+--------------------------------------------------------------------+
462462
"
463463
);
@@ -477,7 +477,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
477477
+---------------------------------------------------------------------+
478478
| approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] |
479479
+---------------------------------------------------------------------+
480-
| 100 |
480+
| 100.0 |
481481
+---------------------------------------------------------------------+
482482
"
483483
);
@@ -494,7 +494,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
494494
+------------------------------------------------------------------------------------+
495495
| approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] |
496496
+------------------------------------------------------------------------------------+
497-
| 30 |
497+
| 30.25 |
498498
+------------------------------------------------------------------------------------+
499499
");
500500

@@ -510,7 +510,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
510510
+-------------------------------------------------------------------------------------+
511511
| approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] |
512512
+-------------------------------------------------------------------------------------+
513-
| 69 |
513+
| 69.85 |
514514
+-------------------------------------------------------------------------------------+
515515
");
516516

datafusion/core/tests/dataframe/mod.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,26 +1204,26 @@ async fn window_using_aggregates() -> Result<()> {
12041204
| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |
12051205
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
12061206
| | | | | | | | 1 | -85 |
1207-
| -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 |
1208-
| -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 |
1209-
| -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 |
1210-
| -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 |
1211-
| -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 |
1212-
| -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 |
1213-
| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |
1214-
| -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 |
1215-
| -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 |
1216-
| -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 |
1217-
| -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 |
1218-
| -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 |
1219-
| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |
1220-
| -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 |
1221-
| -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 |
1222-
| -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 |
1223-
| -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 |
1224-
| -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 |
1225-
| -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 |
1226-
| -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 |
1207+
| -85 | -101 | 14 | -12.0 | -12 | 83 | -101 | 4 | -54 |
1208+
| -85 | -101 | 17 | -25.0 | -25 | 83 | -101 | 5 | -31 |
1209+
| -85 | -12 | 10 | -32.75 | -34 | 83 | -85 | 3 | 13 |
1210+
| -85 | -25 | 3 | -56.0 | -56 | -25 | -85 | 1 | -5 |
1211+
| -85 | -31 | 18 | -29.75 | -28 | 83 | -101 | 5 | 36 |
1212+
| -85 | -38 | 16 | -25.0 | -25 | 83 | -101 | 4 | 65 |
1213+
| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |
1214+
| -85 | -48 | 6 | -35.75 | -36 | 83 | -85 | 2 | -43 |
1215+
| -85 | -5 | 4 | -37.75 | -40 | -5 | -85 | 1 | 83 |
1216+
| -85 | -54 | 15 | -17.0 | -18 | 83 | -101 | 4 | -38 |
1217+
| -85 | -56 | 2 | -70.5 | -70 | -56 | -85 | 1 | -25 |
1218+
| -85 | -72 | 9 | -43.0 | -43 | 83 | -85 | 3 | -12 |
1219+
| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |
1220+
| -85 | 13 | 11 | -17.0 | -18 | 83 | -85 | 3 | 14 |
1221+
| -85 | 13 | 11 | -25.0 | -25 | 83 | -85 | 3 | 13 |
1222+
| -85 | 14 | 12 | -12.0 | -12 | 83 | -85 | 3 | 17 |
1223+
| -85 | 17 | 13 | -11.25 | -8 | 83 | -85 | 4 | -101 |
1224+
| -85 | 45 | 8 | -34.5 | -34 | 83 | -85 | 3 | -72 |
1225+
| -85 | 65 | 17 | -17.0 | -18 | 83 | -101 | 5 | -101 |
1226+
| -85 | 83 | 5 | -25.0 | -25 | 83 | -85 | 2 | -48 |
12271227
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
12281228
"
12291229
);

datafusion/expr-common/src/signature.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use std::fmt::Display;
2121
use std::hash::Hash;
2222
use std::sync::Arc;
2323

24-
use crate::type_coercion::aggregates::NUMERICS;
2524
use arrow::datatypes::{
2625
DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType,
2726
Decimal128Type, DecimalType, Field, IntervalUnit, TimeUnit,
@@ -596,6 +595,20 @@ impl Display for ArrayFunctionArgument {
596595
}
597596
}
598597

598+
static NUMERICS: &[DataType] = &[
599+
DataType::Int8,
600+
DataType::Int16,
601+
DataType::Int32,
602+
DataType::Int64,
603+
DataType::UInt8,
604+
DataType::UInt16,
605+
DataType::UInt32,
606+
DataType::UInt64,
607+
DataType::Float16,
608+
DataType::Float32,
609+
DataType::Float64,
610+
];
611+
599612
impl TypeSignature {
600613
pub fn to_string_repr(&self) -> Vec<String> {
601614
match self {

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ use arrow::datatypes::{DataType, FieldRef};
2020

2121
use datafusion_common::{Result, internal_err, plan_err};
2222

23-
// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
24-
// see https://github.com/apache/datafusion/issues/18092
23+
#[deprecated(since = "54.0.0", note = "Use functions signatures")]
2524
pub static INTEGERS: &[DataType] = &[
2625
DataType::Int8,
2726
DataType::Int16,
@@ -33,6 +32,7 @@ pub static INTEGERS: &[DataType] = &[
3332
DataType::UInt64,
3433
];
3534

35+
#[deprecated(since = "54.0.0", note = "Use functions signatures")]
3636
pub static NUMERICS: &[DataType] = &[
3737
DataType::Int8,
3838
DataType::Int16,

datafusion/expr/src/test/function_stub.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ use datafusion_common::plan_err;
2929
use datafusion_common::{Result, exec_err, not_impl_err, utils::take_function_args};
3030

3131
use crate::Volatility::Immutable;
32-
use crate::type_coercion::aggregates::NUMERICS;
3332
use crate::{
34-
Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
33+
Accumulator, AggregateUDFImpl, Coercion, Expr, GroupsAccumulator, ReversedUDAF,
34+
Signature, TypeSignature, TypeSignatureClass,
3535
expr::AggregateFunction,
3636
function::{AccumulatorArgs, StateFieldsArgs},
3737
utils::AggregateOrderSensitivity,
3838
};
39+
use datafusion_common::types::{NativeType, logical_float64};
3940

4041
macro_rules! create_func {
4142
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
@@ -444,9 +445,22 @@ pub struct Avg {
444445

445446
impl Avg {
446447
pub fn new() -> Self {
448+
let signature = Signature::one_of(
449+
vec![
450+
TypeSignature::Coercible(vec![Coercion::new_exact(
451+
TypeSignatureClass::Decimal,
452+
)]),
453+
TypeSignature::Coercible(vec![Coercion::new_implicit(
454+
TypeSignatureClass::Native(logical_float64()),
455+
vec![TypeSignatureClass::Integer, TypeSignatureClass::Float],
456+
NativeType::Float64,
457+
)]),
458+
],
459+
Immutable,
460+
);
447461
Self {
448462
aliases: vec![String::from("mean")],
449-
signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
463+
signature,
450464
}
451465
}
452466
}

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,11 @@ impl ApproxMedian {
7474
pub fn new() -> Self {
7575
Self {
7676
signature: Signature::one_of(
77-
vec![
78-
TypeSignature::Coercible(vec![Coercion::new_exact(
79-
TypeSignatureClass::Integer,
80-
)]),
81-
TypeSignature::Coercible(vec![Coercion::new_implicit(
82-
TypeSignatureClass::Float,
83-
vec![TypeSignatureClass::Decimal],
84-
NativeType::Float64,
85-
)]),
86-
],
77+
vec![TypeSignature::Coercible(vec![Coercion::new_implicit(
78+
TypeSignatureClass::Float,
79+
vec![TypeSignatureClass::Numeric],
80+
NativeType::Float64,
81+
)])],
8782
Volatility::Immutable,
8883
),
8984
}

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 42 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,20 @@ use arrow::array::{Array, Float16Array};
2323
use arrow::compute::{filter, is_not_null};
2424
use arrow::datatypes::FieldRef;
2525
use arrow::{
26-
array::{
27-
ArrayRef, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
28-
Int64Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array,
29-
},
26+
array::{ArrayRef, Float32Array, Float64Array},
3027
datatypes::{DataType, Field},
3128
};
29+
use datafusion_common::types::{NativeType, logical_float64};
3230
use datafusion_common::{
3331
DataFusionError, Result, ScalarValue, downcast_value, internal_err, not_impl_err,
3432
plan_err,
3533
};
3634
use datafusion_expr::expr::{AggregateFunction, Sort};
3735
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
38-
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
3936
use datafusion_expr::utils::format_state_name;
4037
use datafusion_expr::{
41-
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
42-
Volatility,
38+
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature,
39+
TypeSignature, TypeSignatureClass, Volatility,
4340
};
4441
use datafusion_functions_aggregate_common::tdigest::{DEFAULT_MAX_SIZE, TDigest};
4542
use datafusion_macros::user_doc;
@@ -132,22 +129,44 @@ impl Default for ApproxPercentileCont {
132129
impl ApproxPercentileCont {
133130
/// Create a new [`ApproxPercentileCont`] aggregate function.
134131
pub fn new() -> Self {
135-
let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
136132
// Accept any numeric value paired with a float64 percentile
137-
for num in NUMERICS {
138-
variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
139-
// Additionally accept an integer number of centroids for T-Digest
140-
for int in INTEGERS {
141-
variants.push(TypeSignature::Exact(vec![
142-
num.clone(),
143-
DataType::Float64,
144-
int.clone(),
145-
]))
146-
}
147-
}
148-
Self {
149-
signature: Signature::one_of(variants, Volatility::Immutable),
150-
}
133+
let signature = Signature::one_of(
134+
vec![
135+
// 2 args - numeric, percentile (float)
136+
TypeSignature::Coercible(vec![
137+
Coercion::new_implicit(
138+
TypeSignatureClass::Float,
139+
vec![TypeSignatureClass::Numeric],
140+
NativeType::Float64,
141+
),
142+
Coercion::new_implicit(
143+
TypeSignatureClass::Native(logical_float64()),
144+
vec![TypeSignatureClass::Numeric],
145+
NativeType::Float64,
146+
),
147+
]),
148+
// 3 args - numeric, percentile (float), number of centroid for T-Digest (integer)
149+
TypeSignature::Coercible(vec![
150+
Coercion::new_implicit(
151+
TypeSignatureClass::Float,
152+
vec![TypeSignatureClass::Numeric],
153+
NativeType::Float64,
154+
),
155+
Coercion::new_implicit(
156+
TypeSignatureClass::Native(logical_float64()),
157+
vec![TypeSignatureClass::Numeric],
158+
NativeType::Float64,
159+
),
160+
Coercion::new_implicit(
161+
TypeSignatureClass::Integer,
162+
vec![TypeSignatureClass::Numeric],
163+
NativeType::Int64,
164+
),
165+
]),
166+
],
167+
Volatility::Immutable,
168+
);
169+
Self { signature }
151170
}
152171

153172
pub(crate) fn create_accumulator(
@@ -177,17 +196,7 @@ impl ApproxPercentileCont {
177196

178197
let data_type = args.expr_fields[0].data_type();
179198
let accumulator: ApproxPercentileAccumulator = match data_type {
180-
DataType::UInt8
181-
| DataType::UInt16
182-
| DataType::UInt32
183-
| DataType::UInt64
184-
| DataType::Int8
185-
| DataType::Int16
186-
| DataType::Int32
187-
| DataType::Int64
188-
| DataType::Float16
189-
| DataType::Float32
190-
| DataType::Float64 => {
199+
DataType::Float16 | DataType::Float32 | DataType::Float64 => {
191200
if let Some(max_size) = tdigest_max_size {
192201
ApproxPercentileAccumulator::new_with_max_size(
193202
percentile,
@@ -374,38 +383,6 @@ impl ApproxPercentileAccumulator {
374383
.map(|v| v.to_f64())
375384
.collect::<Vec<_>>())
376385
}
377-
DataType::Int64 => {
378-
let array = downcast_value!(values, Int64Array);
379-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
380-
}
381-
DataType::Int32 => {
382-
let array = downcast_value!(values, Int32Array);
383-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
384-
}
385-
DataType::Int16 => {
386-
let array = downcast_value!(values, Int16Array);
387-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
388-
}
389-
DataType::Int8 => {
390-
let array = downcast_value!(values, Int8Array);
391-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
392-
}
393-
DataType::UInt64 => {
394-
let array = downcast_value!(values, UInt64Array);
395-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
396-
}
397-
DataType::UInt32 => {
398-
let array = downcast_value!(values, UInt32Array);
399-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
400-
}
401-
DataType::UInt16 => {
402-
let array = downcast_value!(values, UInt16Array);
403-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
404-
}
405-
DataType::UInt8 => {
406-
let array = downcast_value!(values, UInt8Array);
407-
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
408-
}
409386
e => internal_err!(
410387
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
411388
),
@@ -439,14 +416,6 @@ impl Accumulator for ApproxPercentileAccumulator {
439416
// These acceptable return types MUST match the validation in
440417
// ApproxPercentile::create_accumulator.
441418
Ok(match &self.return_type {
442-
DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
443-
DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
444-
DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
445-
DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
446-
DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
447-
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
448-
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
449-
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
450419
DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
451420
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
452421
DataType::Float64 => ScalarValue::Float64(Some(q)),

0 commit comments

Comments
 (0)