diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 1c46bf5e81337..d618bfbcb0b75 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -19,10 +19,14 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{ - Array, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, make_array, + Array, ArrayAccessor, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, + cast::AsArray, make_array, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Field, Int8Type, Int16Type, Int32Type, Int64Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, cast::as_map_array, exec_err}; use datafusion_expr::{ @@ -130,11 +134,41 @@ impl ScalarUDFImpl for MapExtract { } } +/// Fast path for key types that support direct typed value comparison. +/// +/// This avoids the generic single-element slice comparison used by +/// `general_map_extract_inner`. +fn specialized_map_extract_inner( + map_array: &MapArray, + keys: A, + query_keys: A, +) -> Result +where + A: ArrayAccessor, + A::Item: PartialEq, +{ + map_extract_with_match(map_array, move |row_index, key_index| { + query_keys.is_valid(row_index) + && keys.is_valid(key_index) + && keys.value(key_index) == query_keys.value(row_index) + }) +} + fn general_map_extract_inner( map_array: &MapArray, query_keys_array: &dyn Array, ) -> Result { - let keys = map_array.keys(); + map_extract_with_match(map_array, |row_index, key_index| { + let query_key = query_keys_array.slice(row_index, 1); + query_keys_array.is_valid(row_index) + && map_array.keys().slice(key_index, 1).as_ref() == query_key.as_ref() + }) +} + +fn map_extract_with_match( + map_array: &MapArray, + mut key_matches: impl FnMut(usize, usize) -> bool, +) -> Result { let mut offsets = vec![0_i32]; let values = map_array.values(); @@ -147,16 +181,12 @@ fn general_map_extract_inner( for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() { let start = offset_window[0] as usize; let end = offset_window[1] as usize; - let len = end - start; - - let query_key = query_keys_array.slice(row_index, 1); - let value_index = - (0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref()); + (start..end).find(|&key_index| key_matches(row_index, key_index)); match value_index { Some(index) => { - mutable.extend(0, start + index, start + index + 1); + mutable.extend(0, index, index + 1); } None => { mutable.extend_nulls(1); @@ -193,5 +223,87 @@ fn map_extract_inner(args: &[ArrayRef]) -> Result { ); } - general_map_extract_inner(map_array, key_arg) + match key_type { + DataType::Int8 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Int16 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Int32 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Int64 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::UInt8 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::UInt16 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::UInt32 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::UInt64 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Date32 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Date64 => specialized_map_extract_inner( + map_array, + map_array.keys().as_primitive::(), + key_arg.as_primitive::(), + ), + DataType::Utf8 => specialized_map_extract_inner( + map_array, + map_array.keys().as_string::(), + key_arg.as_string::(), + ), + DataType::LargeUtf8 => specialized_map_extract_inner( + map_array, + map_array.keys().as_string::(), + key_arg.as_string::(), + ), + DataType::Utf8View => specialized_map_extract_inner( + map_array, + map_array.keys().as_string_view(), + key_arg.as_string_view(), + ), + DataType::Binary => specialized_map_extract_inner( + map_array, + map_array.keys().as_binary::(), + key_arg.as_binary::(), + ), + DataType::LargeBinary => specialized_map_extract_inner( + map_array, + map_array.keys().as_binary::(), + key_arg.as_binary::(), + ), + DataType::BinaryView => specialized_map_extract_inner( + map_array, + map_array.keys().as_binary_view(), + key_arg.as_binary_view(), + ), + _ => general_map_extract_inner(map_array, key_arg), + } } diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 62e70e6080bab..34cd80c86ecfa 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -642,6 +642,22 @@ select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3 ---- [1] [1] [1] [NULL] [1] +# binary and binary view keys +query ???? +select map_extract(MAP {arrow_cast('a', 'Binary'): 1, arrow_cast('b', 'Binary'): 2}, arrow_cast('b', 'Binary')), + map_extract(MAP {arrow_cast('a', 'LargeBinary'): 1, arrow_cast('b', 'LargeBinary'): 2}, arrow_cast('c', 'LargeBinary')), + map_extract(MAP {arrow_cast('a', 'BinaryView'): 1, arrow_cast('b', 'BinaryView'): 2}, arrow_cast('b', 'BinaryView')), + map_extract(MAP {arrow_cast('a', 'BinaryView'): 1, arrow_cast('b', 'BinaryView'): 2}, arrow_cast(NULL, 'BinaryView')); +---- +[2] [NULL] [2] [NULL] + +# nested datatype keys +query ?? +select map_extract(MAP {[1, 2]: 10, [3, 4]: 20}, [3, 4]), + map_extract(MAP {[1, 2]: 10, [3, 4]: 20}, [5, 6]); +---- +[20] [NULL] + # map_extract with columns query ??? select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1;