Skip to content

Commit d929212

Browse files
committed
avoid allocation in dedup_first_wins
1 parent 0c2a55a commit d929212

1 file changed

Lines changed: 30 additions & 43 deletions

File tree

datafusion/functions-aggregate/src/map_agg.rs

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! `MAP_AGG` aggregate implementation: [`MapAgg`]
1919
20-
use std::collections::VecDeque;
20+
use std::collections::{HashSet, VecDeque};
2121
use std::mem::{size_of, size_of_val, take};
2222
use std::sync::Arc;
2323

@@ -26,7 +26,7 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2626
use arrow::compute::SortOptions;
2727
use arrow::datatypes::{DataType, Field, FieldRef, Fields};
2828

29-
use datafusion_common::utils::{compare_rows, get_row_at_idx};
29+
use datafusion_common::utils::{SingleRowListArrayBuilder, compare_rows, get_row_at_idx};
3030
use datafusion_common::{Result, ScalarValue, exec_err};
3131
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3232
use datafusion_expr::utils::format_state_name;
@@ -150,13 +150,11 @@ impl AggregateUDFImpl for MapAgg {
150150
}
151151

152152
fn map_type(key_type: &DataType, value_type: &DataType) -> DataType {
153-
let key_field = Arc::new(Field::new("key", key_type.clone(), false));
154-
let value_field = Arc::new(Field::new("value", value_type.clone(), true));
155-
let entries_field = Arc::new(Field::new(
156-
"entries",
157-
DataType::Struct(Fields::from(vec![key_field, value_field])),
158-
false,
159-
));
153+
let fields = Fields::from(vec![
154+
Field::new("key", key_type.clone(), false),
155+
Field::new("value", value_type.clone(), true),
156+
]);
157+
let entries_field = Arc::new(Field::new("entries", DataType::Struct(fields), false));
160158
DataType::Map(entries_field, false)
161159
}
162160

@@ -168,14 +166,13 @@ fn build_single_map(
168166
) -> Result<ArrayRef> {
169167
debug_assert_eq!(keys.len(), values.len());
170168

171-
let key_field = Arc::new(Field::new("key", key_type.clone(), false));
172-
let value_field = Arc::new(Field::new("value", value_type.clone(), true));
169+
let fields = Fields::from(vec![
170+
Field::new("key", key_type.clone(), false),
171+
Field::new("value", value_type.clone(), true),
172+
]);
173173
let entries_field = Arc::new(Field::new(
174174
"entries",
175-
DataType::Struct(Fields::from(vec![
176-
Arc::clone(&key_field),
177-
Arc::clone(&value_field),
178-
])),
175+
DataType::Struct(fields.clone()),
179176
false,
180177
));
181178

@@ -191,11 +188,7 @@ fn build_single_map(
191188
ScalarValue::iter_to_array(values)?
192189
};
193190

194-
let entries = StructArray::try_new(
195-
Fields::from(vec![key_field, value_field]),
196-
vec![key_array, value_array],
197-
None,
198-
)?;
191+
let entries = StructArray::try_new(fields, vec![key_array, value_array], None)?;
199192

200193
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, len as i32]));
201194
Ok(Arc::new(MapArray::try_new(
@@ -207,27 +200,26 @@ fn build_single_map(
207200
)?))
208201
}
209202

210-
/// De-duplicates parallel key/value vectors keeping the first value seen for
211-
/// each key.
212203
fn dedup_first_wins(
213204
keys: Vec<ScalarValue>,
214205
values: Vec<ScalarValue>,
215206
) -> (Vec<ScalarValue>, Vec<ScalarValue>) {
216-
use std::collections::HashSet;
217-
218-
let mut seen: HashSet<ScalarValue> = HashSet::with_capacity(keys.len());
219-
let mut out_keys: Vec<ScalarValue> = Vec::with_capacity(keys.len());
220-
let mut out_vals: Vec<ScalarValue> = Vec::with_capacity(keys.len());
221-
222-
for (k, v) in keys.into_iter().zip(values) {
223-
// Keep only the first occurrence of each key; later ones are dropped.
224-
if seen.insert(k.clone()) {
225-
out_keys.push(k);
226-
out_vals.push(v);
227-
}
228-
}
229-
230-
(out_keys, out_vals)
207+
// First pass: mark each position that is the first occurrence of its key.
208+
let mut seen = HashSet::with_capacity(keys.len());
209+
let keep: Vec<bool> = keys.iter().map(|k| seen.insert(k)).collect();
210+
211+
// Second pass: keep only the first-occurrence positions.
212+
let out_keys = keys
213+
.into_iter()
214+
.zip(&keep)
215+
.filter_map(|(k, &keep)| keep.then_some(k))
216+
.collect();
217+
let out_values = values
218+
.into_iter()
219+
.zip(&keep)
220+
.filter_map(|(v, &keep)| keep.then_some(v))
221+
.collect();
222+
(out_keys, out_values)
231223
}
232224

233225
/// Plain accumulator used when there is no `ORDER BY`.
@@ -388,12 +380,7 @@ impl OrderSensitiveMapAggAccumulator {
388380
}
389381

390382
let struct_array = StructArray::try_new(struct_field, column_wise, None)?;
391-
Ok(
392-
datafusion_common::utils::SingleRowListArrayBuilder::new(Arc::new(
393-
struct_array,
394-
))
395-
.build_list_scalar(),
396-
)
383+
Ok(SingleRowListArrayBuilder::new(Arc::new(struct_array)).build_list_scalar())
397384
}
398385
}
399386

0 commit comments

Comments
 (0)