Skip to content

Commit d4fb7ef

Browse files
authored
functions: Add dict support for get field (#21115)
## Which issue does this PR close? Closes #21113 ## What changes are included in this PR? Support structs in `get_field`. ## Are these changes tested? Yes, see tests, and I also replaced this in our code base and it had the effect we wanted. ## Are there any user-facing changes? No just added support for existing APIs. @alamb
1 parent 9aa1413 commit d4fb7ef

3 files changed

Lines changed: 437 additions & 5 deletions

File tree

datafusion/functions/src/core/getfield.rs

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use std::sync::Arc;
1919

2020
use arrow::array::{
21-
Array, BooleanArray, Capacities, MutableArrayData, Scalar, make_array,
21+
Array, BooleanArray, Capacities, MutableArrayData, Scalar, cast::AsArray, make_array,
2222
make_comparator,
2323
};
2424
use arrow::compute::SortOptions;
@@ -27,7 +27,7 @@ use arrow_buffer::NullBuffer;
2727

2828
use datafusion_common::cast::{as_map_array, as_struct_array};
2929
use datafusion_common::{
30-
Result, ScalarValue, exec_err, internal_err, plan_datafusion_err,
30+
Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, plan_datafusion_err,
3131
};
3232
use datafusion_expr::expr::ScalarFunction;
3333
use datafusion_expr::simplify::ExprSimplifyResult;
@@ -198,6 +198,24 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
198198
let string_value = name.try_as_str().flatten().map(|s| s.to_string());
199199

200200
match (array.data_type(), name, string_value) {
201+
// Dictionary-encoded struct: extract the field from the dictionary's
202+
// values (the deduplicated struct array) and rebuild a dictionary with
203+
// the same keys. This preserves dictionary encoding without expanding.
204+
(DataType::Dictionary(_, value_type), _, Some(field_name))
205+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
206+
{
207+
let dict = array.as_any_dictionary();
208+
let values_struct = dict.values().as_struct();
209+
let field_col =
210+
values_struct.column_by_name(&field_name).ok_or_else(|| {
211+
exec_datafusion_err!(
212+
"Field {field_name} not found in dictionary struct"
213+
)
214+
})?;
215+
Ok(ColumnarValue::Array(
216+
dict.with_values(Arc::clone(field_col)),
217+
))
218+
}
201219
(DataType::Map(_, _), ScalarValue::List(arr), _) => {
202220
let key_array: Arc<dyn Array> = arr;
203221
process_map_array(&array, key_array)
@@ -333,6 +351,42 @@ impl ScalarUDFImpl for GetFieldFunc {
333351
}
334352
}
335353
}
354+
// Dictionary-encoded struct: resolve the child field from
355+
// the underlying struct, then wrap the result back in the
356+
// same Dictionary type so the promised type matches execution.
357+
DataType::Dictionary(key_type, value_type)
358+
if matches!(value_type.as_ref(), DataType::Struct(_)) =>
359+
{
360+
let DataType::Struct(fields) = value_type.as_ref() else {
361+
unreachable!()
362+
};
363+
let field_name = sv
364+
.as_ref()
365+
.and_then(|sv| {
366+
sv.try_as_str().flatten().filter(|s| !s.is_empty())
367+
})
368+
.ok_or_else(|| {
369+
exec_datafusion_err!("Field name must be a non-empty string")
370+
})?;
371+
372+
let child_field = fields
373+
.iter()
374+
.find(|f| f.name() == field_name)
375+
.ok_or_else(|| {
376+
plan_datafusion_err!("Field {field_name} not found in struct")
377+
})?;
378+
379+
let dict_type = DataType::Dictionary(
380+
key_type.clone(),
381+
Box::new(child_field.data_type().clone()),
382+
);
383+
let mut new_field =
384+
child_field.as_ref().clone().with_data_type(dict_type);
385+
if current_field.is_nullable() {
386+
new_field = new_field.with_nullable(true);
387+
}
388+
current_field = Arc::new(new_field);
389+
}
336390
DataType::Struct(fields) => {
337391
let field_name = sv
338392
.as_ref()
@@ -560,6 +614,133 @@ mod tests {
560614
Ok(())
561615
}
562616

617+
#[test]
618+
fn test_get_field_dict_encoded_struct() -> Result<()> {
619+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
620+
use arrow::datatypes::UInt32Type;
621+
622+
let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef;
623+
let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
624+
625+
let struct_fields: Fields = vec![
626+
Field::new("name", DataType::Utf8, false),
627+
Field::new("id", DataType::Int32, false),
628+
]
629+
.into();
630+
631+
let values_struct =
632+
Arc::new(StructArray::new(struct_fields, vec![names, ids], None)) as ArrayRef;
633+
634+
let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
635+
let dict = DictionaryArray::<UInt32Type>::try_new(keys, values_struct)?;
636+
637+
let base = ColumnarValue::Array(Arc::new(dict));
638+
let key = ScalarValue::Utf8(Some("name".to_string()));
639+
640+
let result = extract_single_field(base, key)?;
641+
let result_array = result.into_array(5)?;
642+
643+
assert!(
644+
matches!(result_array.data_type(), DataType::Dictionary(_, _)),
645+
"expected dictionary output, got {:?}",
646+
result_array.data_type()
647+
);
648+
649+
let result_dict = result_array
650+
.as_any()
651+
.downcast_ref::<DictionaryArray<UInt32Type>>()
652+
.unwrap();
653+
assert_eq!(result_dict.values().len(), 3);
654+
assert_eq!(result_dict.len(), 5);
655+
656+
let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?;
657+
let string_arr = resolved.as_any().downcast_ref::<StringArray>().unwrap();
658+
assert_eq!(string_arr.value(0), "main");
659+
assert_eq!(string_arr.value(1), "foo");
660+
assert_eq!(string_arr.value(2), "bar");
661+
assert_eq!(string_arr.value(3), "main");
662+
assert_eq!(string_arr.value(4), "foo");
663+
664+
Ok(())
665+
}
666+
667+
#[test]
668+
fn test_get_field_nested_dict_struct() -> Result<()> {
669+
use arrow::array::{DictionaryArray, StringArray, UInt32Array};
670+
use arrow::datatypes::UInt32Type;
671+
672+
let func_names = Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef;
673+
let func_files = Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef;
674+
let func_fields: Fields = vec![
675+
Field::new("name", DataType::Utf8, false),
676+
Field::new("file", DataType::Utf8, false),
677+
]
678+
.into();
679+
let func_struct = Arc::new(StructArray::new(
680+
func_fields.clone(),
681+
vec![func_names, func_files],
682+
None,
683+
)) as ArrayRef;
684+
let func_dict = Arc::new(DictionaryArray::<UInt32Type>::try_new(
685+
UInt32Array::from(vec![0u32, 1, 0]),
686+
func_struct,
687+
)?) as ArrayRef;
688+
689+
let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef;
690+
let line_fields: Fields = vec![
691+
Field::new("num", DataType::Int32, false),
692+
Field::new(
693+
"function",
694+
DataType::Dictionary(
695+
Box::new(DataType::UInt32),
696+
Box::new(DataType::Struct(func_fields)),
697+
),
698+
false,
699+
),
700+
]
701+
.into();
702+
let line_struct = StructArray::new(line_fields, vec![line_nums, func_dict], None);
703+
704+
let base = ColumnarValue::Array(Arc::new(line_struct));
705+
706+
let func_result =
707+
extract_single_field(base, ScalarValue::Utf8(Some("function".to_string())))?;
708+
709+
let func_array = func_result.into_array(3)?;
710+
assert!(
711+
matches!(func_array.data_type(), DataType::Dictionary(_, _)),
712+
"expected dictionary for function, got {:?}",
713+
func_array.data_type()
714+
);
715+
716+
let name_result = extract_single_field(
717+
ColumnarValue::Array(func_array),
718+
ScalarValue::Utf8(Some("name".to_string())),
719+
)?;
720+
let name_array = name_result.into_array(3)?;
721+
722+
assert!(
723+
matches!(name_array.data_type(), DataType::Dictionary(_, _)),
724+
"expected dictionary for name, got {:?}",
725+
name_array.data_type()
726+
);
727+
728+
let name_dict = name_array
729+
.as_any()
730+
.downcast_ref::<DictionaryArray<UInt32Type>>()
731+
.unwrap();
732+
assert_eq!(name_dict.values().len(), 2);
733+
assert_eq!(name_dict.len(), 3);
734+
735+
let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?;
736+
let strings = resolved.as_any().downcast_ref::<StringArray>().unwrap();
737+
assert_eq!(strings.value(0), "main");
738+
assert_eq!(strings.value(1), "foo");
739+
assert_eq!(strings.value(2), "main");
740+
741+
Ok(())
742+
}
743+
563744
#[test]
564745
fn test_placement_literal_key() {
565746
let func = GetFieldFunc::new();

datafusion/sqllogictest/src/test_context.rs

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ use std::sync::Arc;
2323
use std::vec;
2424

2525
use arrow::array::{
26-
Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray,
27-
LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray,
26+
Array, ArrayRef, BinaryArray, DictionaryArray, Float64Array, Int32Array,
27+
LargeBinaryArray, LargeStringArray, StringArray, StructArray,
28+
TimestampNanosecondArray, UInt32Array, UnionArray,
2829
};
2930
use arrow::buffer::ScalarBuffer;
3031
use arrow::datatypes::{
31-
DataType, Field, FieldRef, Schema, SchemaRef, TimeUnit, UnionFields,
32+
DataType, Field, FieldRef, Fields, Schema, SchemaRef, TimeUnit, UInt32Type,
33+
UnionFields,
3234
};
3335
use arrow::record_batch::RecordBatch;
3436
use datafusion::catalog::{
@@ -173,6 +175,10 @@ impl TestContext {
173175
info!("Registering table with union column");
174176
register_union_table(test_ctx.session_ctx())
175177
}
178+
"dictionary_struct.slt" => {
179+
info!("Registering table with dictionary-encoded struct column");
180+
register_dictionary_struct_table(test_ctx.session_ctx());
181+
}
176182
"async_udf.slt" => {
177183
info!("Registering dummy async udf");
178184
register_async_abs_udf(test_ctx.session_ctx())
@@ -575,6 +581,97 @@ fn register_union_table(ctx: &SessionContext) {
575581
ctx.register_batch("union_table", batch).unwrap();
576582
}
577583

584+
fn register_dictionary_struct_table(ctx: &SessionContext) {
585+
// Build deduplicated struct values: 3 unique structs
586+
let names = Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])) as ArrayRef;
587+
let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
588+
589+
let struct_fields: Fields = vec![
590+
Field::new("name", DataType::Utf8, false),
591+
Field::new("id", DataType::Int32, false),
592+
]
593+
.into();
594+
595+
let values_struct = Arc::new(
596+
StructArray::try_new(struct_fields.clone(), vec![names, ids], None).unwrap(),
597+
) as ArrayRef;
598+
599+
// Dictionary keys index into the 3-element struct array.
600+
// 5 rows with repeated references to test dictionary deduplication.
601+
let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
602+
let dict =
603+
DictionaryArray::<UInt32Type>::try_new(keys, Arc::clone(&values_struct)).unwrap();
604+
605+
// Also build a non-dictionary plain struct column for comparison.
606+
let plain_names = Arc::new(StringArray::from(vec![
607+
"Alice", "Bob", "Carol", "Alice", "Bob",
608+
])) as ArrayRef;
609+
let plain_ids = Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2])) as ArrayRef;
610+
let plain_struct =
611+
StructArray::try_new(struct_fields.clone(), vec![plain_names, plain_ids], None)
612+
.unwrap();
613+
614+
let dict_type = DataType::Dictionary(
615+
Box::new(DataType::UInt32),
616+
Box::new(DataType::Struct(struct_fields.clone())),
617+
);
618+
619+
let schema = Schema::new(vec![
620+
Field::new("dict_struct", dict_type, false),
621+
Field::new(
622+
"plain_struct",
623+
DataType::Struct(struct_fields.clone()),
624+
false,
625+
),
626+
]);
627+
628+
let batch = RecordBatch::try_new(
629+
Arc::new(schema),
630+
vec![
631+
Arc::new(dict) as ArrayRef,
632+
Arc::new(plain_struct) as ArrayRef,
633+
],
634+
)
635+
.unwrap();
636+
637+
ctx.register_batch("dict_struct_table", batch).unwrap();
638+
639+
// Second table: dictionary-encoded struct with nullable entries
640+
let names_nullable = Arc::new(StringArray::from(vec!["X", "Y"])) as ArrayRef;
641+
let ids_nullable = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef;
642+
let struct_fields_nullable: Fields = vec![
643+
Field::new("name", DataType::Utf8, false),
644+
Field::new("id", DataType::Int32, false),
645+
]
646+
.into();
647+
let values_struct_nullable = Arc::new(
648+
StructArray::try_new(
649+
struct_fields_nullable.clone(),
650+
vec![names_nullable, ids_nullable],
651+
None,
652+
)
653+
.unwrap(),
654+
) as ArrayRef;
655+
let keys_nullable = UInt32Array::from(vec![Some(0), None, Some(1), None]);
656+
let dict_nullable =
657+
DictionaryArray::<UInt32Type>::try_new(keys_nullable, values_struct_nullable)
658+
.unwrap();
659+
660+
let dict_type_nullable = DataType::Dictionary(
661+
Box::new(DataType::UInt32),
662+
Box::new(DataType::Struct(struct_fields_nullable)),
663+
);
664+
665+
let schema_nullable = Schema::new(vec![Field::new("ds", dict_type_nullable, true)]);
666+
let batch_nullable = RecordBatch::try_new(
667+
Arc::new(schema_nullable),
668+
vec![Arc::new(dict_nullable) as ArrayRef],
669+
)
670+
.unwrap();
671+
ctx.register_batch("dict_struct_nullable", batch_nullable)
672+
.unwrap();
673+
}
674+
578675
fn register_async_abs_udf(ctx: &SessionContext) {
579676
#[derive(Debug, PartialEq, Eq, Hash)]
580677
struct AsyncAbs {

0 commit comments

Comments
 (0)