Skip to content

Commit b1c7656

Browse files
committed
fix compile.
1 parent 0cc1e0c commit b1c7656

10 files changed

Lines changed: 137 additions & 54 deletions

File tree

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ use arrow::{
2929
};
3030
use datafusion::error::Result;
3131
use datafusion::prelude::*;
32-
use datafusion_common::{cast::as_float64_array, ScalarValue};
32+
use datafusion_common::{cast::as_float64_array, DataFusionError, ScalarValue};
3333
use datafusion_expr::{
34-
function::{AccumulatorArgs, StateFieldsArgs},
35-
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
34+
function::{AccumulatorArgs, StateFieldsArgs}, groups_accumulator::GroupIndices, Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature
3635
};
3736

3837
/// This example shows how to use the full AggregateUDFImpl API to implement a user
@@ -272,11 +271,19 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
272271
fn update_batch(
273272
&mut self,
274273
values: &[ArrayRef],
275-
group_indices: &[usize],
274+
group_indices: GroupIndices<'_>,
276275
opt_filter: Option<&arrow::array::BooleanArray>,
277276
total_num_groups: usize,
278277
) -> Result<()> {
279278
assert_eq!(values.len(), 1, "single argument to update_batch");
279+
280+
let group_indices = match group_indices {
281+
GroupIndices::Flat(idxs) => idxs,
282+
GroupIndices::Blocked(_) => return Err(DataFusionError::NotImplemented(
283+
"blocked states management is not supported".to_string()),
284+
),
285+
};
286+
280287
let values = values[0].as_primitive::<Float64Type>();
281288

282289
// increment counts, update sums
@@ -303,11 +310,19 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
303310
fn merge_batch(
304311
&mut self,
305312
values: &[ArrayRef],
306-
group_indices: &[usize],
313+
group_indices: GroupIndices<'_>,
307314
opt_filter: Option<&arrow::array::BooleanArray>,
308315
total_num_groups: usize,
309316
) -> Result<()> {
310317
assert_eq!(values.len(), 2, "two arguments to merge_batch");
318+
319+
let group_indices = match group_indices {
320+
GroupIndices::Flat(idxs) => idxs,
321+
GroupIndices::Blocked(_) => return Err(DataFusionError::NotImplemented(
322+
"blocked states management is not supported".to_string()),
323+
),
324+
};
325+
311326
// first batch is counts, second is partial sums
312327
let partial_prods = values[0].as_primitive::<Float64Type>();
313328
let partial_counts = values[1].as_primitive::<UInt32Type>();

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use datafusion::{
4949
scalar::ScalarValue,
5050
};
5151
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
52+
use datafusion_expr::groups_accumulator::GroupIndices;
5253
use datafusion_expr::{
5354
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
5455
LogicalPlanBuilder, SimpleAggregateUDF,
@@ -832,7 +833,7 @@ impl GroupsAccumulator for TestGroupsAccumulator {
832833
fn update_batch(
833834
&mut self,
834835
_values: &[ArrayRef],
835-
_group_indices: &[usize],
836+
_group_indices: GroupIndices<'_>,
836837
_opt_filter: Option<&arrow_array::BooleanArray>,
837838
_total_num_groups: usize,
838839
) -> Result<()> {
@@ -856,7 +857,7 @@ impl GroupsAccumulator for TestGroupsAccumulator {
856857
fn merge_batch(
857858
&mut self,
858859
_values: &[ArrayRef],
859-
_group_indices: &[usize],
860+
_group_indices: GroupIndices<'_>,
860861
_opt_filter: Option<&arrow_array::BooleanArray>,
861862
_total_num_groups: usize,
862863
) -> Result<()> {

datafusion/expr-common/src/groups_accumulator.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,21 @@ pub enum GroupIndices<'a> {
7070
Blocked(&'a [u64]),
7171
}
7272

73+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74+
pub enum GroupIndicesType {
75+
Flat,
76+
Blocked,
77+
}
78+
79+
impl GroupIndicesType {
80+
pub fn typed_group_indices<'a>(&self, indices: &'a [u64]) -> GroupIndices<'a> {
81+
match self {
82+
GroupIndicesType::Flat => GroupIndices::Flat(indices),
83+
GroupIndicesType::Blocked => GroupIndices::Blocked(indices),
84+
}
85+
}
86+
}
87+
7388
/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and
7489
/// stores the state for *all* groups internally.
7590
///

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,9 @@ mod test {
665665
let mut accumulated_values = vec![];
666666
let mut null_state = NullState::new();
667667

668+
let group_indices = group_indices.iter().map(|idx| *idx as u64).collect::<Vec<_>>();
668669
null_state.accumulate(
669-
group_indices,
670+
&group_indices,
670671
values,
671672
opt_filter,
672673
total_num_groups,
@@ -683,8 +684,8 @@ mod test {
683684
None => group_indices.iter().zip(values.iter()).for_each(
684685
|(&group_index, value)| {
685686
if let Some(value) = value {
686-
mock.saw_value(group_index);
687-
expected_values.push((group_index, value));
687+
mock.saw_value(group_index as usize);
688+
expected_values.push((group_index as usize, value));
688689
}
689690
},
690691
),
@@ -697,8 +698,8 @@ mod test {
697698
// if value passed filter
698699
if let Some(true) = is_included {
699700
if let Some(value) = value {
700-
mock.saw_value(group_index);
701-
expected_values.push((group_index, value));
701+
mock.saw_value(group_index as usize);
702+
expected_values.push((group_index as usize, value));
702703
}
703704
}
704705
});
@@ -727,7 +728,8 @@ mod test {
727728
) {
728729
let mut accumulated_values = vec![];
729730

730-
accumulate_indices(group_indices, nulls, opt_filter, |group_index| {
731+
let group_indices = group_indices.iter().map(|idx| *idx as u64).collect::<Vec<_>>();
732+
accumulate_indices(&group_indices, nulls, opt_filter, |group_index| {
731733
accumulated_values.push(group_index);
732734
});
733735

@@ -736,19 +738,19 @@ mod test {
736738

737739
match (nulls, opt_filter) {
738740
(None, None) => group_indices.iter().for_each(|&group_index| {
739-
expected_values.push(group_index);
741+
expected_values.push(group_index as usize);
740742
}),
741743
(Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each(
742744
|(&group_index, is_valid)| {
743745
if is_valid {
744-
expected_values.push(group_index);
746+
expected_values.push(group_index as usize);
745747
}
746748
},
747749
),
748750
(None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each(
749751
|(&group_index, is_included)| {
750752
if let Some(true) = is_included {
751-
expected_values.push(group_index);
753+
expected_values.push(group_index as usize);
752754
}
753755
},
754756
),
@@ -760,7 +762,7 @@ mod test {
760762
.for_each(|((&group_index, is_valid), is_included)| {
761763
// if value passed filter
762764
if let (true, Some(true)) = (is_valid, is_included) {
763-
expected_values.push(group_index);
765+
expected_values.push(group_index as usize);
764766
}
765767
});
766768
}
@@ -781,8 +783,9 @@ mod test {
781783
let mut accumulated_values = vec![];
782784
let mut null_state = NullState::new();
783785

786+
let group_indices = group_indices.iter().map(|idx| *idx as u64).collect::<Vec<_>>();
784787
null_state.accumulate_boolean(
785-
group_indices,
788+
&group_indices,
786789
values,
787790
opt_filter,
788791
total_num_groups,
@@ -799,8 +802,8 @@ mod test {
799802
None => group_indices.iter().zip(values.iter()).for_each(
800803
|(&group_index, value)| {
801804
if let Some(value) = value {
802-
mock.saw_value(group_index);
803-
expected_values.push((group_index, value));
805+
mock.saw_value(group_index as usize);
806+
expected_values.push((group_index as usize, value));
804807
}
805808
},
806809
),
@@ -813,8 +816,8 @@ mod test {
813816
// if value passed filter
814817
if let Some(true) = is_included {
815818
if let Some(value) = value {
816-
mock.saw_value(group_index);
817-
expected_values.push((group_index, value));
819+
mock.saw_value(group_index as usize);
820+
expected_values.push((group_index as usize, value));
818821
}
819822
}
820823
});

datafusion/physical-plan/src/aggregates/group_values/bytes.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::aggregates::group_values::GroupValues;
1919
use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch};
2020
use datafusion_common::{DataFusionError, Result};
21-
use datafusion_expr::EmitTo;
21+
use datafusion_expr::{groups_accumulator::GroupIndicesType, EmitTo};
2222
use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType};
2323

2424
/// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values
@@ -45,9 +45,17 @@ impl<O: OffsetSizeTrait> GroupValues for GroupValuesByes<O> {
4545
fn intern(
4646
&mut self,
4747
cols: &[ArrayRef],
48-
groups: &mut Vec<usize>,
48+
groups: &mut Vec<u64>,
49+
group_type: GroupIndicesType
4950
) -> datafusion_common::Result<()> {
5051
assert_eq!(cols.len(), 1);
52+
53+
if group_type == GroupIndicesType::Blocked {
54+
return Err(DataFusionError::NotImplemented(
55+
"blocked group values management is not supported".to_string()),
56+
);
57+
}
58+
5159

5260
// look up / add entries in the table
5361
let arr = &cols[0];
@@ -64,7 +72,7 @@ impl<O: OffsetSizeTrait> GroupValues for GroupValuesByes<O> {
6472
},
6573
// called for each group
6674
|group_idx| {
67-
groups.push(group_idx);
75+
groups.push(group_idx as u64);
6876
},
6977
);
7078

@@ -109,7 +117,8 @@ impl<O: OffsetSizeTrait> GroupValues for GroupValuesByes<O> {
109117

110118
self.num_groups = 0;
111119
let mut group_indexes = vec![];
112-
self.intern(&[remaining_group_values], &mut group_indexes)?;
120+
// FIXME: When impl blocked GroupValuesByes, we should consider a way to get right `group_type` here.
121+
self.intern(&[remaining_group_values], &mut group_indexes, GroupIndicesType::Flat)?;
113122

114123
// Verify that the group indexes were assigned in the correct order
115124
assert_eq!(0, group_indexes[0]);

datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::aggregates::group_values::GroupValues;
1919
use arrow_array::{Array, ArrayRef, RecordBatch};
2020
use datafusion_common::{DataFusionError, Result};
21-
use datafusion_expr::EmitTo;
21+
use datafusion_expr::{groups_accumulator::GroupIndicesType, EmitTo};
2222
use datafusion_physical_expr::binary_map::OutputType;
2323
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;
2424

@@ -46,10 +46,17 @@ impl GroupValues for GroupValuesBytesView {
4646
fn intern(
4747
&mut self,
4848
cols: &[ArrayRef],
49-
groups: &mut Vec<usize>,
49+
groups: &mut Vec<u64>,
50+
group_type: GroupIndicesType,
5051
) -> datafusion_common::Result<()> {
5152
assert_eq!(cols.len(), 1);
5253

54+
if group_type == GroupIndicesType::Blocked {
55+
return Err(DataFusionError::NotImplemented(
56+
"blocked group values management is not supported".to_string()),
57+
);
58+
}
59+
5360
// look up / add entries in the table
5461
let arr = &cols[0];
5562

@@ -65,7 +72,7 @@ impl GroupValues for GroupValuesBytesView {
6572
},
6673
// called for each group
6774
|group_idx| {
68-
groups.push(group_idx);
75+
groups.push(group_idx as u64);
6976
},
7077
);
7178

@@ -110,7 +117,8 @@ impl GroupValues for GroupValuesBytesView {
110117

111118
self.num_groups = 0;
112119
let mut group_indexes = vec![];
113-
self.intern(&[remaining_group_values], &mut group_indexes)?;
120+
// FIXME: When impl blocked GroupValuesBytesView, we should consider a way to get right `group_type` here.
121+
self.intern(&[remaining_group_values], &mut group_indexes, GroupIndicesType::Flat)?;
114122

115123
// Verify that the group indexes were assigned in the correct order
116124
assert_eq!(0, group_indexes[0]);

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use bytes_view::GroupValuesBytesView;
2222
use datafusion_common::Result;
2323

2424
pub(crate) mod primitive;
25-
use datafusion_expr::EmitTo;
25+
use datafusion_expr::{groups_accumulator::GroupIndicesType, EmitTo};
2626
use primitive::GroupValuesPrimitive;
2727

2828
mod row;
@@ -36,7 +36,7 @@ use datafusion_physical_expr::binary_map::OutputType;
3636
/// An interning store for group keys
3737
pub trait GroupValues: Send {
3838
/// Calculates the `groups` for each input row of `cols`
39-
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()>;
39+
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<u64>, group_type: GroupIndicesType) -> Result<()>;
4040

4141
/// Returns the number of bytes used by this [`GroupValues`]
4242
fn size(&self) -> usize;

datafusion/physical-plan/src/aggregates/group_values/primitive.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
2727
use arrow_schema::DataType;
2828
use datafusion_common::{DataFusionError, Result};
2929
use datafusion_execution::memory_pool::proxy::VecAllocExt;
30+
use datafusion_expr::groups_accumulator::GroupIndicesType;
3031
use datafusion_expr::EmitTo;
3132
use half::f16;
3233
use hashbrown::raw::RawTable;
@@ -111,8 +112,20 @@ impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T>
111112
where
112113
T::Native: HashValue,
113114
{
114-
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
115+
fn intern(
116+
&mut self,
117+
cols: &[ArrayRef],
118+
groups: &mut Vec<u64>,
119+
group_type: GroupIndicesType,
120+
) -> Result<()> {
115121
assert_eq!(cols.len(), 1);
122+
123+
if group_type == GroupIndicesType::Blocked {
124+
return Err(DataFusionError::NotImplemented(
125+
"blocked group values management is not supported".to_string()),
126+
);
127+
}
128+
116129
groups.clear();
117130

118131
for v in cols[0].as_primitive::<T>() {
@@ -145,7 +158,7 @@ where
145158
}
146159
}
147160
};
148-
groups.push(group_id)
161+
groups.push(group_id as u64)
149162
}
150163
Ok(())
151164
}

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use arrow_schema::{DataType, SchemaRef};
2525
use datafusion_common::hash_utils::create_hashes;
2626
use datafusion_common::{DataFusionError, Result};
2727
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
28+
use datafusion_expr::groups_accumulator::GroupIndicesType;
2829
use datafusion_expr::EmitTo;
2930
use hashbrown::raw::RawTable;
3031

@@ -99,7 +100,13 @@ impl GroupValuesRows {
99100
}
100101

101102
impl GroupValues for GroupValuesRows {
102-
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
103+
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<u64>, group_type: GroupIndicesType) -> Result<()> {
104+
if group_type == GroupIndicesType::Blocked {
105+
return Err(DataFusionError::NotImplemented(
106+
"blocked group values management is not supported".to_string()),
107+
);
108+
}
109+
103110
// Convert the group keys into the row format
104111
let group_rows = &mut self.rows_buffer;
105112
group_rows.clear();
@@ -151,7 +158,7 @@ impl GroupValues for GroupValuesRows {
151158
group_idx
152159
}
153160
};
154-
groups.push(group_idx);
161+
groups.push(group_idx as u64);
155162
}
156163

157164
self.group_values = Some(group_values);

0 commit comments

Comments
 (0)