Skip to content

Commit 4795aa5

Browse files
committed
Refactor min_max and display helpers
Share array-extrema scan in min_max.rs. Collapse dictionary scalar handling into a single normalized path. Simplify key ordering, indentation, and stringification helpers in display.rs. Reduce duplicated dictionary test setup in min_max.rs and basic.rs.
1 parent b4938c1 commit 4795aa5

File tree

4 files changed

+152
-147
lines changed

4 files changed

+152
-147
lines changed

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -448,28 +448,35 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
448448
let ctx =
449449
SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2));
450450

451-
let dict_type =
452-
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
453-
let schema = Arc::new(Schema::new(vec![Field::new(
454-
"dict",
455-
dict_type.clone(),
456-
true,
457-
)]));
458-
459-
let batch1 = RecordBatch::try_new(
460-
schema.clone(),
461-
vec![Arc::new(DictionaryArray::new(
462-
Int32Array::from(vec![Some(1), Some(1), None]),
463-
Arc::new(StringArray::from(vec!["a", "z", "zz_unused"])),
464-
))],
465-
)?;
466-
let batch2 = RecordBatch::try_new(
451+
fn dictionary_schema() -> Arc<Schema> {
452+
let dict_type =
453+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
454+
Arc::new(Schema::new(vec![Field::new("dict", dict_type, true)]))
455+
}
456+
457+
fn dictionary_record_batch(
458+
schema: Arc<Schema>,
459+
values: &[&str],
460+
keys: &[Option<i32>],
461+
) -> Result<RecordBatch> {
462+
Ok(RecordBatch::try_new(
463+
schema,
464+
vec![Arc::new(DictionaryArray::new(
465+
Int32Array::from(keys.to_vec()),
466+
Arc::new(StringArray::from(values.to_vec())),
467+
))],
468+
)?)
469+
}
470+
471+
let schema = dictionary_schema();
472+
473+
let batch1 = dictionary_record_batch(
467474
schema.clone(),
468-
vec![Arc::new(DictionaryArray::new(
469-
Int32Array::from(vec![Some(0), Some(1)]),
470-
Arc::new(StringArray::from(vec!["a", "d"])),
471-
))],
475+
&["a", "z", "zz_unused"],
476+
&[Some(1), Some(1), None],
472477
)?;
478+
let batch2 =
479+
dictionary_record_batch(schema.clone(), &["a", "d"], &[Some(0), Some(1)])?;
473480
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
474481
ctx.register_table("t", Arc::new(provider))?;
475482

datafusion/expr/src/logical_plan/display.rs

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -283,38 +283,45 @@ pub struct PgJsonVisitor<'a, 'b> {
283283
parent_ids: Vec<u32>,
284284
}
285285

286+
const NODE_TYPE_KEY: &str = "Node Type";
287+
const PLANS_KEY: &str = "Plans";
288+
const OUTPUT_KEY: &str = "Output";
289+
286290
fn ordered_keys(map: &serde_json::Map<String, serde_json::Value>) -> Vec<&str> {
287-
let mut ordered_keys = vec![];
288-
289-
// pgjson node objects are always emitted as:
290-
// Node Type, node-specific keys, Plans, Output
291-
if map.contains_key("Node Type") {
292-
ordered_keys.push("Node Type");
293-
294-
let mut middle_keys = map
295-
.keys()
296-
.map(String::as_str)
297-
.filter(|k| *k != "Node Type" && *k != "Plans" && *k != "Output")
298-
.collect::<Vec<_>>();
299-
middle_keys.sort_unstable();
300-
ordered_keys.extend(middle_keys);
301-
302-
if map.contains_key("Plans") {
303-
ordered_keys.push("Plans");
304-
}
291+
let mut keys = map.keys().map(String::as_str).collect::<Vec<_>>();
292+
keys.sort_unstable();
305293

306-
if map.contains_key("Output") {
307-
ordered_keys.push("Output");
308-
}
309-
} else {
310-
let mut keys = map.keys().map(String::as_str).collect::<Vec<_>>();
311-
keys.sort_unstable();
312-
ordered_keys.extend(keys);
294+
if !map.contains_key(NODE_TYPE_KEY) {
295+
return keys;
313296
}
314297

298+
let mut ordered_keys = Vec::with_capacity(keys.len());
299+
ordered_keys.push(NODE_TYPE_KEY);
300+
ordered_keys.extend(
301+
keys.iter()
302+
.copied()
303+
.filter(|key| !matches!(*key, NODE_TYPE_KEY | PLANS_KEY | OUTPUT_KEY)),
304+
);
305+
ordered_keys.extend(
306+
[PLANS_KEY, OUTPUT_KEY]
307+
.into_iter()
308+
.filter(|key| map.contains_key(*key)),
309+
);
315310
ordered_keys
316311
}
317312

313+
fn json_to_string(value: &serde_json::Value) -> datafusion_common::Result<String> {
314+
serde_json::to_string(value).map_err(|e| DataFusionError::External(Box::new(e)))
315+
}
316+
317+
fn json_key_to_string(key: &str) -> datafusion_common::Result<String> {
318+
serde_json::to_string(key).map_err(|e| DataFusionError::External(Box::new(e)))
319+
}
320+
321+
fn push_indent(buf: &mut String, indent: usize) {
322+
buf.push_str(&" ".repeat(indent * 2));
323+
}
324+
318325
fn write_ordered_json(
319326
value: &serde_json::Value,
320327
buf: &mut String,
@@ -325,9 +332,7 @@ fn write_ordered_json(
325332
| serde_json::Value::Bool(_)
326333
| serde_json::Value::Number(_)
327334
| serde_json::Value::String(_) => {
328-
let scalar = serde_json::to_string(value)
329-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
330-
buf.push_str(&scalar);
335+
buf.push_str(&json_to_string(value)?);
331336
}
332337
serde_json::Value::Array(values) => {
333338
if values.is_empty() {
@@ -338,14 +343,14 @@ fn write_ordered_json(
338343
buf.push('[');
339344
buf.push('\n');
340345
for (idx, value) in values.iter().enumerate() {
341-
buf.push_str(&" ".repeat((indent + 1) * 2));
346+
push_indent(buf, indent + 1);
342347
write_ordered_json(value, buf, indent + 1)?;
343348
if idx + 1 != values.len() {
344349
buf.push(',');
345350
}
346351
buf.push('\n');
347352
}
348-
buf.push_str(&" ".repeat(indent * 2));
353+
push_indent(buf, indent);
349354
buf.push(']');
350355
}
351356
serde_json::Value::Object(map) => {
@@ -363,10 +368,8 @@ fn write_ordered_json(
363368
.get(*key)
364369
.ok_or_else(|| internal_datafusion_err!("Missing key in object!"))?;
365370

366-
buf.push_str(&" ".repeat((indent + 1) * 2));
367-
let escaped_key = serde_json::to_string(key)
368-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
369-
buf.push_str(&escaped_key);
371+
push_indent(buf, indent + 1);
372+
buf.push_str(&json_key_to_string(key)?);
370373
buf.push_str(": ");
371374
write_ordered_json(value, buf, indent + 1)?;
372375

@@ -376,7 +379,7 @@ fn write_ordered_json(
376379
buf.push('\n');
377380
}
378381

379-
buf.push_str(&" ".repeat(indent * 2));
382+
push_indent(buf, indent);
380383
buf.push('}');
381384
}
382385
}

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -413,28 +413,20 @@ macro_rules! min_max {
413413
min_max_generic!(lhs, rhs, $OP)
414414
}
415415

416-
(
417-
ScalarValue::Dictionary(key_type, lhs_inner),
418-
ScalarValue::Dictionary(_, rhs_inner),
419-
) => {
420-
wrap_dictionary_scalar(
421-
key_type.as_ref(),
422-
min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP),
423-
)
424-
}
425-
426-
(
427-
ScalarValue::Dictionary(_, lhs_inner),
428-
rhs,
429-
) => {
430-
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
431-
}
432-
433416
(
434417
lhs,
435-
ScalarValue::Dictionary(_, rhs_inner),
436-
) => {
437-
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
418+
rhs,
419+
) if matches!(lhs, ScalarValue::Dictionary(_, _))
420+
|| matches!(rhs, ScalarValue::Dictionary(_, _)) =>
421+
{
422+
let (lhs_key_type, lhs) = unpack_dictionary_scalar(lhs);
423+
let (rhs_key_type, rhs) = unpack_dictionary_scalar(rhs);
424+
425+
let value = min_max_generic!(lhs, rhs, $OP);
426+
match (lhs_key_type, rhs_key_type) {
427+
(Some(key_type), Some(_)) => wrap_dictionary_scalar(key_type, value),
428+
_ => value,
429+
}
438430
}
439431

440432
e => {
@@ -447,10 +439,15 @@ macro_rules! min_max {
447439
}};
448440
}
449441

450-
fn dictionary_batch_extreme(
451-
values: &ArrayRef,
452-
ordering: Ordering,
453-
) -> Result<ScalarValue> {
442+
fn unpack_dictionary_scalar(value: &ScalarValue) -> (Option<&DataType>, &ScalarValue) {
443+
if let ScalarValue::Dictionary(key_type, inner) = value {
444+
(Some(key_type.as_ref()), inner.as_ref())
445+
} else {
446+
(None, value)
447+
}
448+
}
449+
450+
fn scalar_extreme(values: &ArrayRef, ordering: Ordering) -> Result<Option<ScalarValue>> {
454451
let mut extreme: Option<ScalarValue> = None;
455452

456453
for i in 0..values.len() {
@@ -465,7 +462,15 @@ fn dictionary_batch_extreme(
465462
}
466463
}
467464

468-
extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
465+
Ok(extreme)
466+
}
467+
468+
fn dictionary_batch_extreme(
469+
values: &ArrayRef,
470+
ordering: Ordering,
471+
) -> Result<ScalarValue> {
472+
scalar_extreme(values, ordering)?
473+
.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
469474
}
470475

471476
fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
@@ -824,23 +829,8 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
824829

825830
/// Generic min/max implementation for complex types
826831
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
827-
let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i));
828-
let Some(first_idx) = non_null_indices.next() else {
829-
return ScalarValue::try_from(array.data_type());
830-
};
831-
832-
let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
833-
for i in non_null_indices {
834-
let current = ScalarValue::try_from_array(array, i)?;
835-
if current.is_null() {
836-
continue;
837-
}
838-
if extreme.is_null() || extreme.try_cmp(&current)? == ordering {
839-
extreme = current;
840-
}
841-
}
842-
843-
Ok(extreme)
832+
scalar_extreme(array, ordering)?
833+
.map_or_else(|| ScalarValue::try_from(array.data_type()), Ok)
844834
}
845835

846836
/// dynamically-typed max(array) -> ScalarValue

0 commit comments

Comments
 (0)