Skip to content

Commit 82bc41a

Browse files
committed
feat: add _normalize_table_provider utility for consistent table registration
1 parent 95abaf6 commit 82bc41a

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

python/datafusion/catalog.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import TYPE_CHECKING, Protocol
2424

2525
import datafusion._internal as df_internal
26+
from datafusion.utils import _normalize_table_provider
2627

2728
if TYPE_CHECKING:
2829
import pyarrow as pa
@@ -137,9 +138,8 @@ def register_table(
137138
Objects implementing ``__datafusion_table_provider__`` are also supported
138139
and treated as :class:`TableProvider` instances.
139140
"""
140-
if isinstance(table, Table):
141-
return self._raw_schema.register_table(name, table.table)
142-
return self._raw_schema.register_table(name, table)
141+
provider = _normalize_table_provider(table)
142+
return self._raw_schema.register_table(name, provider)
143143

144144
def deregister_table(self, name: str) -> None:
145145
"""Deregister a table provider from this schema."""

python/datafusion/context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
37+
from datafusion.utils import _normalize_table_provider
3738

3839
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
3940
from ._internal import SessionConfig as SessionConfigInternal
@@ -766,10 +767,8 @@ def register_table(
766767
implementing ``__datafusion_table_provider__`` to add to the session
767768
context.
768769
"""
769-
if isinstance(table, Table):
770-
self.ctx.register_table(name, table.table)
771-
else:
772-
self.ctx.register_table(name, table)
770+
provider = _normalize_table_provider(table)
771+
self.ctx.register_table(name, provider)
773772

774773
def deregister_table(self, name: str) -> None:
775774
"""Remove a table from the session."""

python/datafusion/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Miscellaneous helper utilities for DataFusion's Python bindings."""
18+
19+
from __future__ import annotations
20+
21+
from typing import TYPE_CHECKING, Any
22+
23+
from datafusion._internal import EXPECTED_PROVIDER_MSG
24+
25+
if TYPE_CHECKING: # pragma: no cover - imported for typing only
26+
from datafusion import TableProvider
27+
from datafusion.catalog import Table
28+
from datafusion.context import TableProviderExportable
29+
30+
31+
def _normalize_table_provider(
32+
table: Table | TableProvider | TableProviderExportable,
33+
) -> Any:
34+
"""Return the underlying provider for supported table inputs.
35+
36+
Args:
37+
table: A :class:`~datafusion.catalog.Table`,
38+
:class:`~datafusion.table_provider.TableProvider`, or object exporting a
39+
DataFusion table provider via ``__datafusion_table_provider__``.
40+
41+
Returns:
42+
The object expected by the Rust bindings for table registration.
43+
44+
Raises:
45+
TypeError: If ``table`` is not a supported table provider input.
46+
"""
47+
48+
from datafusion.catalog import Table as _Table
49+
from datafusion.table_provider import TableProvider as _TableProvider
50+
51+
if isinstance(table, _Table):
52+
return table.table
53+
54+
if isinstance(table, _TableProvider):
55+
return table._table_provider
56+
57+
provider_factory = getattr(table, "__datafusion_table_provider__", None)
58+
if callable(provider_factory):
59+
return table
60+
61+
raise TypeError(EXPECTED_PROVIDER_MSG)

0 commit comments

Comments
 (0)