-
Notifications
You must be signed in to change notification settings - Fork 2k
feat: change approx percentile/median UDFs to return floats #21074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
33ccecc
3b48e04
4637df2
968a851
d2bccf3
6fd2148
ff1eb8b
c04052c
0a17a19
add22d8
98cdc25
09c0b90
abd8b09
2ec7332
f19de5f
2fd9259
9d2790d
498a825
4995989
ae04043
793609e
a08f2ff
b7b3020
462bfa8
e242290
aa2c8c1
e6752c2
858bbd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,23 +23,20 @@ use arrow::array::{Array, Float16Array}; | |
| use arrow::compute::{filter, is_not_null}; | ||
| use arrow::datatypes::FieldRef; | ||
| use arrow::{ | ||
| array::{ | ||
| ArrayRef, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, | ||
| Int64Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array, | ||
| }, | ||
| array::{ArrayRef, Float32Array, Float64Array}, | ||
| datatypes::{DataType, Field}, | ||
| }; | ||
| use datafusion_common::types::{NativeType, logical_float64}; | ||
| use datafusion_common::{ | ||
| DataFusionError, Result, ScalarValue, downcast_value, internal_err, not_impl_err, | ||
| plan_err, | ||
| }; | ||
| use datafusion_expr::expr::{AggregateFunction, Sort}; | ||
| use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; | ||
| use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; | ||
| use datafusion_expr::utils::format_state_name; | ||
| use datafusion_expr::{ | ||
| Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, | ||
| Volatility, | ||
| Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature, | ||
| TypeSignature, TypeSignatureClass, Volatility, | ||
| }; | ||
| use datafusion_functions_aggregate_common::tdigest::{DEFAULT_MAX_SIZE, TDigest}; | ||
| use datafusion_macros::user_doc; | ||
|
|
@@ -132,22 +129,44 @@ impl Default for ApproxPercentileCont { | |
| impl ApproxPercentileCont { | ||
| /// Create a new [`ApproxPercentileCont`] aggregate function. | ||
| pub fn new() -> Self { | ||
| let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); | ||
| // Accept any numeric value paired with a float64 percentile | ||
| for num in NUMERICS { | ||
| variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); | ||
| // Additionally accept an integer number of centroids for T-Digest | ||
| for int in INTEGERS { | ||
| variants.push(TypeSignature::Exact(vec![ | ||
| num.clone(), | ||
| DataType::Float64, | ||
| int.clone(), | ||
| ])) | ||
| } | ||
| } | ||
| Self { | ||
| signature: Signature::one_of(variants, Volatility::Immutable), | ||
| } | ||
| let signature = Signature::one_of( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're now coercing to floats, we can remove some of the implementation code that handles integer types
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, removed some code. It also affected Clarified the scope of this PR in description - it changes more than expected |
||
| vec![ | ||
| // 2 args - numeric, percentile (float) | ||
| TypeSignature::Coercible(vec![ | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Float, | ||
| vec![TypeSignatureClass::Numeric], | ||
| NativeType::Float64, | ||
| ), | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_float64()), | ||
| vec![TypeSignatureClass::Numeric], | ||
| NativeType::Float64, | ||
| ), | ||
| ]), | ||
| // 3 args - numeric, percentile (float), number of centroid for T-Digest (integer) | ||
| TypeSignature::Coercible(vec![ | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Float, | ||
| vec![TypeSignatureClass::Numeric], | ||
| NativeType::Float64, | ||
| ), | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_float64()), | ||
| vec![TypeSignatureClass::Numeric], | ||
| NativeType::Float64, | ||
| ), | ||
| Coercion::new_implicit( | ||
| TypeSignatureClass::Integer, | ||
| vec![TypeSignatureClass::Numeric], | ||
| NativeType::Int64, | ||
| ), | ||
| ]), | ||
| ], | ||
| Volatility::Immutable, | ||
| ); | ||
| Self { signature } | ||
| } | ||
|
|
||
| pub(crate) fn create_accumulator( | ||
|
|
@@ -177,17 +196,7 @@ impl ApproxPercentileCont { | |
|
|
||
| let data_type = args.expr_fields[0].data_type(); | ||
| let accumulator: ApproxPercentileAccumulator = match data_type { | ||
| DataType::UInt8 | ||
| | DataType::UInt16 | ||
| | DataType::UInt32 | ||
| | DataType::UInt64 | ||
| | DataType::Int8 | ||
| | DataType::Int16 | ||
| | DataType::Int32 | ||
| | DataType::Int64 | ||
| | DataType::Float16 | ||
| | DataType::Float32 | ||
| | DataType::Float64 => { | ||
| DataType::Float16 | DataType::Float32 | DataType::Float64 => { | ||
| if let Some(max_size) = tdigest_max_size { | ||
| ApproxPercentileAccumulator::new_with_max_size( | ||
| percentile, | ||
|
|
@@ -374,38 +383,6 @@ impl ApproxPercentileAccumulator { | |
| .map(|v| v.to_f64()) | ||
| .collect::<Vec<_>>()) | ||
| } | ||
| DataType::Int64 => { | ||
| let array = downcast_value!(values, Int64Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::Int32 => { | ||
| let array = downcast_value!(values, Int32Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::Int16 => { | ||
| let array = downcast_value!(values, Int16Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::Int8 => { | ||
| let array = downcast_value!(values, Int8Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::UInt64 => { | ||
| let array = downcast_value!(values, UInt64Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::UInt32 => { | ||
| let array = downcast_value!(values, UInt32Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::UInt16 => { | ||
| let array = downcast_value!(values, UInt16Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| DataType::UInt8 => { | ||
| let array = downcast_value!(values, UInt8Array); | ||
| Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>()) | ||
| } | ||
| e => internal_err!( | ||
| "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" | ||
| ), | ||
|
|
@@ -439,14 +416,6 @@ impl Accumulator for ApproxPercentileAccumulator { | |
| // These acceptable return types MUST match the validation in | ||
| // ApproxPercentile::create_accumulator. | ||
| Ok(match &self.return_type { | ||
| DataType::Int8 => ScalarValue::Int8(Some(q as i8)), | ||
| DataType::Int16 => ScalarValue::Int16(Some(q as i16)), | ||
| DataType::Int32 => ScalarValue::Int32(Some(q as i32)), | ||
| DataType::Int64 => ScalarValue::Int64(Some(q as i64)), | ||
| DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), | ||
| DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), | ||
| DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), | ||
| DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), | ||
| DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))), | ||
| DataType::Float32 => ScalarValue::Float32(Some(q as f32)), | ||
| DataType::Float64 => ScalarValue::Float64(Some(q)), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.