Skip to content

Commit 87a0e30

Browse files
committed
Implement single point for scalar conversion from python objects
1 parent 7ff146e commit 87a0e30

File tree

6 files changed

+117
-85
lines changed

6 files changed

+117
-85
lines changed

src/config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use parking_lot::RwLock;
2222
use pyo3::prelude::*;
2323
use pyo3::types::*;
2424

25+
use crate::common::data_type::PyScalarValue;
2526
use crate::errors::PyDataFusionResult;
26-
use crate::utils::py_obj_to_scalar_value;
2727
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
2828
#[derive(Clone)]
2929
pub(crate) struct PyConfig {
@@ -65,9 +65,9 @@ impl PyConfig {
6565

6666
/// Set a configuration option
6767
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> PyDataFusionResult<()> {
68-
let scalar_value = py_obj_to_scalar_value(py, value)?;
68+
let scalar_value: PyScalarValue = value.extract(py)?;
6969
let mut options = self.config.write();
70-
options.set(key, scalar_value.to_string().as_str())?;
70+
options.set(key, scalar_value.0.to_string().as_str())?;
7171
Ok(())
7272
}
7373

src/dataframe.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,15 @@ use pyo3::pybacked::PyBackedStr;
4848
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
4949
use pyo3::PyErr;
5050

51+
use crate::common::data_type::PyScalarValue;
5152
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
5253
use crate::expr::sort_expr::{to_sort_expressions, PySortExpr};
5354
use crate::expr::PyExpr;
5455
use crate::physical_plan::PyExecutionPlan;
5556
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
5657
use crate::sql::logical::PyLogicalPlan;
5758
use crate::table::{PyTable, TempViewTable};
58-
use crate::utils::{
59-
is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future,
60-
};
59+
use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
6160

6261
/// File-level static CStr for the Arrow array stream capsule name.
6362
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
@@ -1191,14 +1190,14 @@ impl PyDataFrame {
11911190
columns: Option<Vec<PyBackedStr>>,
11921191
py: Python,
11931192
) -> PyDataFusionResult<Self> {
1194-
let scalar_value = py_obj_to_scalar_value(py, value)?;
1193+
let scalar_value: PyScalarValue = value.extract(py)?;
11951194

11961195
let cols = match columns {
11971196
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
11981197
None => Vec::new(), // Empty vector means fill null for all columns
11991198
};
12001199

1201-
let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
1200+
let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
12021201
Ok(Self::new(df))
12031202
}
12041203
}

src/pyarrow_util.rs

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,127 @@
1717

1818
//! Conversions between PyArrow and DataFusion types
1919
20-
use arrow::array::{Array, ArrayData};
20+
use std::sync::Arc;
21+
22+
use arrow::array::{make_array, Array, ArrayData, ListArray};
23+
use arrow::buffer::OffsetBuffer;
24+
use arrow::datatypes::Field;
2125
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
26+
use datafusion::common::exec_err;
2227
use datafusion::scalar::ScalarValue;
2328
use pyo3::types::{PyAnyMethods, PyList};
2429
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
2530

2631
use crate::common::data_type::PyScalarValue;
2732
use crate::errors::PyDataFusionError;
2833

34+
fn pyobj_extract_scalar_via_capsule(
35+
value: &Bound<'_, PyAny>,
36+
as_list_array: bool,
37+
) -> PyResult<PyScalarValue> {
38+
let array_data = ArrayData::from_pyarrow_bound(value)?;
39+
let array = make_array(array_data);
40+
41+
if as_list_array {
42+
let field = Arc::new(Field::new_list_field(
43+
array.data_type().clone(),
44+
array.nulls().is_some(),
45+
));
46+
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
47+
let list_array = ListArray::new(field, offsets, array, None);
48+
Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
49+
} else {
50+
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
51+
Ok(PyScalarValue(scalar))
52+
}
53+
}
54+
2955
impl FromPyArrow for PyScalarValue {
3056
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
3157
let py = value.py();
32-
let typ = value.getattr("type")?;
58+
let pyarrow_mod = py.import("pyarrow");
3359

34-
// construct pyarrow array from the python value and pyarrow type
35-
let factory = py.import("pyarrow")?.getattr("array")?;
36-
let args = PyList::new(py, [value])?;
37-
let array = factory.call1((args, typ))?;
60+
// Is it a PyArrow object?
61+
if let Ok(pa) = pyarrow_mod.as_ref() {
62+
let scalar_type = pa.getattr("Scalar")?;
63+
if value.is_instance(&scalar_type)? {
64+
let typ = value.getattr("type")?;
3865

39-
// convert the pyarrow array to rust array using C data interface
40-
let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
41-
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
66+
// construct pyarrow array from the python value and pyarrow type
67+
let factory = py.import("pyarrow")?.getattr("array")?;
68+
let args = PyList::new(py, [value])?;
69+
let array = factory.call1((args, typ))?;
4270

43-
Ok(PyScalarValue(scalar))
71+
return pyobj_extract_scalar_via_capsule(&array, false);
72+
}
73+
74+
let array_type = pa.getattr("Array")?;
75+
if value.is_instance(&array_type)? {
76+
return pyobj_extract_scalar_via_capsule(value, true);
77+
}
78+
}
79+
80+
// Is it a NanoArrow scalar?
81+
if let Ok(na) = py.import("nanoarrow") {
82+
let type_name = value.get_type().repr()?;
83+
if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? {
84+
return pyobj_extract_scalar_via_capsule(value, false);
85+
}
86+
let array_type = na.getattr("Array")?;
87+
if value.is_instance(&array_type)? {
88+
return pyobj_extract_scalar_via_capsule(value, true);
89+
}
90+
}
91+
92+
// Is it a arro3 scalar?
93+
if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) {
94+
let scalar_type = arro3.getattr("Scalar")?;
95+
if value.is_instance(&scalar_type)? {
96+
return pyobj_extract_scalar_via_capsule(value, false);
97+
}
98+
let array_type = arro3.getattr("Array")?;
99+
if value.is_instance(&array_type)? {
100+
return pyobj_extract_scalar_via_capsule(value, true);
101+
}
102+
}
103+
104+
// Does it have a PyCapsule interface but isn't one of our known libraries?
105+
// If so do our "best guess". Try checking type name, and if that fails
106+
// return a single value if the length is 1 and return a List value otherwise
107+
if value.hasattr("__arrow_c_array__")? {
108+
let type_name = value.get_type().repr()?;
109+
if type_name.contains("Scalar")? {
110+
return pyobj_extract_scalar_via_capsule(value, false);
111+
}
112+
if type_name.contains("Array")? {
113+
return pyobj_extract_scalar_via_capsule(value, true);
114+
}
115+
116+
let array_data = ArrayData::from_pyarrow_bound(value)?;
117+
let array = make_array(array_data);
118+
if array.len() == 1 {
119+
let scalar =
120+
ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
121+
return Ok(PyScalarValue(scalar));
122+
} else {
123+
let field = Arc::new(Field::new_list_field(
124+
array.data_type().clone(),
125+
array.nulls().is_some(),
126+
));
127+
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
128+
let list_array = ListArray::new(field, offsets, array, None);
129+
return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))));
130+
}
131+
}
132+
133+
// Last attempt - try to create a PyArrow scalar from a plain Python object
134+
if let Ok(pa) = pyarrow_mod.as_ref() {
135+
let scalar = pa.call_method1("scalar", (value,))?;
136+
137+
PyScalarValue::from_pyarrow_bound(&scalar)
138+
} else {
139+
exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)?
140+
}
44141
}
45142
}
46143

src/udaf.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple};
3232
use crate::common::data_type::PyScalarValue;
3333
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3434
use crate::expr::PyExpr;
35-
use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule};
35+
use crate::utils::{parse_volatility, validate_pycapsule};
3636

3737
#[derive(Debug)]
3838
struct RustAccumulator {
@@ -52,10 +52,7 @@ impl Accumulator for RustAccumulator {
5252
let mut scalars = Vec::new();
5353
for item in values.try_iter()? {
5454
let item: Bound<'_, PyAny> = item?;
55-
let scalar = match item.extract::<PyScalarValue>() {
56-
Ok(py_scalar) => py_scalar.0,
57-
Err(_) => py_obj_to_scalar_value(py, item.unbind())?,
58-
};
55+
let scalar = item.extract::<PyScalarValue>()?.0;
5956
scalars.push(scalar);
6057
}
6158
Ok(scalars)
@@ -66,10 +63,7 @@ impl Accumulator for RustAccumulator {
6663
fn evaluate(&mut self) -> Result<ScalarValue> {
6764
Python::attach(|py| -> PyResult<ScalarValue> {
6865
let value = self.accum.bind(py).call_method0("evaluate")?;
69-
match value.extract::<PyScalarValue>() {
70-
Ok(py_scalar) => Ok(py_scalar.0),
71-
Err(_) => py_obj_to_scalar_value(py, value.unbind()),
72-
}
66+
value.extract::<PyScalarValue>().map(|v| v.0)
7367
})
7468
.map_err(|e| DataFusionError::Execution(format!("{e}")))
7569
}

src/udwf.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator {
9494
}
9595

9696
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
97-
println!("evaluate all called with number of values {}", values.len());
9897
Python::attach(|py| {
9998
let py_values = PyList::new(
10099
py,

src/utils.rs

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@ use std::future::Future;
1919
use std::sync::{Arc, OnceLock};
2020
use std::time::Duration;
2121

22-
use datafusion::arrow::array::{make_array, ArrayData, ListArray};
23-
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
24-
use datafusion::arrow::datatypes::Field;
25-
use datafusion::arrow::pyarrow::FromPyArrow;
26-
use datafusion::common::ScalarValue;
2722
use datafusion::datasource::TableProvider;
2823
use datafusion::execution::context::SessionContext;
2924
use datafusion::logical_expr::Volatility;
@@ -37,7 +32,6 @@ use tokio::runtime::Runtime;
3732
use tokio::task::JoinHandle;
3833
use tokio::time::sleep;
3934

40-
use crate::common::data_type::PyScalarValue;
4135
use crate::context::PySessionContext;
4236
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
4337
use crate::TokioRuntime;
@@ -203,57 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
203197
}
204198
}
205199

206-
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<ScalarValue> {
207-
// convert Python object to PyScalarValue to ScalarValue
208-
209-
let pa = py.import("pyarrow")?;
210-
let scalar_attr = pa.getattr("Scalar")?;
211-
let scalar_type = scalar_attr.downcast::<PyType>()?;
212-
let array_attr = pa.getattr("Array")?;
213-
let array_type = array_attr.downcast::<PyType>()?;
214-
let chunked_array_attr = pa.getattr("ChunkedArray")?;
215-
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;
216-
217-
let obj_ref = obj.bind(py);
218-
219-
if obj_ref.is_instance(scalar_type)? {
220-
let py_scalar = PyScalarValue::extract_bound(obj_ref)
221-
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
222-
return Ok(py_scalar.into());
223-
}
224-
225-
if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
226-
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
227-
obj_ref.call_method0("combine_chunks")?.unbind()
228-
} else {
229-
obj_ref.clone().unbind()
230-
};
231-
let array_bound = array_obj.bind(py);
232-
let array_data = ArrayData::from_pyarrow_bound(array_bound)
233-
.map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?;
234-
let array = make_array(array_data);
235-
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32]));
236-
let list_array = Arc::new(ListArray::new(
237-
Arc::new(Field::new_list_field(array.data_type().clone(), true)),
238-
offsets,
239-
array,
240-
None,
241-
));
242-
243-
return Ok(ScalarValue::List(list_array));
244-
}
245-
246-
// Convert Python object to PyArrow scalar
247-
let scalar = pa.call_method1("scalar", (obj,))?;
248-
249-
// Convert PyArrow scalar to PyScalarValue
250-
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
251-
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
252-
253-
// Convert PyScalarValue to ScalarValue
254-
Ok(py_scalar.into())
255-
}
256-
257200
pub(crate) fn extract_logical_extension_codec(
258201
py: Python,
259202
obj: Option<Bound<PyAny>>,

0 commit comments

Comments
 (0)