Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ jobs:
value: |
org.apache.comet.CometExpressionSuite
org.apache.comet.CometExpressionCoverageSuite
org.apache.comet.CometHashExpressionSuite
org.apache.comet.CometTemporalExpressionSuite
org.apache.comet.CometArrayExpressionSuite
org.apache.comet.CometCastSuite
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ jobs:
value: |
org.apache.comet.CometExpressionSuite
org.apache.comet.CometExpressionCoverageSuite
org.apache.comet.CometHashExpressionSuite
org.apache.comet.CometTemporalExpressionSuite
org.apache.comet.CometArrayExpressionSuite
org.apache.comet.CometCastSuite
Expand Down
3 changes: 2 additions & 1 deletion native/spark-expr/src/hash_funcs/murmur3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ pub fn create_murmur3_hashes<'a>(
arrays,
hashes_buffer,
spark_compatible_murmur3_hash,
create_hashes_dictionary
create_hashes_dictionary,
create_murmur3_hashes
);
Ok(hashes_buffer)
}
Expand Down
158 changes: 157 additions & 1 deletion native/spark-expr/src/hash_funcs/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,62 @@ macro_rules! hash_array_decimal {
};
}

/// Hash a list array by recursively hashing each element.
/// For each row, we hash all elements in the list.
/// Spark hashes arrays by recursively hashing each element, where each
/// element's hash is computed using the previous element's hash as the seed.
/// This creates a chain: hash(elem_n, hash(elem_n-1, ... hash(elem_0, seed)...))
#[macro_export]
macro_rules! hash_list_array {
($array_type:ident, $offset_type:ty, $column: ident, $hashes: ident, $recursive_hash_method: ident) => {
let list_array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap_or_else(|| {
panic!(
"Failed to downcast column to {}. Actual data type: {:?}.",
stringify!($array_type),
$column.data_type()
)
});

let values = list_array.values();
let offsets = list_array.offsets();

if list_array.null_count() == 0 {
// Fast path: no nulls, skip null checks
for (row_idx, hash) in $hashes.iter_mut().enumerate() {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
let len = end - start;
// Hash each element in sequence, chaining the hash values
for elem_idx in 0..len {
let elem_array = values.slice(start + elem_idx, 1);
let mut single_hash = [*hash];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is single_hash an array?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single_hash is an array because the recursive hash method interface expects a slice of hashes and this allows us to reuse that rather than add another version of the code

$recursive_hash_method(&[elem_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
} else {
// Slow path: array has nulls, check each row
for (row_idx, hash) in $hashes.iter_mut().enumerate() {
if !list_array.is_null(row_idx) {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
let len = end - start;
// Hash each element in sequence, chaining the hash values
for elem_idx in 0..len {
let elem_array = values.slice(start + elem_idx, 1);
let mut single_hash = [*hash];
$recursive_hash_method(&[elem_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
}
}
};
}

/// Creates hash values for every row, based on the values in the
/// columns.
///
Expand All @@ -214,9 +270,10 @@ macro_rules! hash_array_decimal {
///
/// `hash_method` is the hash function to use.
/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input.
/// `recursive_hash_method` is the function to call for recursive hashing of complex types.
#[macro_export]
macro_rules! create_hashes_internal {
($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => {
($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident, $recursive_hash_method: ident) => {
use arrow::datatypes::{DataType, TimeUnit};
use arrow::array::{types::*, *};

Expand Down Expand Up @@ -425,6 +482,105 @@ macro_rules! create_hashes_internal {
)))
}
},
DataType::List(_) => {
$crate::hash_list_array!(ListArray, i32, col, $hashes_buffer, $recursive_hash_method);
}
DataType::LargeList(_) => {
$crate::hash_list_array!(LargeListArray, i64, col, $hashes_buffer, $recursive_hash_method);
}
DataType::FixedSizeList(_, size) => {
let list_array = col.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
let values = list_array.values();
let list_size = *size as usize;

if list_array.null_count() == 0 {
// Fast path: no nulls, skip null checks
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
let start = row_idx * list_size;
// Hash each element in sequence, chaining the hash values
for elem_idx in 0..list_size {
let elem_array = values.slice(start + elem_idx, 1);
let mut single_hash = [*hash];
$recursive_hash_method(&[elem_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
} else {
// Slow path: array has nulls, check each row
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
if !list_array.is_null(row_idx) {
let start = row_idx * list_size;
// Hash each element in sequence, chaining the hash values
for elem_idx in 0..list_size {
let elem_array = values.slice(start + elem_idx, 1);
let mut single_hash = [*hash];
$recursive_hash_method(&[elem_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
}
}
}
DataType::Struct(_) => {
let struct_array = col.as_any().downcast_ref::<StructArray>().unwrap();
// Hash each field of the struct - Spark hashes all fields recursively
let columns: Vec<ArrayRef> = struct_array.columns().to_vec();
if !columns.is_empty() {
$recursive_hash_method(&columns, $hashes_buffer)?;
}
}
DataType::Map(_, _) => {
let map_array = col.as_any().downcast_ref::<MapArray>().unwrap();
// For maps, Spark hashes by iterating through (key, value) pairs
// For each entry, hash the key then the value
let keys = map_array.keys();
let values = map_array.values();
let offsets = map_array.offsets();

if map_array.null_count() == 0 {
// Fast path: no nulls, skip null checks
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
// Hash each key-value pair in sequence
for entry_idx in start..end {
// Hash the key
let key_array = keys.slice(entry_idx, 1);
let mut single_hash = [*hash];
$recursive_hash_method(&[key_array], &mut single_hash)?;
*hash = single_hash[0];

// Hash the value
let value_array = values.slice(entry_idx, 1);
single_hash = [*hash];
$recursive_hash_method(&[value_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
} else {
// Slow path: array has nulls, check each row
for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() {
if !map_array.is_null(row_idx) {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
// Hash each key-value pair in sequence
for entry_idx in start..end {
// Hash the key
let key_array = keys.slice(entry_idx, 1);
let mut single_hash = [*hash];
$recursive_hash_method(&[key_array], &mut single_hash)?;
*hash = single_hash[0];

// Hash the value
let value_array = values.slice(entry_idx, 1);
single_hash = [*hash];
$recursive_hash_method(&[value_array], &mut single_hash)?;
*hash = single_hash[0];
}
}
}
}
}
_ => {
// This is internal because we should have caught this before.
return Err(DataFusionError::Internal(format!(
Expand Down
3 changes: 2 additions & 1 deletion native/spark-expr/src/hash_funcs/xxhash64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ fn create_xxhash64_hashes<'a>(
arrays,
hashes_buffer,
spark_compatible_xxhash64,
create_xxhash64_hashes_dictionary
create_xxhash64_hashes_dictionary,
create_xxhash64_hashes
);
Ok(hashes_buffer)
}
Expand Down
35 changes: 24 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/hash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64}
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegerType, LongType, MapType, StringType, StructType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarFunctionExprToProtoWithReturnType, serializeDataType, supportedDataType}
Expand Down Expand Up @@ -106,18 +106,31 @@ object CometSha1 extends CometExpressionSerde[Sha1] {
private object HashUtils {
def isSupportedType(expr: Expression): Boolean = {
for (child <- expr.children) {
child.dataType match {
case dt: DecimalType if dt.precision > 18 =>
// Spark converts decimals with precision > 18 into
// Java BigDecimal before hashing
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
return false
case dt if !supportedDataType(dt) =>
withInfo(expr, s"Unsupported datatype $dt")
return false
case _ =>
if (!isSupportedDataType(expr, child.dataType)) {
return false
}
}
true
}

private def isSupportedDataType(expr: Expression, dt: DataType): Boolean = {
dt match {
case d: DecimalType if d.precision > 18 =>
// Spark converts decimals with precision > 18 into
// Java BigDecimal before hashing
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
false
case s: StructType =>
s.fields.forall(f => isSupportedDataType(expr, f.dataType))
case a: ArrayType =>
isSupportedDataType(expr, a.elementType)
case m: MapType =>
isSupportedDataType(expr, m.keyType) && isSupportedDataType(expr, m.valueType)
case _ if !supportedDataType(dt, allowComplex = true) =>
withInfo(expr, s"Unsupported datatype $dt")
false
case _ =>
true
}
}
}
Loading
Loading