Skip to content

Commit b41ea60

Browse files
fix: median returns Float64 for integer inputs to avoid truncation
1 parent f043092 commit b41ea60

7 files changed

Lines changed: 132 additions & 113 deletions

File tree

datafusion/core/tests/dataframe/describe.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async fn describe() -> Result<()> {
4444
| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |
4545
| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |
4646
| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |
47-
| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |
47+
| median | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.5 | 7.0 |
4848
+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+
4949
");
5050
Ok(())

datafusion/core/tests/dataframe/mod.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,26 +1204,26 @@ async fn window_using_aggregates() -> Result<()> {
12041204
| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |
12051205
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
12061206
| | | | | | | | 1 | -85 |
1207-
| -85 | -101 | 14 | -12.0 | -12 | 83 | -101 | 4 | -54 |
1208-
| -85 | -101 | 17 | -25.0 | -25 | 83 | -101 | 5 | -31 |
1209-
| -85 | -12 | 10 | -32.75 | -34 | 83 | -85 | 3 | 13 |
1210-
| -85 | -25 | 3 | -56.0 | -56 | -25 | -85 | 1 | -5 |
1211-
| -85 | -31 | 18 | -29.75 | -28 | 83 | -101 | 5 | 36 |
1212-
| -85 | -38 | 16 | -25.0 | -25 | 83 | -101 | 4 | 65 |
1213-
| -85 | -43 | 7 | -43.0 | -43 | 83 | -85 | 2 | 45 |
1214-
| -85 | -48 | 6 | -35.75 | -36 | 83 | -85 | 2 | -43 |
1215-
| -85 | -5 | 4 | -37.75 | -40 | -5 | -85 | 1 | 83 |
1216-
| -85 | -54 | 15 | -17.0 | -18 | 83 | -101 | 4 | -38 |
1217-
| -85 | -56 | 2 | -70.5 | -70 | -56 | -85 | 1 | -25 |
1218-
| -85 | -72 | 9 | -43.0 | -43 | 83 | -85 | 3 | -12 |
1219-
| -85 | -85 | 1 | -85.0 | -85 | -85 | -85 | 1 | -56 |
1220-
| -85 | 13 | 11 | -17.0 | -18 | 83 | -85 | 3 | 14 |
1221-
| -85 | 13 | 11 | -25.0 | -25 | 83 | -85 | 3 | 13 |
1222-
| -85 | 14 | 12 | -12.0 | -12 | 83 | -85 | 3 | 17 |
1223-
| -85 | 17 | 13 | -11.25 | -8 | 83 | -85 | 4 | -101 |
1224-
| -85 | 45 | 8 | -34.5 | -34 | 83 | -85 | 3 | -72 |
1225-
| -85 | 65 | 17 | -17.0 | -18 | 83 | -101 | 5 | -101 |
1226-
| -85 | 83 | 5 | -25.0 | -25 | 83 | -85 | 2 | -48 |
1207+
| -85 | -101 | 14 | -12.0 | -12.0 | 83 | -101 | 4 | -54 |
1208+
| -85 | -101 | 17 | -25.0 | -25.0 | 83 | -101 | 5 | -31 |
1209+
| -85 | -12 | 10 | -32.75 | -34.0 | 83 | -85 | 3 | 13 |
1210+
| -85 | -25 | 3 | -56.0 | -56.0 | -25 | -85 | 1 | -5 |
1211+
| -85 | -31 | 18 | -29.75 | -28.0 | 83 | -101 | 5 | 36 |
1212+
| -85 | -38 | 16 | -25.0 | -25.0 | 83 | -101 | 4 | 65 |
1213+
| -85 | -43 | 7 | -43.0 | -43.0 | 83 | -85 | 2 | 45 |
1214+
| -85 | -48 | 6 | -35.75 | -36.5 | 83 | -85 | 2 | -43 |
1215+
| -85 | -5 | 4 | -37.75 | -40.5 | -5 | -85 | 1 | 83 |
1216+
| -85 | -54 | 15 | -17.0 | -18.5 | 83 | -101 | 4 | -38 |
1217+
| -85 | -56 | 2 | -70.5 | -70.5 | -56 | -85 | 1 | -25 |
1218+
| -85 | -72 | 9 | -43.0 | -43.0 | 83 | -85 | 3 | -12 |
1219+
| -85 | -85 | 1 | -85.0 | -85.0 | -85 | -85 | 1 | -56 |
1220+
| -85 | 13 | 11 | -17.0 | -18.5 | 83 | -85 | 3 | 14 |
1221+
| -85 | 13 | 11 | -25.0 | -25.0 | 83 | -85 | 3 | 13 |
1222+
| -85 | 14 | 12 | -12.0 | -12.0 | 83 | -85 | 3 | 17 |
1223+
| -85 | 17 | 13 | -11.25 | -8.5 | 83 | -85 | 4 | -101 |
1224+
| -85 | 45 | 8 | -34.5 | -34.0 | 83 | -85 | 3 | -72 |
1225+
| -85 | 65 | 17 | -17.0 | -18.5 | 83 | -101 | 5 | -101 |
1226+
| -85 | 83 | 5 | -25.0 | -25.0 | 83 | -85 | 2 | -48 |
12271227
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
12281228
"
12291229
);

datafusion/core/tests/sql/aggregates/dict_nulls.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> {
9191
+----------------+--------------+
9292
| dict_null_vals | median_value |
9393
+----------------+--------------+
94-
| | 3 |
95-
| group_x | 1 |
96-
| group_y | 5 |
97-
| group_z | 7 |
94+
| | 3.0 |
95+
| group_x | 1.0 |
96+
| group_y | 5.0 |
97+
| group_z | 7.0 |
9898
+----------------+--------------+
9999
");
100100

@@ -437,16 +437,16 @@ async fn test_median_distinct_with_fuzz_table_dict_nulls() -> Result<()> {
437437
assert_snapshot!(
438438
batches_to_string(&results),
439439
@r"
440-
+--------+---------------------+------+------+------+--------+--------+
441-
| u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 |
442-
+--------+---------------------+------+------+------+--------+--------+
443-
| 50 | | | 30 | | 987.65 | 400000 |
444-
| 50 | group_three | 5000 | 50 | 5000 | 555.55 | 500000 |
445-
| 75 | | 4000 | | 4000 | | 450000 |
446-
| 100 | group_one | 1100 | 11 | 1000 | 123.45 | 110000 |
447-
| 100 | group_two | 1500 | 15 | 1500 | 111.11 | 150000 |
448-
| 200 | | 2500 | 22 | 2500 | 506.11 | 250000 |
449-
+--------+---------------------+------+------+------+--------+--------+
440+
+--------+---------------------+--------+------+--------+--------+----------+
441+
| u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 |
442+
+--------+---------------------+--------+------+--------+--------+----------+
443+
| 50 | | | 30.0 | | 987.65 | 400000.0 |
444+
| 50 | group_three | 5000.0 | 50.0 | 5000.0 | 555.55 | 500000.0 |
445+
| 75 | | 4000.0 | | 4000.0 | | 450000.0 |
446+
| 100 | group_one | 1100.0 | 11.0 | 1000.0 | 123.45 | 110000.0 |
447+
| 100 | group_two | 1500.0 | 15.0 | 1500.0 | 111.11 | 150000.0 |
448+
| 200 | | 2500.0 | 22.5 | 2500.0 | 506.11 | 250000.0 |
449+
+--------+---------------------+--------+------+--------+--------+----------+
450450
"
451451
);
452452

datafusion/functions-aggregate/src/median.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,15 @@ use arrow::datatypes::{
3939
ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
4040
};
4141

42+
use datafusion_common::types::{NativeType, logical_float64};
4243
use datafusion_common::{
4344
DataFusionError, Result, ScalarValue, assert_eq_or_internal_err,
4445
internal_datafusion_err,
4546
};
4647
use datafusion_expr::function::StateFieldsArgs;
4748
use datafusion_expr::{
48-
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
49-
function::AccumulatorArgs, utils::format_state_name,
49+
Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature,
50+
TypeSignatureClass, Volatility, function::AccumulatorArgs, utils::format_state_name,
5051
};
5152
use datafusion_expr::{EmitTo, GroupsAccumulator};
5253
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
@@ -99,7 +100,25 @@ impl Default for Median {
99100
impl Median {
100101
pub fn new() -> Self {
101102
Self {
102-
signature: Signature::numeric(1, Volatility::Immutable),
103+
// Integer inputs are coerced to Float64 so the average of the two
104+
// middle values is not truncated. This matches DuckDB / PostgreSQL / Spark.
105+
// Float and Decimal inputs preserve their type.
106+
signature: Signature::one_of(
107+
vec![
108+
TypeSignature::Coercible(vec![Coercion::new_exact(
109+
TypeSignatureClass::Decimal,
110+
)]),
111+
TypeSignature::Coercible(vec![Coercion::new_exact(
112+
TypeSignatureClass::Float,
113+
)]),
114+
TypeSignature::Coercible(vec![Coercion::new_implicit(
115+
TypeSignatureClass::Native(logical_float64()),
116+
vec![TypeSignatureClass::Integer],
117+
NativeType::Float64,
118+
)]),
119+
],
120+
Volatility::Immutable,
121+
),
103122
}
104123
}
105124
}

0 commit comments

Comments
 (0)