feat: register python data sources#6936
Conversation
- Skip throwaway construction when constructor requires args - Wrap any init failure into a clear ValueError - Hoist _RegisteredDataSource to a lazy module-level class - Document auto-instantiation in register_data_source docstring Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Greptile SummaryThis PR adds a session-scoped Python
Confidence Score: 4/5Safe to merge; the new code is well-tested, additive, and isolated to the registry module and Session class. The core registration and read path works correctly and is well-tested. The two findings are both style-level: inline imports in data_sources.py violate the team's import-placement rule, and naming a module function list shadows the builtin. Neither affects correctness. daft/data_sources.py — inline imports and the list naming collision are worth cleaning up before the API stabilises. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant ds as daft.data_sources
participant sess as Session
participant reg as _data_sources dict
participant src as DataSource class
participant wrap as _RegisteredDataSource
User->>ds: "register(MySource, name=my_source)"
ds->>sess: current_session().register_data_source(MySource)
sess->>sess: _data_source_registration_name(MySource)
sess->>reg: "_data_sources[my_source] = MySource"
User->>ds: "read_source(my_source, **options)"
ds->>sess: "_session().read_source(my_source, **options)"
sess->>reg: get_data_source(my_source) returns MySource class
sess->>src: "MySource(**options) returns instance"
sess->>wrap: _RegisteredDataSource(my_source, instance)
wrap->>wrap: .read() via ScanOperatorHandle
wrap-->>User: DataFrame
Reviews (1): Last reviewed commit: "refactor(session): polish data source re..." | Re-trigger Greptile |
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from daft.session import current_session | ||
|
|
||
| if TYPE_CHECKING: | ||
| from daft.dataframe import DataFrame | ||
| from daft.io.source import DataSource | ||
|
|
||
|
|
||
| def register(data_source: type[DataSource], *, name: str | None = None, replace: bool = False) -> None: | ||
| """Register a Python ``DataSource`` class on the current session.""" | ||
| current_session().register_data_source(data_source, name=name, replace=replace) | ||
|
|
||
|
|
||
| def unregister(name: str) -> None: | ||
| """Remove a registered Python ``DataSource`` from the current session.""" | ||
| current_session().unregister_data_source(name) | ||
|
|
||
|
|
||
| def get(name: str) -> type[DataSource]: | ||
| """Return a registered Python ``DataSource`` class.""" | ||
| return current_session().get_data_source(name) | ||
|
|
||
|
|
||
| def list_sources() -> list[str]: | ||
| """List registered Python ``DataSource`` names.""" | ||
| return current_session().list_data_sources() | ||
|
|
||
|
|
||
| def read(name: str, **options: Any) -> DataFrame: | ||
| """Read a registered Python ``DataSource`` by name.""" | ||
| return current_session().read_source(name, **options) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "get", | ||
| "list_sources", | ||
| "read", | ||
| "register", | ||
| "unregister", | ||
| ] |
There was a problem hiding this comment.
We can delete this file, only use the session methods.
| def register_data_source( | ||
| self, | ||
| data_source: type[DataSource], | ||
| *, | ||
| name: str | None = None, | ||
| replace: bool = False, | ||
| ) -> None: | ||
| """Register a Python ``DataSource`` class with this session. | ||
|
|
||
| If ``name`` is omitted and the class declares ``name`` as an instance | ||
| ``@property``, the class is instantiated with no arguments to read it. | ||
| Pass ``name=`` explicitly to skip that construction (e.g. when the | ||
| constructor has side effects or requires arguments). | ||
| """ | ||
| from daft.io.source import DataSource | ||
|
|
||
| if not isinstance(data_source, type) or not issubclass(data_source, DataSource): | ||
| raise TypeError(f"Expected a DataSource class, got {data_source!r}") | ||
|
|
||
| source_name = _data_source_registration_name(data_source, name) | ||
| if not replace and source_name in self._data_sources: | ||
| raise ValueError(f"DataSource {source_name!r} is already registered") | ||
| self._data_sources[source_name] = data_source | ||
|
|
||
| def unregister_data_source(self, name: str) -> None: | ||
| """Remove a registered Python ``DataSource`` from this session.""" | ||
| try: | ||
| del self._data_sources[name] | ||
| except KeyError as e: | ||
| raise ValueError(f"DataSource {name!r} is not registered") from e | ||
|
|
||
| def get_data_source(self, name: str) -> type[DataSource]: | ||
| """Return a registered Python ``DataSource`` class.""" | ||
| try: | ||
| return self._data_sources[name] | ||
| except KeyError as e: | ||
| raise ValueError(f"DataSource {name!r} is not registered") from e | ||
|
|
||
| def list_data_sources(self) -> list[str]: | ||
| """List registered Python ``DataSource`` names.""" | ||
| return sorted(self._data_sources) |
There was a problem hiding this comment.
Please match the existing conventions of attach/detach
| def _registered_data_source_cls() -> type[DataSource]: | ||
| global _RegisteredDataSourceCls | ||
| if _RegisteredDataSourceCls is not None: | ||
| return _RegisteredDataSourceCls | ||
|
|
||
| from daft.io.source import DataSource | ||
|
|
||
| class _RegisteredDataSource(DataSource): | ||
| def __init__(self, registered_name: str, wrapped: DataSource) -> None: | ||
| self._registered_name = registered_name | ||
| self._wrapped = wrapped | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return self._registered_name | ||
|
|
||
| @property | ||
| def schema(self) -> Schema: | ||
| return self._wrapped.schema | ||
|
|
||
| def get_partition_fields(self) -> list[PartitionField]: | ||
| return self._wrapped.get_partition_fields() | ||
|
|
||
| async def get_tasks(self, pushdowns: Pushdowns) -> AsyncIterator[DataSourceTask]: | ||
| async for task in self._wrapped.get_tasks(pushdowns): | ||
| yield task | ||
|
|
||
| _RegisteredDataSourceCls = _RegisteredDataSource | ||
| return _RegisteredDataSourceCls | ||
|
|
||
|
|
||
| def _read_registered_data_source(name: str, inner: DataSource) -> DataFrame: | ||
| registered_data_source_cls: Any = _registered_data_source_cls() | ||
| return registered_data_source_cls(name, inner).read() | ||
|
|
There was a problem hiding this comment.
All of this should be in the rust session, and we shouldn't have any registration code in python
Summary
daft.data_sources.register(...)anddaft.read_source(...)Validation
.venv/bin/ruff check daft/session.py daft/__init__.py daft/data_sources.py tests/io/test_data_source_registry.pyDAFT_RUNNER=native make test EXTRA_ARGS="-q tests/io/test_data_source_registry.py"