Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ public URI pathUri() throws URISyntaxException {
protected static final BufferAllocator ALLOCATOR = new RootAllocator();
private NativeUtil nativeUtil = new NativeUtil();

/**
* Thread-local holding the native BatchContext handle of the current reader. Set during
* nextBatch() in passthrough mode so that CometBatchIterator.advancePassthrough() can retrieve
* it.
*/
public static final ThreadLocal<Long> CURRENT_READER_HANDLE = ThreadLocal.withInitial(() -> 0L);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude really loves thread local variables and I always need to sit and think for a long time how it could blow up in our faces.


protected Configuration conf;
protected int capacity;
protected boolean isCaseSensitive;
Expand Down Expand Up @@ -888,6 +895,10 @@ private boolean containsPath(Type parquetType, String[] path, int depth) {
return false;
}

public long getHandle() {
return this.handle;
}

public void setSparkSchema(StructType schema) {
this.sparkSchema = schema;
}
Expand Down Expand Up @@ -956,6 +967,11 @@ public boolean nextBatch() throws IOException {

if (batchSize == 0) return false;

// Set the thread-local handle so CometBatchIterator.advancePassthrough() can retrieve it.
// This is always set after a successful loadNextBatch() regardless of whether passthrough
// mode will be used — the Rust ScanExec decides whether to use it.
CURRENT_READER_HANDLE.set(this.handle);

long totalDecodeTime = 0, totalLoadTime = 0;
for (int i = 0; i < columnReaders.length; i++) {
AbstractColumnReader reader = columnReaders[i];
Expand Down
123 changes: 117 additions & 6 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::execution::operators::{copy_array, copy_or_unpack_array, CopyMode};
use crate::parquet::get_batch_context;
use crate::{
errors::CometError,
execution::{
Expand Down Expand Up @@ -79,6 +80,11 @@ pub struct ScanExec {
baseline_metrics: BaselineMetrics,
/// Whether native code can assume ownership of batches that it receives
arrow_ffi_safe: bool,
/// When true, data columns are read directly from the native reader's
/// BatchContext instead of through JVM FFI (zero-copy).
native_batch_passthrough: bool,
/// Number of data columns from native reader. Remaining are partition columns.
num_data_columns: usize,
}

impl ScanExec {
Expand All @@ -88,6 +94,8 @@ impl ScanExec {
input_source_description: &str,
data_types: Vec<DataType>,
arrow_ffi_safe: bool,
native_batch_passthrough: bool,
num_data_columns: usize,
) -> Result<Self, CometError> {
let metrics_set = ExecutionPlanMetricsSet::default();
let baseline_metrics = BaselineMetrics::new(&metrics_set, 0);
Expand Down Expand Up @@ -115,6 +123,8 @@ impl ScanExec {
baseline_metrics,
schema,
arrow_ffi_safe,
native_batch_passthrough,
num_data_columns,
})
}

Expand Down Expand Up @@ -143,12 +153,21 @@ impl ScanExec {

let mut current_batch = self.batch.try_lock().unwrap();
if current_batch.is_none() {
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.data_types.len(),
self.arrow_ffi_safe,
)?;
let next_batch = if self.native_batch_passthrough {
ScanExec::get_next_passthrough(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.num_data_columns,
self.data_types.len(),
)?
} else {
ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.data_types.len(),
self.arrow_ffi_safe,
)?
};
*current_batch = Some(next_batch);
}

Expand Down Expand Up @@ -259,6 +278,98 @@ impl ScanExec {
Ok(InputBatch::new(inputs, Some(actual_num_rows)))
}

/// Passthrough mode: data columns are read directly from native BatchContext
/// (zero-copy Arc::clone). Only partition columns are imported from JVM via FFI.
fn get_next_passthrough(
exec_context_id: i64,
iter: &JObject,
num_data_cols: usize,
num_total_cols: usize,
) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
return Ok(InputBatch::EOF);
}

if iter.is_null() {
return Err(CometError::from(ExecutionError::GeneralError(format!(
"Null batch iterator object. Plan id: {exec_context_id}"
))));
}

let mut env = JVMClasses::get_env()?;

// 1. Advance reader; get native batch handle (data stays in Rust)
let handle: i64 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).advance_passthrough() -> i64)?
};
if handle == 0 {
return Ok(InputBatch::EOF);
}

// 2. Get data columns from native BatchContext (zero-copy)
let context = get_batch_context(handle)?;
let batch = context.current_batch.as_ref().ok_or_else(|| {
CometError::from(ExecutionError::GeneralError(
"No current batch in BatchContext".to_string(),
))
})?;

let num_rows = batch.num_rows();
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_total_cols);

for i in 0..num_data_cols {
// Zero-copy: just increment the Arc reference count
inputs.push(Arc::clone(batch.column(i)));
}

// 3. Import partition columns from JVM FFI (if any)
let num_partition_cols = num_total_cols - num_data_cols;
if num_partition_cols > 0 {
let mut array_addrs = Vec::with_capacity(num_partition_cols);
let mut schema_addrs = Vec::with_capacity(num_partition_cols);

for _ in 0..num_partition_cols {
let arrow_array = Rc::new(FFI_ArrowArray::empty());
let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
array_addrs.push(Rc::into_raw(arrow_array) as i64);
schema_addrs.push(Rc::into_raw(arrow_schema) as i64);
}

let long_array_addrs = env.new_long_array(num_partition_cols as jsize)?;
let long_schema_addrs = env.new_long_array(num_partition_cols as jsize)?;
env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;

let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);
let num_data_cols_jint = num_data_cols as i32;

let _part_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next_partition_columns_only(
JValueGen::Object(array_obj.as_ref()),
JValueGen::Object(schema_obj.as_ref()),
JValueGen::Int(num_data_cols_jint)
) -> i32)?
};

for i in 0..num_partition_cols {
let array_data = ArrayData::from_spark((array_addrs[i], schema_addrs[i]))?;
let array = make_array(array_data);
// Partition columns come from JVM mutable buffers, must copy
inputs.push(copy_array(&array));

unsafe {
Rc::from_raw(array_addrs[i] as *const FFI_ArrowArray);
Rc::from_raw(schema_addrs[i] as *const FFI_ArrowSchema);
}
}
}

Ok(InputBatch::new(inputs, Some(num_rows)))
}

/// Allocates Arrow FFI structures and calls JNI to get the next batch data.
/// Returns the number of rows and the allocated array/schema addresses.
fn allocate_and_fetch_batch(
Expand Down
12 changes: 12 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,8 @@ impl PhysicalPlanner {
&scan.source,
data_types,
scan.arrow_ffi_safe,
scan.native_batch_passthrough,
scan.num_data_columns as usize,
)?;

Ok((
Expand Down Expand Up @@ -3473,6 +3475,8 @@ mod tests {
}],
source: "".to_string(),
arrow_ffi_safe: false,
native_batch_passthrough: false,
num_data_columns: 0,
})),
};

Expand Down Expand Up @@ -3547,6 +3551,8 @@ mod tests {
}],
source: "".to_string(),
arrow_ffi_safe: false,
native_batch_passthrough: false,
num_data_columns: 0,
})),
};

Expand Down Expand Up @@ -3754,6 +3760,8 @@ mod tests {
fields: vec![create_proto_datatype()],
source: "".to_string(),
arrow_ffi_safe: false,
native_batch_passthrough: false,
num_data_columns: 0,
})),
}
}
Expand Down Expand Up @@ -3797,6 +3805,8 @@ mod tests {
],
source: "".to_string(),
arrow_ffi_safe: false,
native_batch_passthrough: false,
num_data_columns: 0,
})),
};

Expand Down Expand Up @@ -3913,6 +3923,8 @@ mod tests {
],
source: "".to_string(),
arrow_ffi_safe: false,
native_batch_passthrough: false,
num_data_columns: 0,
})),
};

Expand Down
16 changes: 16 additions & 0 deletions native/core/src/jvm_bridge/batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub struct CometBatchIterator<'a> {
pub method_has_selection_vectors_ret: ReturnType,
pub method_export_selection_indices: JMethodID,
pub method_export_selection_indices_ret: ReturnType,
pub method_advance_passthrough: JMethodID,
pub method_advance_passthrough_ret: ReturnType,
pub method_next_partition_columns_only: JMethodID,
pub method_next_partition_columns_only_ret: ReturnType,
}

impl<'a> CometBatchIterator<'a> {
Expand All @@ -61,6 +65,18 @@ impl<'a> CometBatchIterator<'a> {
"([J[J)I",
)?,
method_export_selection_indices_ret: ReturnType::Primitive(Primitive::Int),
method_advance_passthrough: env.get_method_id(
Self::JVM_CLASS,
"advancePassthrough",
"()J",
)?,
method_advance_passthrough_ret: ReturnType::Primitive(Primitive::Long),
method_next_partition_columns_only: env.get_method_id(
Self::JVM_CLASS,
"nextPartitionColumnsOnly",
"([J[JI)I",
)?,
method_next_partition_columns_only_ret: ReturnType::Primitive(Primitive::Int),
})
}
}
6 changes: 3 additions & 3 deletions native/core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,16 +601,16 @@ enum ParquetReaderState {
Complete,
}
/// Parquet read context maintained across multiple JNI calls.
struct BatchContext {
pub struct BatchContext {
native_plan: Arc<SparkPlan>,
metrics_node: Arc<GlobalRef>,
batch_stream: Option<SendableRecordBatchStream>,
current_batch: Option<RecordBatch>,
pub current_batch: Option<RecordBatch>,
reader_state: ParquetReaderState,
}

#[inline]
fn get_batch_context<'a>(handle: jlong) -> Result<&'a mut BatchContext, CometError> {
pub fn get_batch_context<'a>(handle: i64) -> Result<&'a mut BatchContext, CometError> {
unsafe {
(handle as *mut BatchContext)
.as_mut()
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ message Scan {
string source = 2;
// Whether native code can assume ownership of batches that it receives
bool arrow_ffi_safe = 3;
// When true, data columns are read directly from the native reader's
// BatchContext instead of through JVM FFI. Only partition columns
// cross the JVM boundary.
bool native_batch_passthrough = 4;
// Number of data columns (from native reader). Remaining columns are partition cols.
int32 num_data_columns = 5;
}

message NativeScan {
Expand Down
42 changes: 42 additions & 0 deletions spark/src/main/java/org/apache/comet/CometBatchIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

import org.apache.spark.sql.vectorized.ColumnarBatch;

import org.apache.comet.parquet.NativeBatchReader;
import org.apache.comet.vector.CometSelectionVector;
import org.apache.comet.vector.CometVector;
import org.apache.comet.vector.NativeUtil;

/**
Expand Down Expand Up @@ -111,6 +113,46 @@ public boolean hasSelectionVectors() {
return true;
}

/**
* Advance to next batch in passthrough mode. Data columns stay in native BatchContext; only
* partition columns are exported via FFI.
*
* @return native reader handle, or 0 for EOF
*/
public long advancePassthrough() {
previousBatch = null;

if (currentBatch == null) {
if (input.hasNext()) {
currentBatch = input.next();
}
}
if (currentBatch == null) {
return 0; // EOF
}
long handle = NativeBatchReader.CURRENT_READER_HANDLE.get();
previousBatch = currentBatch;
currentBatch = null;
return handle;
}

/**
* Export only partition columns (columns at indices >= numDataCols).
*
* @param arrayAddrs The addresses of the ArrowArray structures for partition columns
* @param schemaAddrs The addresses of the ArrowSchema structures for partition columns
* @param numDataCols Number of data columns to skip
* @return the number of rows, or -1 if no batch
*/
public int nextPartitionColumnsOnly(long[] arrayAddrs, long[] schemaAddrs, int numDataCols) {
if (previousBatch == null) return -1;
for (int i = numDataCols; i < previousBatch.numCols(); i++) {
CometVector vec = (CometVector) previousBatch.column(i);
nativeUtil.exportSingleVector(vec, arrayAddrs[i - numDataCols], schemaAddrs[i - numDataCols]);
}
return previousBatch.numRows();
}

/**
* Export selection indices for all columns when they are selection vectors.
*
Expand Down
Loading
Loading