Skip to content

Commit f02114a

Browse files
committed
feat(utils): add spawn_stream utility for async execution with Python signal handling
1 parent e3e3b20 commit f02114a

3 files changed

Lines changed: 21 additions & 19 deletions

File tree

src/context.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
48-
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48+
use crate::utils::{get_global_ctx, spawn_stream, validate_pycapsule, wait_for_future};
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
@@ -74,7 +74,6 @@ use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvid
7474
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7575
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7676
use pyo3::IntoPyObjectExt;
77-
use tokio::task::JoinHandle;
7877

7978
/// Configuration options for a SessionContext
8079
#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
@@ -1132,12 +1131,8 @@ impl PySessionContext {
11321131
py: Python,
11331132
) -> PyDataFusionResult<PyRecordBatchStream> {
11341133
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1135-
// create a Tokio runtime to run the async code
1136-
let rt = &get_tokio_runtime().0;
11371134
let plan = plan.plan.clone();
1138-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1139-
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1140-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1135+
let stream = spawn_stream(py, async move { plan.execute(part, Arc::new(ctx)) })?;
11411136
Ok(PyRecordBatchStream::new(stream))
11421137
}
11431138
}

src/dataframe.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ use crate::physical_plan::PyExecutionPlan;
5151
use crate::record_batch::PyRecordBatchStream;
5252
use crate::sql::logical::PyLogicalPlan;
5353
use crate::utils::{
54-
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
54+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, validate_pycapsule,
55+
wait_for_future,
5556
};
5657
use crate::{
5758
errors::PyDataFusionResult,
@@ -922,9 +923,7 @@ impl PyDataFrame {
922923
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
923924
let rt = &get_tokio_runtime().0;
924925
let df = self.df.as_ref().clone();
925-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
926-
rt.spawn(async move { df.execute_stream().await });
927-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
926+
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
928927

929928
let mut schema: Schema = self.df.schema().to_owned().into();
930929
let mut projection: Option<SchemaRef> = None;
@@ -955,12 +954,8 @@ impl PyDataFrame {
955954
}
956955

957956
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
958-
// create a Tokio runtime to run the async code
959-
let rt = &get_tokio_runtime().0;
960957
let df = self.df.as_ref().clone();
961-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
962-
rt.spawn(async move { df.execute_stream().await });
963-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
958+
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
964959
Ok(PyRecordBatchStream::new(stream))
965960
}
966961

src/utils.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717

1818
use crate::{
1919
common::data_type::PyScalarValue,
20-
errors::{PyDataFusionError, PyDataFusionResult},
20+
errors::{to_datafusion_err, PyDataFusionError, PyDataFusionResult},
2121
TokioRuntime,
2222
};
2323
use datafusion::{
24-
common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
24+
common::ScalarValue, execution::context::SessionContext, execution::SendableRecordBatchStream,
25+
logical_expr::Volatility,
2526
};
2627
use pyo3::prelude::*;
2728
use pyo3::{exceptions::PyValueError, types::PyCapsule};
2829
use std::{future::Future, sync::OnceLock, time::Duration};
29-
use tokio::{runtime::Runtime, time::sleep};
30+
use tokio::{runtime::Runtime, task::JoinHandle, time::sleep};
3031
/// Utility to get the Tokio Runtime from Python
3132
#[inline]
3233
pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
@@ -84,6 +85,17 @@ where
8485
})
8586
}
8687

88+
/// Spawn a [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
89+
/// while respecting Python signal handling.
90+
pub(crate) fn spawn_stream<F>(py: Python, fut: F) -> PyDataFusionResult<SendableRecordBatchStream>
91+
where
92+
F: Future<Output = datafusion::common::Result<SendableRecordBatchStream>> + Send + 'static,
93+
{
94+
let rt = &get_tokio_runtime().0;
95+
let handle: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> = rt.spawn(fut);
96+
wait_for_future(py, async { handle.await.map_err(to_datafusion_err) })???
97+
}
98+
8799
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
88100
Ok(match value {
89101
"immutable" => Volatility::Immutable,

0 commit comments

Comments
 (0)