Skip to content

Commit c99f7f7

Browse files
committed
feat: implement extract_table_provider utility for streamlined table registration
1 parent 99ba9d4 commit c99f7f7

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

src/catalog.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::dataframe::PyDataFrame;
1918
use crate::dataset::Dataset;
2019
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2120
use crate::table::PyTableProvider;
2221
use crate::utils::{
23-
table_provider_from_pycapsule, validate_pycapsule, wait_for_future, EXPECTED_PROVIDER_MSG,
22+
extract_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future,
2423
};
2524
use async_trait::async_trait;
2625
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -199,18 +198,9 @@ impl PySchema {
199198
}
200199

201200
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
202-
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
203-
py_table.table
204-
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
205-
py_provider.into_inner()
206-
} else if table_provider.extract::<PyDataFrame>().is_ok() {
207-
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()).into());
208-
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
209-
provider
210-
} else {
211-
let py = table_provider.py();
212-
let provider = Dataset::new(&table_provider, py)?;
213-
Arc::new(provider) as Arc<dyn TableProvider>
201+
let provider = match extract_table_provider(&table_provider) {
202+
Ok(provider) => provider,
203+
Err(err) => return Err(err.into()),
214204
};
215205

216206
let _ = self

src/context.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ use crate::record_batch::PyRecordBatchStream;
4141
use crate::sql::exceptions::py_value_err;
4242
use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
44-
use crate::table::PyTableProvider;
4544
use crate::udaf::PyAggregateUDF;
4645
use crate::udf::PyScalarUDF;
4746
use crate::udtf::PyTableFunction;
4847
use crate::udwf::PyWindowUDF;
4948
use crate::utils::{
50-
get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule, validate_pycapsule,
51-
wait_for_future, EXPECTED_PROVIDER_MSG,
49+
extract_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future,
5250
};
5351
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5452
use datafusion::arrow::pyarrow::PyArrowType;
@@ -610,17 +608,7 @@ impl PySessionContext {
610608
name: &str,
611609
table_provider: Bound<'_, PyAny>,
612610
) -> PyDataFusionResult<()> {
613-
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
614-
py_table.table()
615-
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
616-
py_provider.into_inner()
617-
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
618-
provider
619-
} else {
620-
return Err(crate::errors::PyDataFusionError::Common(
621-
EXPECTED_PROVIDER_MSG.to_string(),
622-
));
623-
};
611+
let provider = extract_table_provider(&table_provider)?;
624612

625613
self.ctx.register_table(name, provider)?;
626614
Ok(())

src/utils.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
// under the License.
1717

1818
use crate::{
19+
catalog::PyTable,
1920
common::data_type::PyScalarValue,
21+
dataframe::PyDataFrame,
22+
dataset::Dataset,
2023
errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult},
24+
table::PyTableProvider,
2125
TokioRuntime,
2226
};
2327
use datafusion::{
@@ -141,6 +145,40 @@ pub(crate) fn table_provider_from_pycapsule(
141145
}
142146
}
143147

148+
pub(crate) fn extract_table_provider(
149+
table_like: &Bound<PyAny>,
150+
) -> PyDataFusionResult<Arc<dyn TableProvider>> {
151+
if let Ok(py_table) = table_like.extract::<PyTable>() {
152+
return Ok(py_table.table());
153+
}
154+
155+
if let Ok(py_provider) = table_like.extract::<PyTableProvider>() {
156+
return Ok(py_provider.into_inner());
157+
}
158+
159+
if table_like.extract::<PyDataFrame>().is_ok() {
160+
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()));
161+
}
162+
163+
match table_provider_from_pycapsule(table_like) {
164+
Ok(Some(provider)) => Ok(provider),
165+
Ok(None) => {
166+
let py = table_like.py();
167+
match Dataset::new(table_like, py) {
168+
Ok(dataset) => Ok(Arc::new(dataset) as Arc<dyn TableProvider>),
169+
Err(err) => {
170+
if err.is_instance_of::<PyValueError>(py) {
171+
Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()))
172+
} else {
173+
Err(err.into())
174+
}
175+
}
176+
}
177+
}
178+
Err(err) => Err(err.into()),
179+
}
180+
}
181+
144182
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
145183
// convert Python object to PyScalarValue to ScalarValue
146184

0 commit comments

Comments
 (0)