Skip to content

Commit 1620e33

Browse files
committed
refactor: address review feedback on MapSort
- Use single global sort_to_indices+take instead of per-map take+concat - Add early-out fast paths (empty array, all-null, is_sorted=true) - Fall back to Spark for floating-point map keys when strictFloatingPoint=true - Clean up Arc::clone calls and replace .unwrap() on downcasts with .expect - Document MapSort behavior in map expressions compatibility guide
1 parent bb0678f commit 1620e33

3 files changed

Lines changed: 52 additions & 29 deletions

File tree

docs/source/user-guide/latest/compatibility/expressions/map.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,16 @@ under the License.
1919

2020
# Map Expressions
2121

22+
## MapSort (Spark 4.0+)
23+
24+
Spark 4.0 inserts `MapSort` to normalize map values when they appear in shuffle hash partitioning
25+
keys, in `try_element_at`, and in other contexts where map ordering must be deterministic. Comet
26+
runs `MapSort` natively, so map shuffle and group-by-on-map stay on Comet under Spark 4.0.
27+
28+
When `spark.comet.exec.strictFloatingPoint=true`, `MapSort` falls back to Spark for maps whose
29+
keys contain `Float` or `Double` (consistent with `SortOrder` and `SortArray`). Arrow's sort uses
30+
IEEE total ordering for floating-point, which differs from Spark's `Double.compare` semantics for
31+
`NaN` and `-0.0`.
32+
2233
<!--BEGIN:EXPR_COMPAT[map]-->
2334
<!--END:EXPR_COMPAT-->

native/spark-expr/src/map_funcs/map_sort.rs

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayRef, MapArray, StructArray};
19-
use arrow::compute::{concat, sort_to_indices, take, SortOptions};
18+
use arrow::array::{Array, ArrayRef, MapArray, StructArray, UInt32Array};
19+
use arrow::compute::{sort_to_indices, take, SortOptions};
2020
use arrow::datatypes::DataType;
2121
use datafusion::common::{exec_err, DataFusionError};
2222
use datafusion::physical_plan::ColumnarValue;
@@ -31,59 +31,62 @@ pub fn spark_map_sort(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusio
3131
}
3232

3333
let arr_arg: ArrayRef = match &args[0] {
34-
ColumnarValue::Array(array) => Arc::<dyn Array>::clone(array),
34+
ColumnarValue::Array(array) => Arc::clone(array),
3535
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?,
3636
};
3737

3838
let (maps_arg, map_field, is_sorted) = match arr_arg.data_type() {
3939
DataType::Map(map_field, is_sorted) => {
40-
let maps_arg = arr_arg.as_any().downcast_ref::<MapArray>().unwrap();
40+
let maps_arg = arr_arg
41+
.as_any()
42+
.downcast_ref::<MapArray>()
43+
.expect("invariant: array data type is Map but downcast to MapArray failed");
4144
(maps_arg, map_field, is_sorted)
4245
}
4346
_ => return exec_err!("spark_map_sort expects Map type as argument"),
4447
};
4548

49+
// Fast paths: nothing to sort, all maps null, or input already declared sorted.
50+
if maps_arg.is_empty() || maps_arg.null_count() == maps_arg.len() || *is_sorted {
51+
return Ok(ColumnarValue::Array(arr_arg));
52+
}
53+
4654
let maps_arg_entries = maps_arg.entries();
4755
let maps_arg_offsets = maps_arg.offsets();
4856

49-
let mut sorted_map_entries_vec: Vec<ArrayRef> = Vec::with_capacity(maps_arg.len());
57+
let sort_options = SortOptions {
58+
descending: false,
59+
nulls_first: true,
60+
};
61+
62+
// Build one global permutation over the full entries struct, respecting per-map boundaries,
63+
// then issue a single `take`. This avoids per-map struct copies and a final `concat`.
64+
let mut global_indices: Vec<u32> = Vec::with_capacity(maps_arg_entries.len());
5065

5166
for idx in 0..maps_arg.len() {
5267
let map_start = maps_arg_offsets[idx] as usize;
5368
let map_end = maps_arg_offsets[idx + 1] as usize;
54-
let map_len = map_end - map_start;
55-
56-
let map_entries = maps_arg_entries.slice(map_start, map_len);
57-
58-
if map_len == 0 {
59-
sorted_map_entries_vec.push(Arc::new(map_entries));
69+
if map_end == map_start {
6070
continue;
6171
}
6272

63-
let map_keys = map_entries.column(0);
64-
let sort_options = SortOptions {
65-
descending: false,
66-
nulls_first: true,
67-
};
68-
let sorted_indices = sort_to_indices(&map_keys, Some(sort_options), None)?;
69-
70-
let sorted_map_entries = take(&map_entries, &sorted_indices, None)?;
71-
sorted_map_entries_vec.push(sorted_map_entries);
73+
let map_keys = maps_arg_entries
74+
.column(0)
75+
.slice(map_start, map_end - map_start);
76+
let local_indices = sort_to_indices(&map_keys, Some(sort_options), None)?;
77+
global_indices.extend(local_indices.values().iter().map(|i| map_start as u32 + *i));
7278
}
7379

74-
let sorted_map_entries_arr: Vec<&dyn Array> = sorted_map_entries_vec
75-
.iter()
76-
.map(|arr| arr.as_ref())
77-
.collect();
78-
let combined_sorted_map_entries = concat(&sorted_map_entries_arr)?;
79-
let sorted_map_struct = combined_sorted_map_entries
80+
let indices = UInt32Array::from(global_indices);
81+
let sorted_entries = take(maps_arg_entries, &indices, None)?;
82+
let sorted_map_struct = sorted_entries
8083
.as_any()
8184
.downcast_ref::<StructArray>()
82-
.unwrap();
85+
.expect("invariant: take on StructArray must return StructArray");
8386

8487
// Preserve the original is_sorted flag to keep schema consistent
8588
let sorted_map_arr = Arc::new(MapArray::try_new(
86-
Arc::<arrow::datatypes::Field>::clone(map_field),
89+
Arc::clone(map_field),
8790
maps_arg.offsets().clone(),
8891
sorted_map_struct.clone(),
8992
maps_arg.nulls().cloned(),

spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.internal.types.StringTypeWithCollation
2727
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, MapType, StringType}
2828

29+
import org.apache.comet.CometConf
2930
import org.apache.comet.CometSparkSessionExtensions.withInfo
3031
import org.apache.comet.expressions.{CometCast, CometEvalMode}
31-
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible}
32+
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible, SupportLevel}
3233
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
3334
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType}
3435

@@ -149,6 +150,14 @@ trait CometExprShim extends CommonStringExprs {
149150
if (!supportedScalarSortElementType(keyType)) {
150151
withInfo(ms, s"MapSort on map with key type $keyType is not supported")
151152
None
153+
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
154+
SupportLevel.containsFloatingPoint(keyType)) {
155+
withInfo(
156+
ms,
157+
"MapSort on floating-point key is not 100% compatible with Spark, and Comet is " +
158+
s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
159+
s"${CometConf.COMPAT_GUIDE}")
160+
None
152161
} else {
153162
val childExpr = exprToProtoInternal(ms.child, inputs, binding)
154163
val mapSortExpr = scalarFunctionExprToProtoWithReturnType(

0 commit comments

Comments
 (0)