Skip to content

Commit fddb383

Browse files
committed
fix: refactor table provider registration logic for improved clarity
1 parent 8ca3b3a commit fddb383

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

src/catalog.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,20 +199,16 @@ impl PySchema {
199199
}
200200

201201
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
202-
let provider = if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
202+
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
203+
py_table.table
204+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
205+
py_provider.into_inner()
206+
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
203207
provider
204208
} else {
205-
match table_provider.extract::<PyTable>() {
206-
Ok(py_table) => py_table.table,
207-
Err(_) => match table_provider.extract::<PyTableProvider>() {
208-
Ok(py_provider) => py_provider.into_inner(),
209-
Err(_) => {
210-
let py = table_provider.py();
211-
let provider = Dataset::new(&table_provider, py)?;
212-
Arc::new(provider) as Arc<dyn TableProvider + Send>
213-
}
214-
},
215-
}
209+
let py = table_provider.py();
210+
let provider = Dataset::new(&table_provider, py)?;
211+
Arc::new(provider) as Arc<dyn TableProvider + Send>
216212
};
217213

218214
let _ = self

src/context.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,12 +610,12 @@ impl PySessionContext {
610610
name: &str,
611611
table_provider: Bound<'_, PyAny>,
612612
) -> PyDataFusionResult<()> {
613-
let provider = if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
614-
provider
615-
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
613+
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
616614
py_table.table()
617615
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
618616
py_provider.into_inner()
617+
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
618+
provider
619619
} else {
620620
return Err(crate::errors::PyDataFusionError::Common(
621621
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\".".to_string(),

0 commit comments

Comments
 (0)