Skip to content

Commit dcdca33

Browse files
committed
refactor: replace hardcoded error message with EXPECTED_PROVIDER_MSG constant
1 parent d3950c5 commit dcdca33

File tree

7 files changed

+33
-22
lines changed

7 files changed

+33
-22
lines changed

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from . import functions, object_store, substrait, unparser
3434

3535
# The following imports are okay to remain as opaque to the user.
36-
from ._internal import Config
36+
from ._internal import Config, EXPECTED_PROVIDER_MSG
3737
from .catalog import Catalog, Database, Table
3838
from .col import col, column
3939
from .common import (
@@ -77,6 +77,7 @@
7777
"DFSchema",
7878
"DataFrame",
7979
"Database",
80+
"EXPECTED_PROVIDER_MSG",
8081
"ExecutionPlan",
8182
"Expr",
8283
"LogicalPlan",

python/datafusion/table_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323

2424
import datafusion._internal as df_internal
25+
from datafusion._internal import EXPECTED_PROVIDER_MSG
2526

2627
_InternalTableProvider = df_internal.TableProvider
2728

@@ -37,8 +38,7 @@ def __init__(self, table_provider: _InternalTableProvider) -> None:
3738
table_provider = table_provider._table_provider
3839

3940
if not isinstance(table_provider, _InternalTableProvider):
40-
msg = "Expected a datafusion._internal.TableProvider instance."
41-
raise TypeError(msg)
41+
raise TypeError(EXPECTED_PROVIDER_MSG)
4242

4343
self._table_provider = table_provider
4444

python/tests/test_context.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pytest
2424
from datafusion import (
2525
DataFrame,
26+
EXPECTED_PROVIDER_MSG,
2627
RuntimeEnvBuilder,
2728
SessionConfig,
2829
SessionContext,
@@ -383,10 +384,7 @@ def test_register_table_with_dataframe_errors(ctx):
383384
with pytest.raises(Exception) as exc_info:
384385
ctx.register_table("bad", df)
385386

386-
assert str(exc_info.value) == (
387-
"Expected a Table or TableProvider. Convert DataFrames with "
388-
'"DataFrame.into_view()" or "TableProvider.from_dataframe()".'
389-
)
387+
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
390388

391389

392390
def test_register_dataset(ctx):

src/catalog.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::dataframe::PyDataFrame;
1819
use crate::dataset::Dataset;
1920
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2021
use crate::table::PyTableProvider;
2122
use crate::utils::{
22-
extract_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future,
23+
table_provider_from_pycapsule, validate_pycapsule, wait_for_future, EXPECTED_PROVIDER_MSG,
2324
};
2425
use async_trait::async_trait;
2526
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -198,7 +199,13 @@ impl PySchema {
198199
}
199200

200201
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
201-
let provider = if let Some(provider) = extract_table_provider(&table_provider)? {
202+
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
203+
py_table.table
204+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
205+
py_provider.into_inner()
206+
} else if table_provider.extract::<PyDataFrame>().is_ok() {
207+
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()).into());
208+
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
202209
provider
203210
} else {
204211
let py = table_provider.py();

src/context.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ use crate::record_batch::PyRecordBatchStream;
4141
use crate::sql::exceptions::py_value_err;
4242
use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
44+
use crate::table::PyTableProvider;
4445
use crate::udaf::PyAggregateUDF;
4546
use crate::udf::PyScalarUDF;
4647
use crate::udtf::PyTableFunction;
4748
use crate::udwf::PyWindowUDF;
4849
use crate::utils::{
49-
extract_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future,
50-
TABLE_OR_PROVIDER_EXPECTED_MESSAGE,
50+
get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule, validate_pycapsule,
51+
wait_for_future, EXPECTED_PROVIDER_MSG,
5152
};
5253
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5354
use datafusion::arrow::pyarrow::PyArrowType;
@@ -609,13 +610,17 @@ impl PySessionContext {
609610
name: &str,
610611
table_provider: Bound<'_, PyAny>,
611612
) -> PyDataFusionResult<()> {
612-
let provider = extract_table_provider(&table_provider)
613-
.map_err(crate::errors::PyDataFusionError::from)?
614-
.ok_or_else(|| {
615-
crate::errors::PyDataFusionError::Common(
616-
TABLE_OR_PROVIDER_EXPECTED_MESSAGE.to_string(),
617-
)
618-
})?;
613+
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
614+
py_table.table()
615+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
616+
py_provider.into_inner()
617+
} else if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
618+
provider
619+
} else {
620+
return Err(crate::errors::PyDataFusionError::Common(
621+
EXPECTED_PROVIDER_MSG.to_string(),
622+
));
623+
};
619624

620625
self.ctx.register_table(name, provider)?;
621626
Ok(())

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
8181
// Initialize logging
8282
pyo3_log::init();
8383

84+
m.add("EXPECTED_PROVIDER_MSG", crate::utils::EXPECTED_PROVIDER_MSG)?;
85+
8486
// Register the python classes
8587
m.add_class::<context::PyRuntimeEnvBuilder>()?;
8688
m.add_class::<context::PySessionConfig>()?;

src/utils.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use std::{
3737
};
3838
use tokio::{runtime::Runtime, time::sleep};
3939

40-
pub(crate) const TABLE_OR_PROVIDER_EXPECTED_MESSAGE: &str =
40+
pub(crate) const EXPECTED_PROVIDER_MSG: &str =
4141
"Expected a Table or TableProvider. Convert DataFrames with \"DataFrame.into_view()\" or \"TableProvider.from_dataframe()\".";
4242
/// Utility to get the Tokio Runtime from Python
4343
#[inline]
@@ -156,9 +156,7 @@ pub(crate) fn extract_table_provider(
156156
}
157157

158158
if py_obj.extract::<PyDataFrame>().is_ok() {
159-
return Err(
160-
PyDataFusionError::Common(TABLE_OR_PROVIDER_EXPECTED_MESSAGE.to_string()).into(),
161-
);
159+
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()).into());
162160
}
163161

164162
if let Some(provider) = table_provider_from_pycapsule(py_obj)? {

0 commit comments

Comments
 (0)