Skip to content

Commit 10b89a7

Browse files
committed
Merge apache/main into shuffle-complex-type-perf
2 parents d103497 + 9909535 commit 10b89a7

8 files changed

Lines changed: 98 additions & 2 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ use datafusion_spark::function::math::hex::SparkHex;
5656
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
5757
use datafusion_spark::function::string::char::CharFunc;
5858
use datafusion_spark::function::string::concat::SparkConcat;
59+
use datafusion_spark::function::string::luhn_check::SparkLuhnCheck;
5960
use datafusion_spark::function::string::space::SparkSpace;
6061
use futures::poll;
6162
use futures::stream::StreamExt;
@@ -403,6 +404,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
403404
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
404405
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
405406
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default()));
407+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLuhnCheck::default()));
406408
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSpace::default()));
407409
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitCount::default()));
408410
}
@@ -849,6 +851,13 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
849851
tracing_enabled != JNI_FALSE,
850852
|| {
851853
// SAFETY: JVM unsafe memory allocation is aligned with long.
854+
debug_assert!(address != 0, "sortRowPartitionsNative: null address");
855+
debug_assert!(size >= 0, "sortRowPartitionsNative: negative size {size}");
856+
debug_assert_eq!(
857+
(address as usize) % std::mem::align_of::<i64>(),
858+
0,
859+
"sortRowPartitionsNative: address not aligned to i64"
860+
);
852861
let array =
853862
unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) };
854863
array.rdxsort();

native/core/src/execution/shuffle/spark_unsafe/list.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ impl SparkUnsafeArray {
104104
pub fn new(addr: i64) -> Self {
105105
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
106106
// The first 8 bytes contain the element count as a little-endian i64.
107+
debug_assert!(addr != 0, "SparkUnsafeArray::new: null address");
107108
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
108109
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());
109110

@@ -138,6 +139,11 @@ impl SparkUnsafeArray {
138139
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
139140
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
140141
// index < num_elements, so word_offset is within the bitset region.
142+
debug_assert!(
143+
index < self.num_elements,
144+
"is_null_at: index {index} >= num_elements {}",
145+
self.num_elements
146+
);
141147
unsafe {
142148
let mask: i64 = 1i64 << (index & 0x3f);
143149
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;

native/core/src/execution/shuffle/spark_unsafe/map.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ impl SparkUnsafeMap {
3232
pub(crate) fn new(addr: i64, size: i32) -> Self {
3333
// SAFETY: addr points to valid Spark UnsafeMap data from the JVM.
3434
// The first 8 bytes contain the key array size as a little-endian i64.
35+
debug_assert!(addr != 0, "SparkUnsafeMap::new: null address");
36+
debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}");
3537
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
3638
let key_array_size = i64::from_le_bytes(slice.try_into().unwrap());
3739

native/core/src/execution/shuffle/spark_unsafe/row.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ pub trait SparkUnsafeObject {
9393
let addr = self.get_element_offset(index, 1);
9494
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
9595
// The caller ensures index is within bounds.
96+
debug_assert!(
97+
!addr.is_null(),
98+
"get_boolean: null pointer at index {index}"
99+
);
96100
unsafe { *addr != 0 }
97101
}
98102

@@ -101,6 +105,7 @@ pub trait SparkUnsafeObject {
101105
fn get_byte(&self, index: usize) -> i8 {
102106
let addr = self.get_element_offset(index, 1);
103107
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
108+
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
104109
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
105110
i8::from_le_bytes(slice.try_into().unwrap())
106111
}
@@ -110,6 +115,7 @@ pub trait SparkUnsafeObject {
110115
fn get_short(&self, index: usize) -> i16 {
111116
let addr = self.get_element_offset(index, 2);
112117
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
118+
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
113119
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
114120
i16::from_le_bytes(slice.try_into().unwrap())
115121
}
@@ -119,6 +125,7 @@ pub trait SparkUnsafeObject {
119125
fn get_int(&self, index: usize) -> i32 {
120126
let addr = self.get_element_offset(index, 4);
121127
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
128+
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
122129
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
123130
i32::from_le_bytes(slice.try_into().unwrap())
124131
}
@@ -128,6 +135,7 @@ pub trait SparkUnsafeObject {
128135
fn get_long(&self, index: usize) -> i64 {
129136
let addr = self.get_element_offset(index, 8);
130137
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
138+
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
131139
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
132140
i64::from_le_bytes(slice.try_into().unwrap())
133141
}
@@ -137,6 +145,7 @@ pub trait SparkUnsafeObject {
137145
fn get_float(&self, index: usize) -> f32 {
138146
let addr = self.get_element_offset(index, 4);
139147
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
148+
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
140149
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
141150
f32::from_le_bytes(slice.try_into().unwrap())
142151
}
@@ -146,6 +155,7 @@ pub trait SparkUnsafeObject {
146155
fn get_double(&self, index: usize) -> f64 {
147156
let addr = self.get_element_offset(index, 8);
148157
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
158+
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
149159
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
150160
f64::from_le_bytes(slice.try_into().unwrap())
151161
}
@@ -156,6 +166,11 @@ pub trait SparkUnsafeObject {
156166
let addr = self.get_row_addr() + offset as i64;
157167
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
158168
// Offset and length are read from the fixed-length portion of the row/array.
169+
debug_assert!(addr != 0, "get_string: null address at index {index}");
170+
debug_assert!(
171+
len >= 0,
172+
"get_string: negative length {len} at index {index}"
173+
);
159174
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };
160175

161176
from_utf8(slice).unwrap()
@@ -167,6 +182,11 @@ pub trait SparkUnsafeObject {
167182
let addr = self.get_row_addr() + offset as i64;
168183
// SAFETY: addr points to valid binary data within the variable-length region.
169184
// Offset and length are read from the fixed-length portion of the row/array.
185+
debug_assert!(addr != 0, "get_binary: null address at index {index}");
186+
debug_assert!(
187+
len >= 0,
188+
"get_binary: negative length {len} at index {index}"
189+
);
170190
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
171191
}
172192

@@ -175,6 +195,7 @@ pub trait SparkUnsafeObject {
175195
fn get_date(&self, index: usize) -> i32 {
176196
let addr = self.get_element_offset(index, 4);
177197
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
198+
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
178199
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
179200
i32::from_le_bytes(slice.try_into().unwrap())
180201
}
@@ -184,6 +205,10 @@ pub trait SparkUnsafeObject {
184205
fn get_timestamp(&self, index: usize) -> i64 {
185206
let addr = self.get_element_offset(index, 8);
186207
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
208+
debug_assert!(
209+
!addr.is_null(),
210+
"get_timestamp: null pointer at index {index}"
211+
);
187212
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
188213
i64::from_le_bytes(slice.try_into().unwrap())
189214
}
@@ -296,6 +321,7 @@ impl SparkUnsafeRow {
296321
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
297322
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
298323
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
324+
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
299325
unsafe {
300326
let mask: i64 = 1i64 << (index & 0x3f);
301327
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
@@ -310,6 +336,7 @@ impl SparkUnsafeRow {
310336
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
311337
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
312338
// Writing is safe because we have mutable access and the memory is owned by the JVM.
339+
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
313340
unsafe {
314341
let mask: i64 = 1i64 << (index & 0x3f);
315342
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;

native/core/src/execution/utils.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ impl SparkArrowConvert for ArrayData {
9797
}
9898
} else {
9999
// SAFETY: `array_ptr` and `schema_ptr` are aligned correctly.
100+
debug_assert_eq!(
101+
array_ptr.align_offset(array_align),
102+
0,
103+
"move_to_spark: array_ptr not aligned"
104+
);
105+
debug_assert_eq!(
106+
schema_ptr.align_offset(schema_align),
107+
0,
108+
"move_to_spark: schema_ptr not aligned"
109+
);
100110
unsafe {
101111
std::ptr::write(array_ptr, FFI_ArrowArray::new(self));
102112
std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);

native/core/src/jvm_bridge/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,19 @@ impl JVMClasses<'_> {
263263
}
264264

265265
pub fn get() -> &'static JVMClasses<'static> {
266+
debug_assert!(
267+
JVM_CLASSES.get().is_some(),
268+
"JVMClasses::get: not initialized"
269+
);
266270
unsafe { JVM_CLASSES.get_unchecked() }
267271
}
268272

269273
/// Gets the JNIEnv for the current thread.
270274
pub fn get_env() -> CometResult<AttachGuard<'static>> {
275+
debug_assert!(
276+
JAVA_VM.get().is_some(),
277+
"JVMClasses::get_env: JAVA_VM not initialized"
278+
);
271279
unsafe {
272280
let java_vm = JAVA_VM.get_unchecked();
273281
java_vm.attach_current_thread().map_err(|e| {

spark/src/main/scala/org/apache/comet/serde/statics.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.catalyst.expressions.Attribute
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils}
2323
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2424
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
2525

@@ -34,7 +34,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
3434
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
3535
Map(
3636
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
37-
"read_side_padding"))
37+
"read_side_padding"),
38+
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check"))
3839

3940
override def convert(
4041
expr: StaticInvoke,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- MinSparkVersion: 3.5
19+
-- ConfigMatrix: parquet.enable.dictionary=false,true
20+
21+
statement
22+
CREATE TABLE test_luhn(s string) USING parquet
23+
24+
statement
25+
INSERT INTO test_luhn VALUES ('79927398710'), ('79927398713'), ('1234567812345670'), ('0'), (''), ('abc'), (NULL)
26+
27+
-- column input
28+
query
29+
SELECT luhn_check(s) FROM test_luhn
30+
31+
-- literal arguments
32+
query
33+
SELECT luhn_check('79927398713'), luhn_check('79927398710'), luhn_check(''), luhn_check(NULL)

0 commit comments

Comments
 (0)