diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index cb0a519442..7662b219c4 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -26,7 +26,7 @@ import java.nio.channels.Channels import scala.jdk.CollectionConverters._ import org.apache.arrow.c.CDataDictionaryProvider -import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.ipc.ArrowStreamWriter @@ -288,7 +288,7 @@ object Utils extends CometTypeShim { _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector | - _: MapVector) => + _: MapVector | _: NullVector) => v.asInstanceOf[FieldVector] case _ => throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") diff --git a/native/core/src/execution/columnar_to_row.rs b/native/core/src/execution/columnar_to_row.rs index 78ab7637e8..66b53af2bd 100644 --- a/native/core/src/execution/columnar_to_row.rs +++ b/native/core/src/execution/columnar_to_row.rs @@ -41,16 +41,117 @@ use arrow::array::types::{ UInt64Type, UInt8Type, }; use arrow::array::*; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{ArrowNativeType, DataType, TimeUnit}; use std::sync::Arc; /// Maximum digits for decimal that can fit in a long (8 bytes). const MAX_LONG_DIGITS: u8 = 18; +/// Helper macro for downcasting arrays with consistent error messages. +macro_rules! downcast_array { + ($array:expr, $array_type:ty) => { + $array + .as_any() + .downcast_ref::<$array_type>() + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to downcast to {}, actual type: {:?}", + stringify!($array_type), + $array.data_type() + )) + }) + }; +} + +/// Macro to implement is_null for typed array enums. +/// Generates a complete match expression for all variants that have an array as first field. +macro_rules! impl_is_null { + ($self:expr, $row_idx:expr, [$($variant:ident),+ $(,)?]) => { + match $self { + $(Self::$variant(arr, ..) => arr.is_null($row_idx),)+ + } + }; + // Version with special handling for Null variant + ($self:expr, $row_idx:expr, null_always_true, [$($variant:ident),+ $(,)?]) => { + match $self { + Self::Null => true, + $(Self::$variant(arr, ..) => arr.is_null($row_idx),)+ + } + }; +} + +/// Macro to generate TypedElements::from_array match arms for primitive types. +macro_rules! typed_elements_from_primitive { + ($array:expr, $element_type:expr, $(($dt:pat, $variant:ident, $arr_type:ty)),+ $(,)?) => { + match $element_type { + $( + $dt => { + if let Some(arr) = $array.as_any().downcast_ref::<$arr_type>() { + return TypedElements::$variant(arr); + } + } + )+ + _ => {} + } + }; +} + +/// Macro for write_column_fixed_width arms - handles downcast + loop pattern. +macro_rules! write_fixed_column_primitive { + ($self:expr, $array:expr, $row_size:expr, $field_offset:expr, $num_rows:expr, + $arr_type:ty, $to_i64:expr) => {{ + let arr = downcast_array!($array, $arr_type)?; + for row_idx in 0..$num_rows { + if !arr.is_null(row_idx) { + let offset = row_idx * $row_size + $field_offset; + let value: i64 = $to_i64(arr.value(row_idx)); + $self.buffer[offset..offset + 8].copy_from_slice(&value.to_le_bytes()); + } + } + Ok(()) + }}; +} + +/// Macro for get_field_value arms - handles downcast + value extraction. +macro_rules! get_field_value_primitive { + ($array:expr, $row_idx:expr, $arr_type:ty, $to_i64:expr) => {{ + let arr = downcast_array!($array, $arr_type)?; + Ok($to_i64(arr.value($row_idx))) + }}; +} + +/// Macro for write_struct_to_buffer fixed-width field extraction. +macro_rules! extract_fixed_value { + ($column:expr, $row_idx:expr, $(($dt:pat, $arr_type:ty, $to_i64:expr)),+ $(,)?) => { + match $column.data_type() { + $( + $dt => { + let arr = downcast_array!($column, $arr_type)?; + Some($to_i64(arr.value($row_idx))) + } + )+ + _ => None, + } + }; +} + +/// Writes bytes to buffer with 8-byte alignment padding. +/// Returns the unpadded length. +#[inline] +fn write_bytes_padded(buffer: &mut Vec, bytes: &[u8]) -> usize { + let len = bytes.len(); + buffer.extend_from_slice(bytes); + let padding = round_up_to_8(len) - len; + buffer.extend(std::iter::repeat_n(0u8, padding)); + len +} + /// Pre-downcast array reference to avoid type dispatch in inner loops. /// This enum holds references to concrete array types, allowing direct access /// without repeated downcast_ref calls. enum TypedArray<'a> { + Null, Boolean(&'a BooleanArray), Int8(&'a Int8Array), Int16(&'a Int16Array), @@ -65,6 +166,7 @@ enum TypedArray<'a> { LargeString(&'a LargeStringArray), Binary(&'a BinaryArray), LargeBinary(&'a LargeBinaryArray), + FixedSizeBinary(&'a FixedSizeBinaryArray), Struct( &'a StructArray, arrow::datatypes::Fields, @@ -78,119 +180,46 @@ enum TypedArray<'a> { impl<'a> TypedArray<'a> { /// Pre-downcast an ArrayRef to a TypedArray. - fn from_array(array: &'a ArrayRef, schema_type: &DataType) -> CometResult { + fn from_array(array: &'a ArrayRef) -> CometResult { let actual_type = array.data_type(); match actual_type { - DataType::Boolean => Ok(TypedArray::Boolean( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to BooleanArray".to_string()) - })?, - )), - DataType::Int8 => Ok(TypedArray::Int8( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int8Array".to_string()) - })?, - )), - DataType::Int16 => Ok(TypedArray::Int16( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int16Array".to_string()) - })?, - )), - DataType::Int32 => Ok(TypedArray::Int32( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int32Array".to_string()) - })?, - )), - DataType::Int64 => Ok(TypedArray::Int64( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int64Array".to_string()) - })?, - )), - DataType::Float32 => Ok(TypedArray::Float32( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float32Array".to_string()) - })?, - )), - DataType::Float64 => Ok(TypedArray::Float64( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float64Array".to_string()) - })?, - )), - DataType::Date32 => Ok(TypedArray::Date32( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Date32Array".to_string()) - })?, - )), + DataType::Null => { + // Verify the array is actually a NullArray, but we don't need to store the reference + // since all values are null by definition + downcast_array!(array, NullArray)?; + Ok(TypedArray::Null) + } + DataType::Boolean => Ok(TypedArray::Boolean(downcast_array!(array, BooleanArray)?)), + DataType::Int8 => Ok(TypedArray::Int8(downcast_array!(array, Int8Array)?)), + DataType::Int16 => Ok(TypedArray::Int16(downcast_array!(array, Int16Array)?)), + DataType::Int32 => Ok(TypedArray::Int32(downcast_array!(array, Int32Array)?)), + DataType::Int64 => Ok(TypedArray::Int64(downcast_array!(array, Int64Array)?)), + DataType::Float32 => Ok(TypedArray::Float32(downcast_array!(array, Float32Array)?)), + DataType::Float64 => Ok(TypedArray::Float64(downcast_array!(array, Float64Array)?)), + DataType::Date32 => Ok(TypedArray::Date32(downcast_array!(array, Date32Array)?)), DataType::Timestamp(TimeUnit::Microsecond, _) => Ok(TypedArray::TimestampMicro( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal( - "Failed to downcast to TimestampMicrosecondArray".to_string(), - ) - })?, + downcast_array!(array, TimestampMicrosecondArray)?, )), DataType::Decimal128(p, _) => Ok(TypedArray::Decimal128( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Decimal128Array".to_string()) - })?, + downcast_array!(array, Decimal128Array)?, *p, )), - DataType::Utf8 => Ok(TypedArray::String( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to StringArray".to_string()) - })?, - )), - DataType::LargeUtf8 => Ok(TypedArray::LargeString( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeStringArray".to_string()) - })?, - )), - DataType::Binary => Ok(TypedArray::Binary( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to BinaryArray".to_string()) - })?, - )), - DataType::LargeBinary => Ok(TypedArray::LargeBinary( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeBinaryArray".to_string()) - })?, - )), + DataType::Utf8 => Ok(TypedArray::String(downcast_array!(array, StringArray)?)), + DataType::LargeUtf8 => Ok(TypedArray::LargeString(downcast_array!( + array, + LargeStringArray + )?)), + DataType::Binary => Ok(TypedArray::Binary(downcast_array!(array, BinaryArray)?)), + DataType::LargeBinary => Ok(TypedArray::LargeBinary(downcast_array!( + array, + LargeBinaryArray + )?)), + DataType::FixedSizeBinary(_) => Ok(TypedArray::FixedSizeBinary(downcast_array!( + array, + FixedSizeBinaryArray + )?)), DataType::Struct(fields) => { - let struct_arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to StructArray".to_string()) - })?; + let struct_arr = downcast_array!(array, StructArray)?; // Pre-downcast all struct fields once let typed_fields: Vec = fields .iter() @@ -202,27 +231,18 @@ impl<'a> TypedArray<'a> { Ok(TypedArray::Struct(struct_arr, fields.clone(), typed_fields)) } DataType::List(field) => Ok(TypedArray::List( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to ListArray".to_string()) - })?, + downcast_array!(array, ListArray)?, Arc::clone(field), )), DataType::LargeList(field) => Ok(TypedArray::LargeList( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeListArray".to_string()) - })?, + downcast_array!(array, LargeListArray)?, Arc::clone(field), )), DataType::Map(field, _) => Ok(TypedArray::Map( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to MapArray".to_string()) - })?, + downcast_array!(array, MapArray)?, Arc::clone(field), )), - DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array, schema_type.clone())), + DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array, actual_type.clone())), _ => Err(CometError::Internal(format!( "Unsupported data type for pre-downcast: {:?}", actual_type @@ -233,27 +253,33 @@ impl<'a> TypedArray<'a> { /// Check if the value at the given index is null. #[inline] fn is_null(&self, row_idx: usize) -> bool { - match self { - TypedArray::Boolean(arr) => arr.is_null(row_idx), - TypedArray::Int8(arr) => arr.is_null(row_idx), - TypedArray::Int16(arr) => arr.is_null(row_idx), - TypedArray::Int32(arr) => arr.is_null(row_idx), - TypedArray::Int64(arr) => arr.is_null(row_idx), - TypedArray::Float32(arr) => arr.is_null(row_idx), - TypedArray::Float64(arr) => arr.is_null(row_idx), - TypedArray::Date32(arr) => arr.is_null(row_idx), - TypedArray::TimestampMicro(arr) => arr.is_null(row_idx), - TypedArray::Decimal128(arr, _) => arr.is_null(row_idx), - TypedArray::String(arr) => arr.is_null(row_idx), - TypedArray::LargeString(arr) => arr.is_null(row_idx), - TypedArray::Binary(arr) => arr.is_null(row_idx), - TypedArray::LargeBinary(arr) => arr.is_null(row_idx), - TypedArray::Struct(arr, _, _) => arr.is_null(row_idx), - TypedArray::List(arr, _) => arr.is_null(row_idx), - TypedArray::LargeList(arr, _) => arr.is_null(row_idx), - TypedArray::Map(arr, _) => arr.is_null(row_idx), - TypedArray::Dictionary(arr, _) => arr.is_null(row_idx), - } + impl_is_null!( + self, + row_idx, + null_always_true, + [ + Boolean, + Int8, + Int16, + Int32, + Int64, + Float32, + Float64, + Date32, + TimestampMicro, + Decimal128, + String, + LargeString, + Binary, + LargeBinary, + FixedSizeBinary, + Struct, + List, + LargeList, + Map, + Dictionary + ] + ) } /// Get the fixed-width value as i64 (for types that fit in 8 bytes). @@ -291,7 +317,8 @@ impl<'a> TypedArray<'a> { #[inline] fn is_variable_length(&self) -> bool { match self { - TypedArray::Boolean(_) + TypedArray::Null + | TypedArray::Boolean(_) | TypedArray::Int8(_) | TypedArray::Int16(_) | TypedArray::Int32(_) @@ -309,44 +336,17 @@ impl<'a> TypedArray<'a> { fn write_variable_to_buffer(&self, buffer: &mut Vec, row_idx: usize) -> CometResult { match self { TypedArray::String(arr) => { - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } TypedArray::LargeString(arr) => { - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedArray::Binary(arr) => { - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedArray::LargeBinary(arr) => { - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } + TypedArray::Binary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), + TypedArray::LargeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), + TypedArray::FixedSizeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), TypedArray::Decimal128(arr, precision) if *precision > MAX_LONG_DIGITS => { let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } TypedArray::Struct(arr, fields, typed_fields) => { write_struct_to_buffer_typed(buffer, arr, row_idx, fields, typed_fields) @@ -394,6 +394,7 @@ enum TypedElements<'a> { LargeString(&'a LargeStringArray), Binary(&'a BinaryArray), LargeBinary(&'a LargeBinaryArray), + FixedSizeBinary(&'a FixedSizeBinaryArray), // For nested types, fall back to ArrayRef Other(&'a ArrayRef, DataType), } @@ -401,47 +402,26 @@ enum TypedElements<'a> { impl<'a> TypedElements<'a> { /// Create from an ArrayRef and element type. fn from_array(array: &'a ArrayRef, element_type: &DataType) -> Self { + // Try primitive types first using macro + typed_elements_from_primitive!( + array, + element_type, + (DataType::Boolean, Boolean, BooleanArray), + (DataType::Int8, Int8, Int8Array), + (DataType::Int16, Int16, Int16Array), + (DataType::Int32, Int32, Int32Array), + (DataType::Int64, Int64, Int64Array), + (DataType::Float32, Float32, Float32Array), + (DataType::Float64, Float64, Float64Array), + (DataType::Date32, Date32, Date32Array), + (DataType::Utf8, String, StringArray), + (DataType::LargeUtf8, LargeString, LargeStringArray), + (DataType::Binary, Binary, BinaryArray), + (DataType::LargeBinary, LargeBinary, LargeBinaryArray), + ); + + // Handle special cases that need extra processing match element_type { - DataType::Boolean => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Boolean(arr); - } - } - DataType::Int8 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Int8(arr); - } - } - DataType::Int16 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Int16(arr); - } - } - DataType::Int32 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Int32(arr); - } - } - DataType::Int64 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Int64(arr); - } - } - DataType::Float32 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Float32(arr); - } - } - DataType::Float64 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Float64(arr); - } - } - DataType::Date32 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Date32(arr); - } - } DataType::Timestamp(TimeUnit::Microsecond, _) => { if let Some(arr) = array.as_any().downcast_ref::() { return TypedElements::TimestampMicro(arr); @@ -452,24 +432,9 @@ impl<'a> TypedElements<'a> { return TypedElements::Decimal128(arr, *p); } } - DataType::Utf8 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::String(arr); - } - } - DataType::LargeUtf8 => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::LargeString(arr); - } - } - DataType::Binary => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::Binary(arr); - } - } - DataType::LargeBinary => { - if let Some(arr) = array.as_any().downcast_ref::() { - return TypedElements::LargeBinary(arr); + DataType::FixedSizeBinary(_) => { + if let Some(arr) = array.as_any().downcast_ref::() { + return TypedElements::FixedSizeBinary(arr); } } _ => {} @@ -510,23 +475,28 @@ impl<'a> TypedElements<'a> { /// Check if value at given index is null. #[inline] fn is_null_at(&self, idx: usize) -> bool { - match self { - TypedElements::Boolean(arr) => arr.is_null(idx), - TypedElements::Int8(arr) => arr.is_null(idx), - TypedElements::Int16(arr) => arr.is_null(idx), - TypedElements::Int32(arr) => arr.is_null(idx), - TypedElements::Int64(arr) => arr.is_null(idx), - TypedElements::Float32(arr) => arr.is_null(idx), - TypedElements::Float64(arr) => arr.is_null(idx), - TypedElements::Date32(arr) => arr.is_null(idx), - TypedElements::TimestampMicro(arr) => arr.is_null(idx), - TypedElements::Decimal128(arr, _) => arr.is_null(idx), - TypedElements::String(arr) => arr.is_null(idx), - TypedElements::LargeString(arr) => arr.is_null(idx), - TypedElements::Binary(arr) => arr.is_null(idx), - TypedElements::LargeBinary(arr) => arr.is_null(idx), - TypedElements::Other(arr, _) => arr.is_null(idx), - } + impl_is_null!( + self, + idx, + [ + Boolean, + Int8, + Int16, + Int32, + Int64, + Float32, + Float64, + Date32, + TimestampMicro, + Decimal128, + String, + LargeString, + Binary, + LargeBinary, + FixedSizeBinary, + Other + ] + ) } /// Check if this is a fixed-width type (value fits in 8-byte slot). @@ -572,55 +542,21 @@ impl<'a> TypedElements<'a> { } /// Write variable-length data to buffer. Returns length written (0 for fixed-width). - fn write_variable_value( - &self, - buffer: &mut Vec, - idx: usize, - base_offset: usize, - ) -> CometResult { + fn write_variable_value(&self, buffer: &mut Vec, idx: usize) -> CometResult { match self { - TypedElements::String(arr) => { - let bytes = arr.value(idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } + TypedElements::String(arr) => Ok(write_bytes_padded(buffer, arr.value(idx).as_bytes())), TypedElements::LargeString(arr) => { - let bytes = arr.value(idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedElements::Binary(arr) => { - let bytes = arr.value(idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedElements::LargeBinary(arr) => { - let bytes = arr.value(idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(idx).as_bytes())) } + TypedElements::Binary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), + TypedElements::LargeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), + TypedElements::FixedSizeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), TypedElements::Decimal128(arr, precision) if *precision > MAX_LONG_DIGITS => { let bytes = i128_to_spark_decimal_bytes(arr.value(idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } TypedElements::Other(arr, element_type) => { - write_nested_variable_to_buffer(buffer, element_type, arr, idx, base_offset) + write_nested_variable_to_buffer(buffer, element_type, arr, idx) } _ => Ok(0), // Fixed-width types } @@ -771,11 +707,7 @@ impl<'a> TypedElements<'a> { set_null_bit(buffer, null_bitset_start, i); } else { let bytes = i128_to_spark_decimal_bytes(arr.value(src_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, &bytes); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -790,12 +722,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx).as_bytes()); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -810,12 +737,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx).as_bytes()); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -830,12 +752,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx)); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -850,12 +767,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx)); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -872,13 +784,8 @@ impl<'a> TypedElements<'a> { set_null_bit(buffer, null_bitset_start, i); } else { let slot_offset = elements_start + i * element_size; - let var_len = write_nested_variable_to_buffer( - buffer, - element_type, - arr, - src_idx, - array_start, - )?; + let var_len = + write_nested_variable_to_buffer(buffer, element_type, arr, src_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); @@ -1035,6 +942,16 @@ impl ColumnarToRowContext { ))); } + // Unpack any dictionary arrays to their underlying value type + // This is needed because Parquet may return dictionary-encoded arrays + // even when the schema expects a specific type like Decimal128 + let arrays: Vec = arrays + .iter() + .zip(self.schema.iter()) + .map(|(arr, schema_type)| Self::maybe_cast_to_schema_type(arr, schema_type)) + .collect::>>()?; + let arrays = arrays.as_slice(); + // Clear previous data self.buffer.clear(); self.offsets.clear(); @@ -1052,8 +969,7 @@ impl ColumnarToRowContext { // Pre-downcast all arrays to avoid type dispatch in inner loop let typed_arrays: Vec = arrays .iter() - .zip(self.schema.iter()) - .map(|(arr, dt)| TypedArray::from_array(arr, dt)) + .map(TypedArray::from_array) .collect::>>()?; // Pre-compute variable-length column indices (once per batch, not per row) @@ -1079,6 +995,83 @@ impl ColumnarToRowContext { Ok((self.buffer.as_ptr(), &self.offsets, &self.lengths)) } + /// Casts an array to match the expected schema type if needed. + /// This handles cases where: + /// 1. Parquet returns dictionary-encoded arrays but the schema expects a non-dictionary type + /// 2. Parquet returns NullArray when all values are null, but the schema expects a typed array + /// 3. Parquet returns Int32/Int64 for small-precision decimals but schema expects Decimal128 + fn maybe_cast_to_schema_type( + array: &ArrayRef, + schema_type: &DataType, + ) -> CometResult { + let actual_type = array.data_type(); + + // If types already match, no cast needed + if actual_type == schema_type { + return Ok(Arc::clone(array)); + } + + match (actual_type, schema_type) { + (DataType::Dictionary(_, _), schema) + if !matches!(schema, DataType::Dictionary(_, _)) => + { + // Unpack dictionary if the schema type is not a dictionary + let options = CastOptions::default(); + cast_with_options(array, schema_type, &options).map_err(|e| { + CometError::Internal(format!( + "Failed to unpack dictionary array from {:?} to {:?}: {}", + actual_type, schema_type, e + )) + }) + } + (DataType::Null, _) => { + // Cast NullArray to the expected schema type + // This happens when all values in a column are null + let options = CastOptions::default(); + cast_with_options(array, schema_type, &options).map_err(|e| { + CometError::Internal(format!( + "Failed to cast NullArray to {:?}: {}", + schema_type, e + )) + }) + } + (DataType::Int32, DataType::Decimal128(precision, scale)) => { + // Parquet stores small-precision decimals as Int32 for efficiency. + // When COMET_USE_DECIMAL_128 is false, BatchReader produces these types. + // The Int32 value is already scaled (e.g., -1 means -0.01 for scale 2). + // We need to reinterpret (not cast) to Decimal128 preserving the value. + let int_array = array.as_any().downcast_ref::().ok_or_else(|| { + CometError::Internal("Failed to downcast to Int32Array".to_string()) + })?; + let decimal_array: Decimal128Array = int_array + .iter() + .map(|v| v.map(|x| x as i128)) + .collect::() + .with_precision_and_scale(*precision, *scale) + .map_err(|e| { + CometError::Internal(format!("Invalid decimal precision/scale: {}", e)) + })?; + Ok(Arc::new(decimal_array)) + } + (DataType::Int64, DataType::Decimal128(precision, scale)) => { + // Same as Int32 but for medium-precision decimals stored as Int64. + let int_array = array.as_any().downcast_ref::().ok_or_else(|| { + CometError::Internal("Failed to downcast to Int64Array".to_string()) + })?; + let decimal_array: Decimal128Array = int_array + .iter() + .map(|v| v.map(|x| x as i128)) + .collect::() + .with_precision_and_scale(*precision, *scale) + .map_err(|e| { + CometError::Internal(format!("Invalid decimal precision/scale: {}", e)) + })?; + Ok(Arc::new(decimal_array)) + } + _ => Ok(Arc::clone(array)), + } + } + /// Fast path for schemas with only fixed-width columns. /// Pre-allocates entire buffer and processes more efficiently. fn convert_fixed_width( @@ -1153,153 +1146,104 @@ impl ColumnarToRowContext { // Write non-null values using type-specific fast paths match data_type { DataType::Boolean => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to BooleanArray".to_string()) - })?; + // Boolean is special: writes single byte, not 8-byte i64 + let arr = downcast_array!(array, BooleanArray)?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { let offset = row_idx * row_size + field_offset_in_row; self.buffer[offset] = if arr.value(row_idx) { 1 } else { 0 }; } } + Ok(()) } - DataType::Int8 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int8Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx) as i64).to_le_bytes()); - } - } - } - DataType::Int16 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int16Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx) as i64).to_le_bytes()); - } - } - } - DataType::Int32 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int32Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx) as i64).to_le_bytes()); - } - } - } - DataType::Int64 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int64Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&arr.value(row_idx).to_le_bytes()); - } - } - } - DataType::Float32 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float32Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx).to_bits() as i64).to_le_bytes()); - } - } - } - DataType::Float64 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float64Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx).to_bits() as i64).to_le_bytes()); - } - } - } - DataType::Date32 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Date32Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx) as i64).to_le_bytes()); - } - } - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal( - "Failed to downcast to TimestampMicrosecondArray".to_string(), - ) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&arr.value(row_idx).to_le_bytes()); - } - } - } + DataType::Int8 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Int8Array, + |v: i8| v as i64 + ), + DataType::Int16 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Int16Array, + |v: i16| v as i64 + ), + DataType::Int32 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Int32Array, + |v: i32| v as i64 + ), + DataType::Int64 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Int64Array, + |v: i64| v + ), + DataType::Float32 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Float32Array, + |v: f32| v.to_bits() as i64 + ), + DataType::Float64 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Float64Array, + |v: f64| v.to_bits() as i64 + ), + DataType::Date32 => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Date32Array, + |v: i32| v as i64 + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + TimestampMicrosecondArray, + |v: i64| v + ), DataType::Decimal128(precision, _) if *precision <= MAX_LONG_DIGITS => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Decimal128Array".to_string()) - })?; - for row_idx in 0..num_rows { - if !arr.is_null(row_idx) { - let offset = row_idx * row_size + field_offset_in_row; - self.buffer[offset..offset + 8] - .copy_from_slice(&(arr.value(row_idx) as i64).to_le_bytes()); - } - } - } - _ => { - return Err(CometError::Internal(format!( - "Unexpected non-fixed-width type in fast path: {:?}", - data_type - ))); + write_fixed_column_primitive!( + self, + array, + row_size, + field_offset_in_row, + num_rows, + Decimal128Array, + |v: i128| v as i64 + ) } + _ => Err(CometError::Internal(format!( + "Unexpected non-fixed-width type in fast path: {:?}", + data_type + ))), } - - Ok(()) } /// Writes a complete row using pre-downcast TypedArrays. @@ -1386,112 +1330,31 @@ fn get_field_value(data_type: &DataType, array: &ArrayRef, row_idx: usize) -> Co match actual_type { DataType::Boolean => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to BooleanArray for type {:?}", - actual_type - )) - })?; + let arr = downcast_array!(array, BooleanArray)?; Ok(if arr.value(row_idx) { 1i64 } else { 0i64 }) } - DataType::Int8 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Int8Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx) as i64) - } + DataType::Int8 => get_field_value_primitive!(array, row_idx, Int8Array, |v: i8| v as i64), DataType::Int16 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Int16Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx) as i64) + get_field_value_primitive!(array, row_idx, Int16Array, |v: i16| v as i64) } DataType::Int32 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Int32Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx) as i64) - } - DataType::Int64 => { - let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Int64Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx)) + get_field_value_primitive!(array, row_idx, Int32Array, |v: i32| v as i64) } + DataType::Int64 => get_field_value_primitive!(array, row_idx, Int64Array, |v: i64| v), DataType::Float32 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Float32Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx).to_bits() as i64) + get_field_value_primitive!(array, row_idx, Float32Array, |v: f32| v.to_bits() as i64) } DataType::Float64 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Float64Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx).to_bits() as i64) + get_field_value_primitive!(array, row_idx, Float64Array, |v: f64| v.to_bits() as i64) } DataType::Date32 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Date32Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx) as i64) + get_field_value_primitive!(array, row_idx, Date32Array, |v: i32| v as i64) } DataType::Timestamp(TimeUnit::Microsecond, _) => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to TimestampMicrosecondArray for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx)) + get_field_value_primitive!(array, row_idx, TimestampMicrosecondArray, |v: i64| v) } DataType::Decimal128(precision, _) if *precision <= MAX_LONG_DIGITS => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Decimal128Array for type {:?}", - actual_type - )) - })?; - Ok(arr.value(row_idx) as i64) + get_field_value_primitive!(array, row_idx, Decimal128Array, |v: i128| v as i64) } // Variable-length types use placeholder (will be overwritten by get_variable_length_data) DataType::Utf8 @@ -1605,72 +1468,26 @@ fn write_dictionary_to_buffer_with_key( match value_type { DataType::Utf8 => { - let string_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to StringArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = string_values.value(key_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let string_values = downcast_array!(values, StringArray)?; + Ok(write_bytes_padded( + buffer, + string_values.value(key_idx).as_bytes(), + )) } DataType::LargeUtf8 => { - let string_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to LargeStringArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = string_values.value(key_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let string_values = downcast_array!(values, LargeStringArray)?; + Ok(write_bytes_padded( + buffer, + string_values.value(key_idx).as_bytes(), + )) } DataType::Binary => { - let binary_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to BinaryArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = binary_values.value(key_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let binary_values = downcast_array!(values, BinaryArray)?; + Ok(write_bytes_padded(buffer, binary_values.value(key_idx))) } DataType::LargeBinary => { - let binary_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to LargeBinaryArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = binary_values.value(key_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let binary_values = downcast_array!(values, LargeBinaryArray)?; + Ok(write_bytes_padded(buffer, binary_values.value(key_idx))) } _ => Err(CometError::Internal(format!( "Unsupported dictionary value type for direct buffer write: {:?}", @@ -1781,7 +1598,7 @@ fn write_struct_to_buffer_typed( buffer[field_offset..field_offset + 8].copy_from_slice(&value.to_le_bytes()); } else { // Variable-length field - use pre-downcast writer - let var_len = typed_field.write_variable_value(buffer, row_idx, struct_start)?; + let var_len = typed_field.write_variable_value(buffer, row_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); let data_offset = buffer.len() - padded_len - struct_start; @@ -1829,51 +1646,35 @@ fn write_struct_to_buffer( let field_offset = struct_start + nested_bitset_width + field_idx * 8; // Inline type dispatch for fixed-width types (most common case) - let value = match data_type { - DataType::Boolean => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(if arr.value(row_idx) { 1i64 } else { 0i64 }) - } - DataType::Int8 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx) as i64) - } - DataType::Int16 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx) as i64) - } - DataType::Int32 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx) as i64) - } - DataType::Int64 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx)) - } - DataType::Float32 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some((arr.value(row_idx).to_bits() as i32) as i64) - } - DataType::Float64 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx).to_bits() as i64) - } - DataType::Date32 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some(arr.value(row_idx) as i64) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let arr = column - .as_any() - .downcast_ref::() - .unwrap(); - Some(arr.value(row_idx)) - } - DataType::Decimal128(p, _) if *p <= MAX_LONG_DIGITS => { - let arr = column.as_any().downcast_ref::().unwrap(); + let value: Option = extract_fixed_value!( + column, + row_idx, + (DataType::Boolean, BooleanArray, |v: bool| if v { + 1i64 + } else { + 0i64 + }), + (DataType::Int8, Int8Array, |v: i8| v as i64), + (DataType::Int16, Int16Array, |v: i16| v as i64), + (DataType::Int32, Int32Array, |v: i32| v as i64), + (DataType::Int64, Int64Array, |v: i64| v), + (DataType::Float32, Float32Array, |v: f32| v.to_bits() as i64), + (DataType::Float64, Float64Array, |v: f64| v.to_bits() as i64), + (DataType::Date32, Date32Array, |v: i32| v as i64), + ( + DataType::Timestamp(TimeUnit::Microsecond, _), + TimestampMicrosecondArray, + |v: i64| v + ), + ); + // Handle Decimal128 with precision guard separately + let value: Option = match (value, data_type) { + (Some(v), _) => Some(v), + (None, DataType::Decimal128(p, _)) if *p <= MAX_LONG_DIGITS => { + let arr = downcast_array!(column, Decimal128Array)?; Some(arr.value(row_idx) as i64) } - _ => None, // Variable-length type + _ => None, }; if let Some(v) = value { @@ -1881,13 +1682,7 @@ fn write_struct_to_buffer( buffer[field_offset..field_offset + 8].copy_from_slice(&v.to_le_bytes()); } else { // Variable-length field - let var_len = write_nested_variable_to_buffer( - buffer, - data_type, - column, - row_idx, - struct_start, - )?; + let var_len = write_nested_variable_to_buffer(buffer, data_type, column, row_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); let data_offset = buffer.len() - padded_len - struct_start; @@ -2016,136 +1811,45 @@ fn write_nested_variable_to_buffer( data_type: &DataType, array: &ArrayRef, row_idx: usize, - _base_offset: usize, ) -> CometResult { let actual_type = array.data_type(); match actual_type { DataType::Utf8 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to StringArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, StringArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } DataType::LargeUtf8 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeStringArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, LargeStringArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } DataType::Binary => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to BinaryArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, BinaryArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx))) } DataType::LargeBinary => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeBinaryArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, LargeBinaryArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx))) } DataType::Decimal128(precision, _) if *precision > MAX_LONG_DIGITS => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Decimal128Array for type {:?}", - actual_type - )) - })?; + let arr = downcast_array!(array, Decimal128Array)?; let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } DataType::Struct(fields) => { - let struct_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to StructArray for type {:?}", - actual_type - )) - })?; + let struct_array = downcast_array!(array, StructArray)?; write_struct_to_buffer(buffer, struct_array, row_idx, fields) } DataType::List(field) => { - let list_array = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to ListArray for type {:?}", - actual_type - )) - })?; + let list_array = downcast_array!(array, ListArray)?; write_list_to_buffer(buffer, list_array, row_idx, field) } DataType::LargeList(field) => { - let list_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeListArray for type {:?}", - actual_type - )) - })?; + let list_array = downcast_array!(array, LargeListArray)?; write_large_list_to_buffer(buffer, list_array, row_idx, field) } DataType::Map(field, _) => { - let map_array = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to MapArray for type {:?}", - actual_type - )) - })?; + let map_array = downcast_array!(array, MapArray)?; write_map_to_buffer(buffer, map_array, row_idx, field) } DataType::Dictionary(key_type, value_type) => { @@ -2748,4 +2452,163 @@ mod tests { assert_eq!(value, i as i32, "element {} should be {}", i, i); } } + + #[test] + fn test_convert_fixed_size_binary_array() { + // FixedSizeBinary(3) - each value is exactly 3 bytes + let schema = vec![DataType::FixedSizeBinary(3)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let array: ArrayRef = Arc::new(FixedSizeBinaryArray::from(vec![ + Some(&[1u8, 2, 3][..]), + Some(&[4u8, 5, 6][..]), + None, // Test null handling + ])); + let arrays = vec![array]; + + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Row 0: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24 + // Row 1: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24 + // Row 2: 8 (bitset) + 8 (field slot) = 16 (null, no variable data) + assert_eq!(lengths[0], 24); + assert_eq!(lengths[1], 24); + assert_eq!(lengths[2], 16); + + // Verify the data is correct for non-null rows + unsafe { + let row0 = + std::slice::from_raw_parts(ptr.add(offsets[0] as usize), lengths[0] as usize); + // Variable data starts at offset 16 (8 bitset + 8 field slot) + assert_eq!(&row0[16..19], &[1u8, 2, 3]); + + let row1 = + std::slice::from_raw_parts(ptr.add(offsets[1] as usize), lengths[1] as usize); + assert_eq!(&row1[16..19], &[4u8, 5, 6]); + } + } + + #[test] + fn test_convert_dictionary_decimal_array() { + // Test that dictionary-encoded decimals are correctly unpacked and converted + // This tests the fix for casting to schema_type instead of value_type + use arrow::datatypes::Int8Type; + + // Create a dictionary array with Decimal128 values + // Values: [-0.01, -0.02, -0.03] represented as [-1, -2, -3] with scale 2 + let values = Decimal128Array::from(vec![-1i128, -2, -3]) + .with_precision_and_scale(5, 2) + .unwrap(); + + // Keys: [0, 1, 2, 0, 1, 2] - each value appears twice + let keys = Int8Array::from(vec![0i8, 1, 2, 0, 1, 2]); + + let dict_array: ArrayRef = + Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()); + + // Schema expects Decimal128(5, 2) - not a dictionary type + let schema = vec![DataType::Decimal128(5, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![dict_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 6).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 6); + assert_eq!(lengths.len(), 6); + + // Verify the decimal values are correct (not doubled or otherwise corrupted) + // Fixed-width decimal is stored directly in the 8-byte field slot + unsafe { + for (i, expected) in [-1i64, -2, -3, -1, -2, -3].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } + + #[test] + fn test_convert_int32_to_decimal128() { + // Test that Int32 arrays are correctly cast to Decimal128 when schema expects Decimal128. + // This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces + // Int32 for small-precision decimals. + + // Create an Int32 array representing decimals: [-1, -2, -3] which at scale 2 means + // [-0.01, -0.02, -0.03] + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![-1i32, -2, -3])); + + // Schema expects Decimal128(5, 2) + let schema = vec![DataType::Decimal128(5, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![int_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Verify the decimal values are correct after casting + // Fixed-width decimal is stored directly in the 8-byte field slot + unsafe { + for (i, expected) in [-1i64, -2, -3].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } + + #[test] + fn test_convert_int64_to_decimal128() { + // Test that Int64 arrays are correctly cast to Decimal128 when schema expects Decimal128. + // This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces + // Int64 for medium-precision decimals. + + // Create an Int64 array representing decimals + let int_array: ArrayRef = Arc::new(Int64Array::from(vec![-100i64, -200, -300])); + + // Schema expects Decimal128(10, 2) + let schema = vec![DataType::Decimal128(10, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![int_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Verify the decimal values are correct after casting + unsafe { + for (i, expected) in [-100i64, -200, -300].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } } diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..d1c3b07677 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -22,13 +22,14 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometScanExec, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometConf +import org.apache.comet.parquet.CometParquetScan // This rule is responsible for eliminating redundant transitions between row-based and // columnar-based operators for Comet. Currently, three potential redundant transitions are: @@ -139,7 +140,8 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa private def createColumnarToRowExec(child: SparkPlan): SparkPlan = { val schema = child.schema val useNative = CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.get() && - CometNativeColumnarToRowExec.supportsSchema(schema) + CometNativeColumnarToRowExec.supportsSchema(schema) && + !hasScanUsingMutableBuffers(child) if (useNative) { CometNativeColumnarToRowExec(child) @@ -147,4 +149,30 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa CometColumnarToRowExec(child) } } + + /** + * Checks if the plan contains a scan that uses mutable buffers. Native C2R is not compatible + * with such scans because the buffers may be modified after C2R reads them. + * + * This includes: + * - CometScanExec with native_comet scan implementation (V1 path) - uses BatchReader + * - CometScanExec with native_iceberg_compat and partition columns - uses + * ConstantColumnReader + * - CometBatchScanExec with CometParquetScan (V2 Parquet path) - uses BatchReader + */ + private def hasScanUsingMutableBuffers(op: SparkPlan): Boolean = { + op match { + case c: QueryStageExec => hasScanUsingMutableBuffers(c.plan) + case c: ReusedExchangeExec => hasScanUsingMutableBuffers(c.child) + case _ => + op.exists { + case scan: CometScanExec => + scan.scanImpl == CometConf.SCAN_NATIVE_COMET || + (scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT && + scan.relation.partitionSchema.nonEmpty) + case scan: CometBatchScanExec => scan.scan.isInstanceOf[CometParquetScan] + case _ => false + } + } + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala index 93526573c0..a520098ed1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala @@ -19,15 +19,25 @@ package org.apache.spark.sql.comet -import org.apache.spark.TaskContext +import java.util.UUID +import java.util.concurrent.{Future, TimeoutException, TimeUnit} + +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{broadcast, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkFatalException, Utils} import org.apache.comet.{CometConf, NativeColumnarToRowConverter} @@ -64,6 +74,116 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), "convertTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time in conversion")) + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + @transient + private val timeout: Long = conf.broadcastTimeout + + private val runId: UUID = UUID.randomUUID + + private lazy val cometBroadcastExchange = findCometBroadcastExchange(child) + + @transient + lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + session, + CometBroadcastExchangeExec.executionContext) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup( + runId.toString, + s"CometNativeColumnarToRow broadcast exchange (runId $runId)", + interruptOnCancel = true) + + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + val localSchema = this.schema + val batchSize = CometConf.COMET_BATCH_SIZE.get() + val broadcastColumnar = child.executeBroadcast() + val serializedBatches = + broadcastColumnar.value.asInstanceOf[Array[org.apache.spark.util.io.ChunkedByteBuffer]] + + // Use native converter to convert columnar data to rows + val converter = new NativeColumnarToRowConverter(localSchema, batchSize) + try { + val rows = serializedBatches.iterator + .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName)) + .flatMap { batch => + numInputBatches += 1 + numOutputRows += batch.numRows() + val result = converter.convert(batch) + // Wrap iterator to close batch after consumption + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val hasMore = result.hasNext + if (!hasMore) { + batch.close() + } + hasMore + } + override def next(): InternalRow = result.next() + } + } + + val mode = cometBroadcastExchange.get.mode + val relation = mode.transform(rows, Some(numOutputRows.value)) + val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.trySuccess(broadcasted) + broadcasted + } finally { + converter.close() + } + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + val ex = new SparkFatalException(oe) + promise.tryFailure(ex) + throw ex + case e if !NonFatal(e) => + val ex = new SparkFatalException(e) + promise.tryFailure(ex) + throw ex + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + if (cometBroadcastExchange.isEmpty) { + throw new SparkException( + "CometNativeColumnarToRowExec only supports doExecuteBroadcast when child contains a " + + "CometBroadcastExchange, but got " + child) + } + + try { + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) + } + } + + private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = { + op match { + case b: CometBroadcastExchangeExec => Some(b) + case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan) + case b: ReusedExchangeExec => findCometBroadcastExchange(b.child) + case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange)) + } + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") @@ -91,7 +211,17 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) val result = converter.convert(batch) convertTime += System.nanoTime() - startTime - result + // Wrap iterator to close batch after consumption + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val hasMore = result.hasNext + if (!hasMore) { + batch.close() + } + hasMore + } + override def next(): InternalRow = result.next() + } } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e0a5c43aef..fe5ea77a89 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -30,8 +30,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps -import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} -import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.comet.{CometNativeColumnarToRowExec, CometProjectExec} +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1020,11 +1020,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val query = sql(s"select cast(id as string) from $table") val (_, cometPlan) = checkSparkAnswerAndOperator(query) val project = cometPlan - .asInstanceOf[WholeStageCodegenExec] - .child - .asInstanceOf[CometColumnarToRowExec] - .child - .asInstanceOf[InputAdapter] + .asInstanceOf[CometNativeColumnarToRowExec] .child .asInstanceOf[CometProjectExec] val id = project.expressions.head diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..696a12d4a2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, He import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -864,9 +864,11 @@ class CometExecSuite extends CometTestBase { checkSparkAnswerAndOperator(df) // Before AQE: one CometBroadcastExchange, no CometColumnarToRow - var columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect { - case s: CometColumnarToRowExec => s - } + var columnarToRowExec: Seq[SparkPlan] = + stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometColumnarToRowExec => s + case s: CometNativeColumnarToRowExec => s + } assert(columnarToRowExec.isEmpty) // Disable CometExecRule after the initial plan is generated. The CometSortMergeJoin and @@ -880,14 +882,25 @@ class CometExecSuite extends CometTestBase { // After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark // BroadcastHashJoin. val plan = stripAQEPlan(df.queryExecution.executedPlan) - columnarToRowExec = plan.collect { case s: CometColumnarToRowExec => - s + columnarToRowExec = plan.collect { + case s: CometColumnarToRowExec => s + case s: CometNativeColumnarToRowExec => s } assert(columnarToRowExec.length == 1) - // This ColumnarToRowExec should be the immediate child of BroadcastHashJoinExec - val parent = plan.find(_.children.contains(columnarToRowExec.head)) - assert(parent.get.isInstanceOf[BroadcastHashJoinExec]) + // This ColumnarToRowExec should be a descendant of BroadcastHashJoinExec (possibly + // wrapped by InputAdapter for codegen). + val broadcastJoins = plan.collect { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.nonEmpty, s"Expected BroadcastHashJoinExec in plan:\n$plan") + val hasC2RDescendant = broadcastJoins.exists { join => + join.find { + case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec => true + case _ => false + }.isDefined + } + assert( + hasC2RDescendant, + "BroadcastHashJoinExec should have a columnar-to-row descendant") // There should be a CometBroadcastExchangeExec under CometColumnarToRowExec val broadcastQueryStage = diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 89249240cf..8a2f8af5c2 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -80,6 +80,7 @@ abstract class CometTestBase conf.set(CometConf.COMET_ONHEAP_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.key, "true") conf.set(CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.key, "true") conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true") conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala index 7caac71351..c8c4baff4a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala @@ -46,7 +46,7 @@ trait CometPlanChecker { case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec | _: CometIcebergNativeScanExec => case _: CometSinkPlaceHolder | _: CometScanWrapper => - case _: CometColumnarToRowExec => + case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec => case _: CometSparkToColumnarExec => case _: CometExec | _: CometShuffleExchangeExec => case _: CometBroadcastExchangeExec =>