Skip to content

Commit 7423480

Browse files
committed
refactor: streamline table provider extraction logic and improve error handling
1 parent e9060ab commit 7423480

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

src/catalog.rs

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::dataframe::PyDataFrame;
1918
use crate::dataset::Dataset;
2019
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2120
use crate::table::PyTableProvider;
22-
use crate::utils::{table_provider_from_pycapsule, validate_pycapsule, wait_for_future};
21+
use crate::utils::{
22+
extract_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future,
23+
};
2324
use async_trait::async_trait;
2425
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2526
use datafusion::common::DataFusionError;
@@ -197,17 +198,7 @@ impl PySchema {
197198
}
198199

199200
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
200-
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
201-
py_table.table
202-
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
203-
py_provider.into_inner()
204-
} else if table_provider.extract::<PyDataFrame>().is_ok() {
205-
return Err(PyDataFusionError::Common(
206-
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\"."
207-
.to_string(),
208-
)
209-
.into());
210-
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
201+
let provider = if let Some(provider) = extract_table_provider(&table_provider)? {
211202
provider
212203
} else {
213204
let py = table_provider.py();

src/context.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ use crate::record_batch::PyRecordBatchStream;
4141
use crate::sql::exceptions::py_value_err;
4242
use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
44-
use crate::table::PyTableProvider;
4544
use crate::udaf::PyAggregateUDF;
4645
use crate::udf::PyScalarUDF;
4746
use crate::udtf::PyTableFunction;
4847
use crate::udwf::PyWindowUDF;
4948
use crate::utils::{
50-
get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule, validate_pycapsule,
51-
wait_for_future,
49+
extract_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future,
50+
TABLE_OR_PROVIDER_EXPECTED_MESSAGE,
5251
};
5352
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5453
use datafusion::arrow::pyarrow::PyArrowType;
@@ -610,17 +609,13 @@ impl PySessionContext {
610609
name: &str,
611610
table_provider: Bound<'_, PyAny>,
612611
) -> PyDataFusionResult<()> {
613-
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
614-
py_table.table()
615-
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
616-
py_provider.into_inner()
617-
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
618-
provider
619-
} else {
620-
return Err(crate::errors::PyDataFusionError::Common(
621-
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\".".to_string(),
622-
));
623-
};
612+
let provider = extract_table_provider(&table_provider)
613+
.map_err(crate::errors::PyDataFusionError::from)?
614+
.ok_or_else(|| {
615+
crate::errors::PyDataFusionError::Common(
616+
TABLE_OR_PROVIDER_EXPECTED_MESSAGE.to_string(),
617+
)
618+
})?;
624619

625620
self.ctx.register_table(name, provider)?;
626621
Ok(())

src/utils.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
// under the License.
1717

1818
use crate::{
19+
catalog::PyTable,
1920
common::data_type::PyScalarValue,
21+
dataframe::PyDataFrame,
2022
errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult},
23+
table::PyTableProvider,
2124
TokioRuntime,
2225
};
2326
use datafusion::{
@@ -33,6 +36,9 @@ use std::{
3336
time::Duration,
3437
};
3538
use tokio::{runtime::Runtime, time::sleep};
39+
40+
pub(crate) const TABLE_OR_PROVIDER_EXPECTED_MESSAGE: &str =
41+
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\".";
3642
/// Utility to get the Tokio Runtime from Python
3743
#[inline]
3844
pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
@@ -138,6 +144,30 @@ pub(crate) fn table_provider_from_pycapsule(
138144
}
139145
}
140146

147+
pub(crate) fn extract_table_provider(
148+
py_obj: &Bound<PyAny>,
149+
) -> PyResult<Option<Arc<dyn TableProvider>>> {
150+
if let Ok(py_table) = py_obj.extract::<PyTable>() {
151+
return Ok(Some(py_table.table()));
152+
}
153+
154+
if let Ok(py_provider) = py_obj.extract::<PyTableProvider>() {
155+
return Ok(Some(py_provider.into_inner()));
156+
}
157+
158+
if py_obj.extract::<PyDataFrame>().is_ok() {
159+
return Err(
160+
PyDataFusionError::Common(TABLE_OR_PROVIDER_EXPECTED_MESSAGE.to_string()).into(),
161+
);
162+
}
163+
164+
if let Some(provider) = table_provider_from_pycapsule(py_obj)? {
165+
return Ok(Some(provider));
166+
}
167+
168+
Ok(None)
169+
}
170+
141171
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
142172
// convert Python object to PyScalarValue to ScalarValue
143173

0 commit comments

Comments
 (0)