Skip to content

Commit 0ca3069

Browse files
authored
Switch executeQueryPhaseAsync and streamNext from block_on to spawn (#20892)
* Switch executeQueryPhaseAsync and streamNext from block_on to spawn This makes the JVM thread returns immediately after submitting work. To prevent thread leak in tests, in runtime_manager.rs, split io_runtime into a cheap Handle (for sharing) and Mutex<Option<Runtime>> (for owned shutdown). shutdown() now calls io_runtime.shutdown_timeout(10s) after cpu_executor.join_blocking(), blocking until all datafusion-io-* threads have fully terminated. Signed-off-by: Aravind Sagar <sagarara@amazon.com> * Make some variable names clearer Signed-off-by: Aravind Sagar <sagarara@amazon.com> --------- Signed-off-by: Aravind Sagar <sagarara@amazon.com>
1 parent e4a1193 commit 0ca3069

3 files changed

Lines changed: 126 additions & 80 deletions

File tree

plugins/engine-datafusion/jni/src/lib.rs

Lines changed: 93 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ use std::num::NonZeroUsize;
1212
* compatible open source license.
1313
*/
1414
use std::ptr::addr_of_mut;
15-
use jni::objects::{JByteArray, JClass, JMap, JObject};
15+
use jni::objects::{GlobalRef, JByteArray, JClass, JMap, JObject};
1616
use jni::objects::JLongArray;
1717
use jni::sys::{jboolean, jbyteArray, jint, jlong, jstring};
1818
use jni::{JNIEnv, JavaVM};
19+
use std::future::Future;
1920
use std::sync::{Arc, OnceLock};
2021
use arrow_array::{Array, RecordBatch, StructArray};
2122
use arrow_array::ffi::FFI_ArrowArray;
@@ -67,7 +68,7 @@ use tokio::runtime::Runtime;
6768
use std::result;
6869
use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
6970
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
70-
use futures::TryStreamExt;
71+
use futures::{TryStreamExt, FutureExt};
7172

7273
pub type Result<T, E = DataFusionError> = result::Result<T, E>;
7374

@@ -128,6 +129,59 @@ where
128129
})
129130
}
130131

132+
/// Extract a human-readable message from a panic payload.
133+
fn panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
134+
if let Some(s) = payload.downcast_ref::<&str>() {
135+
s.to_string()
136+
} else if let Some(s) = payload.downcast_ref::<String>() {
137+
s.clone()
138+
} else {
139+
"unknown panic payload".to_string()
140+
}
141+
}
142+
143+
/// Spawn an async task on `runtime` that calls an ActionListener exactly once.
144+
///
145+
/// The entire `task` future runs inside `catch_unwind`. Any panic is converted
146+
/// to a `DataFusionError` and surfaced to the Java caller via `listener_ref`.
147+
/// This ensures the `CompletableFuture` on the Java side is always completed,
148+
/// never left hanging.
149+
///
150+
/// `on_ok` receives the success value and is responsible for calling the
151+
/// appropriate `set_action_listener_ok_*` variant. `T` is inferred from
152+
/// the closure, which in turn pins the `Output` type of `task`.
153+
fn spawn_jni_task<Fut, T, FOk>(
154+
runtime: &tokio::runtime::Handle,
155+
task_name: &'static str,
156+
listener_ref: GlobalRef,
157+
task: Fut,
158+
on_ok: FOk,
159+
)
160+
where
161+
Fut: Future<Output = Result<T, DataFusionError>> + Send + 'static,
162+
T: Send + 'static,
163+
FOk: FnOnce(&mut JNIEnv, &GlobalRef, T) + Send + 'static,
164+
{
165+
let _ = runtime.spawn(async move {
166+
let result = std::panic::AssertUnwindSafe(task)
167+
.catch_unwind()
168+
.await
169+
.unwrap_or_else(|panic| {
170+
let msg = panic_message(&panic);
171+
log_error!("{} panicked: {}", task_name, msg);
172+
Err(DataFusionError::Execution(format!("{} panicked: {}", task_name, msg)))
173+
});
174+
175+
with_jni_env(|env| match result {
176+
Ok(value) => on_ok(env, &listener_ref, value),
177+
Err(e) => {
178+
log_error!("{} failed: {}", task_name, e);
179+
set_action_listener_error_global(env, &listener_ref, &e);
180+
}
181+
});
182+
});
183+
}
184+
131185
/// Initialize the logger for Rust->Java logging bridge.
132186
/// This should be called once when the native library is loaded.
133187
#[no_mangle]
@@ -619,9 +673,11 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
619673
let table_path = shard_view.table_path();
620674
let files_meta = shard_view.files_metadata();
621675

622-
io_runtime.block_on(async move {
623-
624-
let result = query_executor::execute_query_with_cross_rt_stream(
676+
spawn_jni_task(
677+
&io_runtime,
678+
"executeQueryPhaseAsync",
679+
listener_ref,
680+
query_executor::execute_query_with_cross_rt_stream(
625681
table_path,
626682
files_meta,
627683
table_name,
@@ -630,22 +686,9 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_executeQu
630686
target_partitions,
631687
runtime,
632688
cpu_executor,
633-
).await;
634-
635-
match result {
636-
Ok(stream_ptr) => {
637-
with_jni_env(|env| {
638-
set_action_listener_ok_global(env, &listener_ref, stream_ptr);
639-
});
640-
}
641-
Err(e) => {
642-
with_jni_env(|env| {
643-
log_error!("Query execution failed: {}", e);
644-
set_action_listener_error_global(env, &listener_ref, &e);
645-
});
646-
}
647-
}
648-
});
689+
),
690+
|env, listener_ref, stream_pointer| set_action_listener_ok_global(env, listener_ref, stream_pointer),
691+
);
649692
}
650693

651694
#[no_mangle]
@@ -680,22 +723,13 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_fetchSegm
680723
let shard_view = unsafe { &*(shard_view_ptr as *const ShardView) };
681724
let files_meta = shard_view.files_metadata();
682725

683-
io_runtime.block_on(async move {
684-
let file_stats = util::fetch_segment_statistics(files_meta).await;
685-
match file_stats {
686-
Ok(map) => {
687-
with_jni_env(|env| {
688-
set_action_listener_ok_global_with_map(env, &listener_ref, &map);
689-
});
690-
}
691-
Err(e) => {
692-
with_jni_env(|env| {
693-
log_error!("Collecting file stats failed: {}", e);
694-
set_action_listener_error_global(env, &listener_ref, &e);
695-
});
696-
}
697-
}
698-
});
726+
spawn_jni_task(
727+
&io_runtime,
728+
"fetchSegmentStats",
729+
listener_ref,
730+
async move { util::fetch_segment_statistics(files_meta).await },
731+
|env, listener_ref, stats_map| set_action_listener_ok_global_with_map(env, listener_ref, &stats_map),
732+
);
699733
}
700734

701735
#[no_mangle]
@@ -732,41 +766,40 @@ pub extern "system" fn Java_org_opensearch_datafusion_jni_NativeBridge_streamNex
732766
let stream_ptr = stream;
733767
let io_runtime = manager.io_runtime.clone();
734768

735-
io_runtime.block_on(async move {
736-
737-
let stream = unsafe { &mut *(stream_ptr as *mut RecordBatchStreamAdapter<CrossRtStream>) };
738-
let result = stream.try_next().await;
739-
740-
// Uncomment for monitoring stream next
741-
// let result = STREAM_NEXT_MONITOR.instrument(async {
742-
// stream.try_next().await
743-
// }).await;
744-
745-
// Use thread-local JNI env - auto-attaches!
746-
with_jni_env(|env| {
769+
// Ensure stream_ptr lifetime is guaranteed beyond the spawn boundary
770+
// (e.g., wrap in Arc<Mutex<...>> or ensure sequential access contract)
771+
spawn_jni_task(
772+
&io_runtime,
773+
"streamNext",
774+
listener_ref,
775+
async move {
776+
let stream = unsafe { &mut *(stream_ptr as *mut RecordBatchStreamAdapter<CrossRtStream>) };
777+
// Poll the stream with monitoring
778+
let result = stream.try_next().await?;
779+
780+
// Uncomment for monitoring stream next
781+
// let result = STREAM_NEXT_MONITOR.instrument(async {
782+
// stream.try_next().await
783+
// }).await;
747784
match result {
748-
Ok(Some(batch)) => {
785+
Some(batch) => {
749786
log_info!("[RUST streamNext] Batch produced: {} rows, {} columns, schema: {:?}",
750787
batch.num_rows(), batch.num_columns(), batch.schema().fields().iter().map(|f| f.name().as_str()).collect::<Vec<_>>());
751788
// Convert to FFI
752789
let struct_array: StructArray = batch.into();
753790
let array_data = struct_array.into_data();
754791
let ffi_array = FFI_ArrowArray::new(&array_data);
755-
let ffi_array_ptr = Box::into_raw(Box::new(ffi_array));
756-
set_action_listener_ok_global(env, &listener_ref, ffi_array_ptr as jlong);
792+
Ok(Box::into_raw(Box::new(ffi_array)) as jlong)
757793
}
758-
Ok(None) => {
794+
None => {
759795
log_info!("[RUST streamNext] End of stream reached");
760796
// End of stream
761-
set_action_listener_ok_global(env, &listener_ref, 0);
762-
}
763-
Err(err) => {
764-
log_error!("Stream next failed: {}", err);
765-
set_action_listener_error_global(env, &listener_ref, &err);
797+
Ok(0)
766798
}
767799
}
768-
});
769-
});
800+
},
801+
|env, listener_ref, data_pointer| set_action_listener_ok_global(env, listener_ref, data_pointer),
802+
);
770803
// Function returns immediately to java - async rust work continues in background
771804
}
772805

plugins/engine-datafusion/jni/src/runtime_manager.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use crate::executor::DedicatedExecutor;
22
use crate::io::register_io_runtime;
33
use vectorized_exec_spi::log_info;
44
use log::info;
5-
use std::sync::Arc;
5+
use std::sync::{Arc, Mutex};
6+
use std::time::Duration;
67
use datafusion::error::DataFusionError;
7-
use tokio::runtime::{Builder, Runtime};
8+
use tokio::runtime::{Builder, Handle, Runtime};
89

910
#[derive(Debug, Clone)]
1011
pub struct RuntimeConfig {
@@ -71,7 +72,12 @@ impl RuntimeConfig {
7172
}
7273

7374
pub struct RuntimeManager {
74-
pub io_runtime: Arc<Runtime>,
75+
/// Cloneable handle to the IO runtime, used by all JNI entry points to submit
76+
/// tasks without taking ownership of the runtime.
77+
pub io_runtime: Handle,
78+
/// Owned runtime kept behind a Mutex so shutdown() can take() it and call
79+
/// shutdown_timeout(), which blocks until all IO threads have fully stopped.
80+
io_runtime_owned: Mutex<Option<Runtime>>,
7581
pub(crate) cpu_executor: DedicatedExecutor,
7682
}
7783

@@ -88,36 +94,37 @@ impl RuntimeManager {
8894
pub fn with_config(config: RuntimeConfig) -> Self {
8995
log_info!("Creating RuntimeManager with config: {:?}", config);
9096

91-
// IO Runtime
92-
let io_runtime = Arc::new(
93-
Builder::new_multi_thread()
94-
.worker_threads(config.effective_io_threads())
95-
.thread_name("datafusion-io")
96-
.enable_all()
97-
.build()
98-
.expect("Failed to create IO runtime"),
99-
);
97+
// IO Runtime — build first so we can extract a Handle for sharing
98+
let io_runtime_rt = Builder::new_multi_thread()
99+
.worker_threads(config.effective_io_threads())
100+
.thread_name("datafusion-io")
101+
.enable_all()
102+
.build()
103+
.expect("Failed to create IO runtime");
104+
105+
let io_handle = io_runtime_rt.handle().clone();
100106

101-
// Register IO runtime for current thread
102-
register_io_runtime(Some(io_runtime.handle().clone()));
107+
// Register IO runtime for the calling (JNI) thread.
108+
register_io_runtime(Some(io_handle.clone()));
103109

104110
// CPU Executor with its own runtime
105111
let mut cpu_runtime_builder = Builder::new_multi_thread();
106-
let io_handle = io_runtime.handle().clone();
112+
let io_handle_for_cpu = io_handle.clone();
107113

108114
cpu_runtime_builder
109115
.worker_threads(config.effective_cpu_threads())
110116
.thread_name("datafusion-cpu")
111117
.enable_all()
112118
.on_thread_start(move || {
113119
// Register IO runtime for each CPU thread
114-
register_io_runtime(Some(io_handle.clone()));
120+
register_io_runtime(Some(io_handle_for_cpu.clone()));
115121
});
116122

117123
let cpu_executor = DedicatedExecutor::new("datafusion-cpu", cpu_runtime_builder);
118124

119125
Self {
120-
io_runtime,
126+
io_runtime: io_handle,
127+
io_runtime_owned: Mutex::new(Some(io_runtime_rt)),
121128
cpu_executor,
122129
}
123130
}
@@ -149,9 +156,14 @@ impl RuntimeManager {
149156

150157
pub fn shutdown(&self) {
151158
info!("Shutting down RuntimeManager");
159+
// Shut down CPU executor first — waits for all in-flight CPU tasks to finish.
152160
self.cpu_executor.join_blocking();
153-
// TODO: io_runtime spawned threads seem to have issue and are leaking
154-
161+
// Take the owned IO runtime out of the Option and call shutdown_timeout, which
162+
// blocks until every datafusion-io-* thread has fully terminated.
163+
if let Some(rt) = self.io_runtime_owned.lock().unwrap().take() {
164+
rt.shutdown_timeout(Duration::from_secs(10));
165+
}
166+
info!("RuntimeManager shut down complete");
155167
}
156168
}
157169

plugins/engine-datafusion/src/test/java/org/opensearch/datafusion/DataFusionReaderManagerTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import static org.opensearch.index.engine.Engine.SearcherScope.INTERNAL;
5353

5454
public class DataFusionReaderManagerTests extends OpenSearchTestCase {
55+
5556
private static DataFusionService service;
5657
Supplier<IndexFileDeleter> noOpFileDeleterSupplier;
5758

@@ -468,7 +469,7 @@ private ShardPath createShardPathWithResourceFiles(String indexName, int shardId
468469

469470
private void verifySearchResults(DatafusionSearcher searcher, DatafusionQuery datafusionQuery, Map<String, Long> expectedResults) throws Exception {
470471
Map<String, Object[]> finalRes = new HashMap<>();
471-
searcher.searchAsync(datafusionQuery, service.getRuntimePointer()).whenComplete((streamPointer, error)-> {
472+
searcher.searchAsync(datafusionQuery, service.getRuntimePointer()).whenCompleteAsync((streamPointer, error)-> {
472473
RootAllocator allocator = new RootAllocator(Long.MAX_VALUE);
473474
RecordBatchStream stream = new RecordBatchStream(streamPointer, service.getRuntimePointer(), allocator);
474475

0 commit comments

Comments
 (0)