Skip to content

Commit 33f514a

Browse files
authored
feat: Expand murmur3 hash support to complex types (#3077)
1 parent c4da3c9 commit 33f514a

8 files changed

Lines changed: 817 additions & 14 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ jobs:
160160
value: |
161161
org.apache.comet.CometExpressionSuite
162162
org.apache.comet.CometExpressionCoverageSuite
163+
org.apache.comet.CometHashExpressionSuite
163164
org.apache.comet.CometTemporalExpressionSuite
164165
org.apache.comet.CometArrayExpressionSuite
165166
org.apache.comet.CometCastSuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ jobs:
123123
value: |
124124
org.apache.comet.CometExpressionSuite
125125
org.apache.comet.CometExpressionCoverageSuite
126+
org.apache.comet.CometHashExpressionSuite
126127
org.apache.comet.CometTemporalExpressionSuite
127128
org.apache.comet.CometArrayExpressionSuite
128129
org.apache.comet.CometCastSuite

native/spark-expr/src/hash_funcs/murmur3.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ pub fn create_murmur3_hashes<'a>(
183183
arrays,
184184
hashes_buffer,
185185
spark_compatible_murmur3_hash,
186-
create_hashes_dictionary
186+
create_hashes_dictionary,
187+
create_murmur3_hashes
187188
);
188189
Ok(hashes_buffer)
189190
}

native/spark-expr/src/hash_funcs/utils.rs

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,62 @@ macro_rules! hash_array_decimal {
206206
};
207207
}
208208

209+
/// Hash a list array by recursively hashing each element.
210+
/// For each row, we hash all elements in the list.
211+
/// Spark hashes arrays by recursively hashing each element, where each
212+
/// element's hash is computed using the previous element's hash as the seed.
213+
/// This creates a chain: hash(elem_n, hash(elem_n-1, ... hash(elem_0, seed)...))
214+
#[macro_export]
215+
macro_rules! hash_list_array {
216+
($array_type:ident, $offset_type:ty, $column: ident, $hashes: ident, $recursive_hash_method: ident) => {
217+
let list_array = $column
218+
.as_any()
219+
.downcast_ref::<$array_type>()
220+
.unwrap_or_else(|| {
221+
panic!(
222+
"Failed to downcast column to {}. Actual data type: {:?}.",
223+
stringify!($array_type),
224+
$column.data_type()
225+
)
226+
});
227+
228+
let values = list_array.values();
229+
let offsets = list_array.offsets();
230+
231+
if list_array.null_count() == 0 {
232+
// Fast path: no nulls, skip null checks
233+
for (row_idx, hash) in $hashes.iter_mut().enumerate() {
234+
let start = offsets[row_idx] as usize;
235+
let end = offsets[row_idx + 1] as usize;
236+
let len = end - start;
237+
// Hash each element in sequence, chaining the hash values
238+
for elem_idx in 0..len {
239+
let elem_array = values.slice(start + elem_idx, 1);
240+
let mut single_hash = [*hash];
241+
$recursive_hash_method(&[elem_array], &mut single_hash)?;
242+
*hash = single_hash[0];
243+
}
244+
}
245+
} else {
246+
// Slow path: array has nulls, check each row
247+
for (row_idx, hash) in $hashes.iter_mut().enumerate() {
248+
if !list_array.is_null(row_idx) {
249+
let start = offsets[row_idx] as usize;
250+
let end = offsets[row_idx + 1] as usize;
251+
let len = end - start;
252+
// Hash each element in sequence, chaining the hash values
253+
for elem_idx in 0..len {
254+
let elem_array = values.slice(start + elem_idx, 1);
255+
let mut single_hash = [*hash];
256+
$recursive_hash_method(&[elem_array], &mut single_hash)?;
257+
*hash = single_hash[0];
258+
}
259+
}
260+
}
261+
}
262+
};
263+
}
264+
209265
/// Creates hash values for every row, based on the values in the
210266
/// columns.
211267
///
@@ -214,9 +270,10 @@ macro_rules! hash_array_decimal {
214270
///
215271
/// `hash_method` is the hash function to use.
216272
/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input.
273+
/// `recursive_hash_method` is the function to call for recursive hashing of complex types.
217274
#[macro_export]
218275
macro_rules! create_hashes_internal {
219-
($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => {
276+
($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident, $recursive_hash_method: ident) => {
220277
use arrow::datatypes::{DataType, TimeUnit};
221278
use arrow::array::{types::*, *};
222279

@@ -425,6 +482,105 @@ macro_rules! create_hashes_internal {
425482
)))
426483
}
427484
},
485+
DataType::List(_) => {
486+
$crate::hash_list_array!(ListArray, i32, col, $hashes_buffer, $recursive_hash_method);
487+
}
488+
DataType::LargeList(_) => {
489+
$crate::hash_list_array!(LargeListArray, i64, col, $hashes_buffer, $recursive_hash_method);
490+
}
491+
DataType::FixedSizeList(_, size) => {
492+
let list_array = col.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
493+
let values = list_array.values();
494+
let list_size = *size as usize;
495+
496+
if list_array.null_count() == 0 {
497+
// Fast path: no nulls, skip null checks
498+
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
499+
let start = row_idx * list_size;
500+
// Hash each element in sequence, chaining the hash values
501+
for elem_idx in 0..list_size {
502+
let elem_array = values.slice(start + elem_idx, 1);
503+
let mut single_hash = [*hash];
504+
$recursive_hash_method(&[elem_array], &mut single_hash)?;
505+
*hash = single_hash[0];
506+
}
507+
}
508+
} else {
509+
// Slow path: array has nulls, check each row
510+
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
511+
if !list_array.is_null(row_idx) {
512+
let start = row_idx * list_size;
513+
// Hash each element in sequence, chaining the hash values
514+
for elem_idx in 0..list_size {
515+
let elem_array = values.slice(start + elem_idx, 1);
516+
let mut single_hash = [*hash];
517+
$recursive_hash_method(&[elem_array], &mut single_hash)?;
518+
*hash = single_hash[0];
519+
}
520+
}
521+
}
522+
}
523+
}
524+
DataType::Struct(_) => {
525+
let struct_array = col.as_any().downcast_ref::<StructArray>().unwrap();
526+
// Hash each field of the struct - Spark hashes all fields recursively
527+
let columns: Vec<ArrayRef> = struct_array.columns().to_vec();
528+
if !columns.is_empty() {
529+
$recursive_hash_method(&columns, $hashes_buffer)?;
530+
}
531+
}
532+
DataType::Map(_, _) => {
533+
let map_array = col.as_any().downcast_ref::<MapArray>().unwrap();
534+
// For maps, Spark hashes by iterating through (key, value) pairs
535+
// For each entry, hash the key then the value
536+
let keys = map_array.keys();
537+
let values = map_array.values();
538+
let offsets = map_array.offsets();
539+
540+
if map_array.null_count() == 0 {
541+
// Fast path: no nulls, skip null checks
542+
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
543+
let start = offsets[row_idx] as usize;
544+
let end = offsets[row_idx + 1] as usize;
545+
// Hash each key-value pair in sequence
546+
for entry_idx in start..end {
547+
// Hash the key
548+
let key_array = keys.slice(entry_idx, 1);
549+
let mut single_hash = [*hash];
550+
$recursive_hash_method(&[key_array], &mut single_hash)?;
551+
*hash = single_hash[0];
552+
553+
// Hash the value
554+
let value_array = values.slice(entry_idx, 1);
555+
single_hash = [*hash];
556+
$recursive_hash_method(&[value_array], &mut single_hash)?;
557+
*hash = single_hash[0];
558+
}
559+
}
560+
} else {
561+
// Slow path: array has nulls, check each row
562+
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
563+
if !map_array.is_null(row_idx) {
564+
let start = offsets[row_idx] as usize;
565+
let end = offsets[row_idx + 1] as usize;
566+
// Hash each key-value pair in sequence
567+
for entry_idx in start..end {
568+
// Hash the key
569+
let key_array = keys.slice(entry_idx, 1);
570+
let mut single_hash = [*hash];
571+
$recursive_hash_method(&[key_array], &mut single_hash)?;
572+
*hash = single_hash[0];
573+
574+
// Hash the value
575+
let value_array = values.slice(entry_idx, 1);
576+
single_hash = [*hash];
577+
$recursive_hash_method(&[value_array], &mut single_hash)?;
578+
*hash = single_hash[0];
579+
}
580+
}
581+
}
582+
}
583+
}
428584
_ => {
429585
// This is internal because we should have caught this before.
430586
return Err(DataFusionError::Internal(format!(

native/spark-expr/src/hash_funcs/xxhash64.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ fn create_xxhash64_hashes<'a>(
129129
arrays,
130130
hashes_buffer,
131131
spark_compatible_xxhash64,
132-
create_xxhash64_hashes_dictionary
132+
create_xxhash64_hashes_dictionary,
133+
create_xxhash64_hashes
133134
);
134135
Ok(hashes_buffer)
135136
}

spark/src/main/scala/org/apache/comet/serde/hash.scala

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.comet.serde
2121

2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64}
23-
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StringType}
23+
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegerType, LongType, MapType, StringType, StructType}
2424

2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
2626
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarFunctionExprToProtoWithReturnType, serializeDataType, supportedDataType}
@@ -106,18 +106,31 @@ object CometSha1 extends CometExpressionSerde[Sha1] {
106106
private object HashUtils {
107107
def isSupportedType(expr: Expression): Boolean = {
108108
for (child <- expr.children) {
109-
child.dataType match {
110-
case dt: DecimalType if dt.precision > 18 =>
111-
// Spark converts decimals with precision > 18 into
112-
// Java BigDecimal before hashing
113-
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
114-
return false
115-
case dt if !supportedDataType(dt) =>
116-
withInfo(expr, s"Unsupported datatype $dt")
117-
return false
118-
case _ =>
109+
if (!isSupportedDataType(expr, child.dataType)) {
110+
return false
119111
}
120112
}
121113
true
122114
}
115+
116+
private def isSupportedDataType(expr: Expression, dt: DataType): Boolean = {
117+
dt match {
118+
case d: DecimalType if d.precision > 18 =>
119+
// Spark converts decimals with precision > 18 into
120+
// Java BigDecimal before hashing
121+
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
122+
false
123+
case s: StructType =>
124+
s.fields.forall(f => isSupportedDataType(expr, f.dataType))
125+
case a: ArrayType =>
126+
isSupportedDataType(expr, a.elementType)
127+
case m: MapType =>
128+
isSupportedDataType(expr, m.keyType) && isSupportedDataType(expr, m.valueType)
129+
case _ if !supportedDataType(dt, allowComplex = true) =>
130+
withInfo(expr, s"Unsupported datatype $dt")
131+
false
132+
case _ =>
133+
true
134+
}
135+
}
123136
}

0 commit comments

Comments
 (0)