Skip to content

Commit 9786e43

Browse files
committed
feat: implement map_sort native function
1 parent cb49ff6 commit 9786e43

4 files changed

Lines changed: 214 additions & 3 deletions

File tree

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mu
2121
use crate::math_funcs::modulo_expr::spark_modulo;
2222
use crate::{
2323
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
24-
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
25-
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc,
26-
SparkStringSpace,
24+
spark_isnan, spark_lpad, spark_make_decimal, spark_map_sort, spark_read_side_padding,
25+
spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount,
26+
SparkDateTrunc, SparkSizeFunc, SparkStringSpace,
2727
};
2828
use arrow::datatypes::DataType;
2929
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -181,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
181181
let func = Arc::new(abs);
182182
make_comet_scalar_udf!("abs", func, without data_type)
183183
}
184+
"map_sort" => {
185+
let func = Arc::new(spark_map_sort);
186+
make_comet_scalar_udf!("map_sort", func, without data_type)
187+
}
184188
_ => registry.udf(fun_name).map_err(|e| {
185189
DataFusionError::Execution(format!(
186190
"Function {fun_name} not found in the registry: {e}",

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ mod bitwise_funcs;
4444
mod comet_scalar_funcs;
4545
pub mod hash_funcs;
4646

47+
mod map_funcs;
4748
mod string_funcs;
4849

4950
mod datetime_funcs;
@@ -63,6 +64,7 @@ pub use array_funcs::*;
6364
pub use bitwise_funcs::*;
6465
pub use conditional_funcs::*;
6566
pub use conversion_funcs::*;
67+
pub use map_funcs::*;
6668
pub use nondetermenistic_funcs::*;
6769

6870
pub use comet_scalar_funcs::{
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef, MapArray};
19+
use arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions};
20+
use arrow::datatypes::{DataType, Field};
21+
use datafusion::common::{exec_err, DataFusionError, ScalarValue};
22+
use datafusion::logical_expr::ColumnarValue;
23+
use std::sync::Arc;
24+
25+
pub fn spark_map_sort(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
26+
if args.len() != 1 {
27+
return exec_err!("map_sort function takes exactly one argument");
28+
}
29+
30+
match &args[0] {
31+
ColumnarValue::Array(array) => {
32+
let result = spark_map_sort_array(array)?;
33+
Ok(ColumnarValue::Array(result))
34+
}
35+
ColumnarValue::Scalar(scalar) => {
36+
let result = spark_map_sort_scalar(scalar)?;
37+
Ok(ColumnarValue::Scalar(result))
38+
}
39+
}
40+
}
41+
42+
fn spark_map_sort_array(array: &ArrayRef) -> Result<ArrayRef, DataFusionError> {
43+
let map_array = array
44+
.as_any()
45+
.downcast_ref::<MapArray>()
46+
.ok_or_else(|| DataFusionError::Internal("Expected MapArray".to_string()))?;
47+
48+
let entries = map_array.entries();
49+
let struct_array = entries
50+
.as_any()
51+
.downcast_ref::<arrow::array::StructArray>()
52+
.ok_or_else(|| DataFusionError::Internal("Expected StructArray for entries".to_string()))?;
53+
54+
if struct_array.num_columns() != 2 {
55+
return exec_err!("Map entries must have exactly 2 columns (keys and values)");
56+
}
57+
58+
let keys = struct_array.column(0);
59+
let values = struct_array.column(1);
60+
let offsets = map_array.offsets();
61+
62+
let mut sorted_keys_arrays = Vec::new();
63+
let mut sorted_values_arrays = Vec::new();
64+
let mut new_offsets = Vec::with_capacity(map_array.len() + 1);
65+
new_offsets.push(0i32);
66+
67+
for row_idx in 0..map_array.len() {
68+
let start = offsets[row_idx] as usize;
69+
let end = offsets[row_idx + 1] as usize;
70+
let len = end - start;
71+
72+
if len == 0 {
73+
new_offsets.push(new_offsets[row_idx]);
74+
continue;
75+
}
76+
77+
let row_keys = keys.slice(start, len);
78+
let row_values = values.slice(start, len);
79+
80+
if len == 1 {
81+
sorted_keys_arrays.push(row_keys);
82+
sorted_values_arrays.push(row_values);
83+
new_offsets.push(new_offsets[row_idx] + len as i32);
84+
continue;
85+
}
86+
87+
let sort_columns = vec![SortColumn {
88+
values: Arc::clone(&row_keys),
89+
options: Some(SortOptions {
90+
descending: false,
91+
nulls_first: false,
92+
}),
93+
}];
94+
95+
let indices = lexsort_to_indices(&sort_columns, None)?;
96+
let sorted_keys = take(&row_keys, &indices, None)?;
97+
let sorted_values = take(&row_values, &indices, None)?;
98+
99+
sorted_keys_arrays.push(sorted_keys);
100+
sorted_values_arrays.push(sorted_values);
101+
new_offsets.push(new_offsets[row_idx] + len as i32);
102+
}
103+
104+
if sorted_keys_arrays.is_empty() {
105+
let key_field = Arc::new(Field::new(
106+
"key",
107+
keys.data_type().clone(),
108+
keys.is_nullable(),
109+
));
110+
let value_field = Arc::new(Field::new(
111+
"value",
112+
values.data_type().clone(),
113+
values.is_nullable(),
114+
));
115+
let entries_field = Arc::new(Field::new(
116+
"entries",
117+
DataType::Struct(vec![Arc::clone(&key_field), Arc::clone(&value_field)].into()),
118+
false,
119+
));
120+
121+
let empty_keys = arrow::array::new_empty_array(keys.data_type());
122+
let empty_values = arrow::array::new_empty_array(values.data_type());
123+
let empty_entries = arrow::array::StructArray::new(
124+
vec![key_field, value_field].into(),
125+
vec![empty_keys, empty_values],
126+
None,
127+
);
128+
129+
return Ok(Arc::new(MapArray::new(
130+
entries_field,
131+
arrow::buffer::OffsetBuffer::new(vec![0i32; map_array.len() + 1].into()),
132+
empty_entries,
133+
map_array.nulls().cloned(),
134+
false,
135+
)));
136+
}
137+
138+
let sorted_keys_refs: Vec<&dyn Array> = sorted_keys_arrays.iter().map(|a| a.as_ref()).collect();
139+
let sorted_values_refs: Vec<&dyn Array> =
140+
sorted_values_arrays.iter().map(|a| a.as_ref()).collect();
141+
142+
let concatenated_keys = arrow::compute::concat(&sorted_keys_refs)?;
143+
let concatenated_values = arrow::compute::concat(&sorted_values_refs)?;
144+
145+
let key_field = Arc::new(Field::new(
146+
"key",
147+
keys.data_type().clone(),
148+
keys.is_nullable(),
149+
));
150+
let value_field = Arc::new(Field::new(
151+
"value",
152+
values.data_type().clone(),
153+
values.is_nullable(),
154+
));
155+
156+
let sorted_entries = arrow::array::StructArray::new(
157+
vec![Arc::clone(&key_field), Arc::clone(&value_field)].into(),
158+
vec![concatenated_keys, concatenated_values],
159+
None,
160+
);
161+
162+
let entries_field = Arc::new(Field::new(
163+
"entries",
164+
DataType::Struct(vec![key_field, value_field].into()),
165+
false,
166+
));
167+
168+
Ok(Arc::new(MapArray::new(
169+
entries_field,
170+
arrow::buffer::OffsetBuffer::new(new_offsets.into()),
171+
sorted_entries,
172+
map_array.nulls().cloned(),
173+
false,
174+
)))
175+
}
176+
177+
fn spark_map_sort_scalar(scalar: &ScalarValue) -> Result<ScalarValue, DataFusionError> {
178+
match scalar {
179+
ScalarValue::Null => Ok(ScalarValue::Null),
180+
_ => exec_err!(
181+
"map_sort scalar function only supports map types, got: {:?}",
182+
scalar
183+
),
184+
}
185+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod map_sort;
19+
20+
pub use map_sort::spark_map_sort;

0 commit comments

Comments
 (0)