Skip to content

Commit ada5e2c

Browse files
committed
feat: refactor table provider handling and update tests for improved access methods
1 parent 6552d43 commit ada5e2c

File tree

4 files changed

+92
-12
lines changed

4 files changed

+92
-12
lines changed

src/catalog.rs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20-
use crate::table::pyany_to_table_provider;
20+
use crate::table::PyTableProvider;
2121
use crate::utils::{validate_pycapsule, wait_for_future};
2222
use async_trait::async_trait;
2323
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -28,6 +28,7 @@ use datafusion::{
2828
datasource::{TableProvider, TableType},
2929
};
3030
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
31+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
3132
use pyo3::exceptions::PyKeyError;
3233
use pyo3::prelude::*;
3334
use pyo3::types::PyCapsule;
@@ -196,7 +197,29 @@ impl PySchema {
196197
}
197198

198199
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
199-
let provider = pyany_to_table_provider(&table_provider)?;
200+
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
201+
let capsule = table_provider
202+
.getattr("__datafusion_table_provider__")?
203+
.call0()?;
204+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
205+
validate_pycapsule(capsule, "datafusion_table_provider")?;
206+
207+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
208+
let provider: ForeignTableProvider = provider.into();
209+
Arc::new(provider) as Arc<dyn TableProvider + Send>
210+
} else {
211+
match table_provider.extract::<PyTable>() {
212+
Ok(py_table) => py_table.table,
213+
Err(_) => match table_provider.extract::<PyTableProvider>() {
214+
Ok(py_provider) => py_provider.into_inner(),
215+
Err(_) => {
216+
let py = table_provider.py();
217+
let provider = Dataset::new(&table_provider, py)?;
218+
Arc::new(provider) as Arc<dyn TableProvider + Send>
219+
}
220+
},
221+
}
222+
};
200223

201224
let _ = self
202225
.schema
@@ -285,7 +308,34 @@ impl RustWrappedPySchemaProvider {
285308
return Ok(None);
286309
}
287310

288-
pyany_to_table_provider(&py_table).map(Some)
311+
if py_table.hasattr("__datafusion_table_provider__")? {
312+
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
313+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
314+
validate_pycapsule(capsule, "datafusion_table_provider")?;
315+
316+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
317+
let provider: ForeignTableProvider = provider.into();
318+
319+
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider + Send>))
320+
} else {
321+
if let Ok(inner_table) = py_table.getattr("table") {
322+
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
323+
return Ok(Some(inner_table.table));
324+
}
325+
}
326+
327+
if let Ok(py_provider) = py_table.extract::<PyTableProvider>() {
328+
return Ok(Some(py_provider.into_inner()));
329+
}
330+
331+
match py_table.extract::<PyTable>() {
332+
Ok(py_table) => Ok(Some(py_table.table)),
333+
Err(_) => {
334+
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
335+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider + Send>))
336+
}
337+
}
338+
}
289339
})
290340
}
291341
}

src/context.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
4141
use crate::sql::exceptions::py_value_err;
4242
use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
44-
use crate::table::pyany_to_table_provider;
44+
use crate::table::PyTableProvider;
4545
use crate::udaf::PyAggregateUDF;
4646
use crate::udf::PyScalarUDF;
4747
use crate::udtf::PyTableFunction;
@@ -72,6 +72,7 @@ use datafusion::prelude::{
7272
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7373
};
7474
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
75+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7576
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7677
use pyo3::IntoPyObjectExt;
7778
use tokio::task::JoinHandle;
@@ -607,9 +608,26 @@ impl PySessionContext {
607608
name: &str,
608609
table_provider: Bound<'_, PyAny>,
609610
) -> PyDataFusionResult<()> {
610-
let provider = pyany_to_table_provider(&table_provider).map_err(|_| {
611-
PyDataFusionError::Common("Expected a Table or TableProvider.".to_string())
612-
})?;
611+
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
612+
let capsule = table_provider
613+
.getattr("__datafusion_table_provider__")?
614+
.call0()?;
615+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
616+
validate_pycapsule(capsule, "datafusion_table_provider")?;
617+
618+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
619+
let provider: ForeignTableProvider = provider.into();
620+
Arc::new(provider) as Arc<dyn TableProvider + Send>
621+
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
622+
py_table.table()
623+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
624+
py_provider.into_inner()
625+
} else {
626+
return Err(crate::errors::PyDataFusionError::Common(
627+
"Expected a Table or TableProvider.".to_string(),
628+
));
629+
};
630+
613631
self.ctx.register_table(name, provider)?;
614632
Ok(())
615633
}

src/table.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,24 @@ impl PyTableProvider {
4242
Self { provider }
4343
}
4444

45-
/// Return a `PyTable` wrapper around this provider so callers can call
46-
/// `as_table().table()` to get the underlying `Arc<dyn TableProvider + Send>`.
45+
/// Return a `PyTable` wrapper around this provider.
46+
///
47+
/// Historically callers chained `as_table().table()` to access the
48+
/// underlying `Arc<dyn TableProvider + Send>`. Prefer [`as_arc`] or
49+
/// [`into_inner`] for direct access instead.
4750
pub fn as_table(&self) -> PyTable {
4851
PyTable::new(Arc::clone(&self.provider))
4952
}
53+
54+
/// Return a clone of the inner [`TableProvider`].
55+
pub fn as_arc(&self) -> Arc<dyn TableProvider + Send> {
56+
Arc::clone(&self.provider)
57+
}
58+
59+
/// Consume this wrapper and return the inner [`TableProvider`].
60+
pub fn into_inner(self) -> Arc<dyn TableProvider + Send> {
61+
self.provider
62+
}
5063
}
5164

5265
#[pymethods]

tests/dataframe_into_view.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ fn dataframe_into_view_returns_table_provider() {
2626

2727
// Register the view in a new context and ensure it can be queried.
2828
let ctx = SessionContext::new();
29-
ctx.register_table("view", provider.as_table().table())
30-
.unwrap();
29+
ctx.register_table("view", provider.into_inner()).unwrap();
3130

3231
let rt = tokio::runtime::Runtime::new().unwrap();
3332
let batches = rt.block_on(async {

0 commit comments

Comments
 (0)