diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index f5881793f..958f7f816 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -434,10 +434,31 @@ def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool protocol = spec.get("protocol", "").lower() supported_protocols = ("file", "s3", "gcs", "azure") if protocol not in supported_protocols: - raise DataJointError( - f'Missing or invalid protocol in config.stores["{store}"]. ' - f"Supported protocols: {', '.join(supported_protocols)}" + from .storage_adapter import get_storage_adapter + + adapter = get_storage_adapter(protocol) + if adapter is None: + raise DataJointError( + f'Unknown protocol "{protocol}" in config.stores["{store}"]. ' + f"Built-in: {', '.join(supported_protocols)}. " + f"Install a plugin package for additional protocols." + ) + # Apply common defaults for plugin protocols + spec.setdefault("subfolding", None) + spec.setdefault("partition_pattern", None) + spec.setdefault("token_length", 8) + spec.setdefault("hash_prefix", "_hash") + spec.setdefault("schema_prefix", "_schema") + spec.setdefault("filepath_prefix", None) + spec.setdefault("location", "") + adapter.validate_spec(spec) + self._validate_prefix_separation( + store_name=store, + hash_prefix=spec.get("hash_prefix"), + schema_prefix=spec.get("schema_prefix"), + filepath_prefix=spec.get("filepath_prefix"), ) + return spec # Set protocol-specific defaults if protocol == "s3": diff --git a/src/datajoint/storage.py b/src/datajoint/storage.py index 86bcd6af0..0a99e4d60 100644 --- a/src/datajoint/storage.py +++ b/src/datajoint/storage.py @@ -368,7 +368,12 @@ def _create_filesystem(self) -> fsspec.AbstractFileSystem: ) else: - raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}") + from .storage_adapter import get_storage_adapter + + adapter = get_storage_adapter(self.protocol) + if adapter is None: + raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}") + return adapter.create_filesystem(self.spec) def _full_path(self, path: str | PurePosixPath) -> str: """ @@ -398,7 +403,12 @@ def _full_path(self, path: str | PurePosixPath) -> str: return f"{bucket}/{location}/{path}" return f"{bucket}/{path}" else: - # Local filesystem - prepend location if specified + from .storage_adapter import get_storage_adapter + + adapter = get_storage_adapter(self.protocol) + if adapter is not None: + return adapter.full_path(self.spec, path) + # File-protocol fallback location = self.spec.get("location", "") if location: return str(Path(location) / path) @@ -448,7 +458,11 @@ def get_url(self, path: str | PurePosixPath) -> str: elif self.protocol == "azure": return f"az://{full_path}" else: - # Fallback: use protocol prefix + from .storage_adapter import get_storage_adapter + + adapter = get_storage_adapter(self.protocol) + if adapter is not None: + return adapter.get_url(self.spec, full_path) return f"{self.protocol}://{full_path}" def put_file(self, local_path: str | Path, remote_path: str | PurePosixPath, metadata: dict | None = None) -> None: diff --git a/src/datajoint/storage_adapter.py b/src/datajoint/storage_adapter.py new file mode 100644 index 000000000..b304586b2 --- /dev/null +++ b/src/datajoint/storage_adapter.py @@ -0,0 +1,109 @@ +"""Plugin system for third-party storage protocols. + +Third-party packages register adapters via entry points:: + + [project.entry-points."datajoint.storage"] + myprotocol = "my_package:MyStorageAdapter" + +The adapter is auto-discovered when DataJoint encounters the protocol name +in a store configuration. No explicit import is needed. +""" + +from abc import ABC, abstractmethod +from typing import Any +import logging + +import fsspec + +from . import errors + +logger = logging.getLogger(__name__) + + +class StorageAdapter(ABC): + """Base class for storage protocol adapters. + + Subclass this and declare an entry point to add a new storage protocol + to DataJoint. At minimum, implement ``create_filesystem`` and set + ``protocol``, ``required_keys``, and ``allowed_keys``. + """ + + protocol: str + required_keys: tuple[str, ...] = () + allowed_keys: tuple[str, ...] = () + + @abstractmethod + def create_filesystem(self, spec: dict[str, Any]) -> fsspec.AbstractFileSystem: + """Return an fsspec filesystem instance for this protocol.""" + ... + + def validate_spec(self, spec: dict[str, Any]) -> None: + """Validate protocol-specific config fields.""" + missing = [k for k in self.required_keys if k not in spec] + if missing: + raise errors.DataJointError(f'{self.protocol} store is missing: {", ".join(missing)}') + all_allowed = set(self.allowed_keys) | _COMMON_STORE_KEYS + invalid = [k for k in spec if k not in all_allowed] + if invalid: + raise errors.DataJointError(f'Invalid key(s) for {self.protocol}: {", ".join(invalid)}') + + def full_path(self, spec: dict[str, Any], relpath: str) -> str: + """Construct storage path from a relative path.""" + location = spec.get("location", "") + return f"{location}/{relpath}" if location else relpath + + def get_url(self, spec: dict[str, Any], path: str) -> str: + """Return a display URL for the stored object.""" + return f"{self.protocol}://{path}" + + +_COMMON_STORE_KEYS = frozenset( + { + "protocol", + "location", + "subfolding", + "partition_pattern", + "token_length", + "hash_prefix", + "schema_prefix", + "filepath_prefix", + "stage", + } +) + +_adapter_registry: dict[str, StorageAdapter] = {} +_adapters_loaded: bool = False + + +def get_storage_adapter(protocol: str) -> StorageAdapter | None: + """Look up a registered storage adapter by protocol name.""" + global _adapters_loaded + if not _adapters_loaded: + _discover_adapters() + _adapters_loaded = True + return _adapter_registry.get(protocol) + + +def _discover_adapters() -> None: + """Load storage adapters from datajoint.storage entry points.""" + try: + from importlib.metadata import entry_points + except ImportError: + logger.debug("importlib.metadata not available, skipping adapter discovery") + return + + try: + eps = entry_points(group="datajoint.storage") + except TypeError: + eps = entry_points().get("datajoint.storage", []) + + for ep in eps: + if ep.name in _adapter_registry: + continue + try: + adapter_cls = ep.load() + adapter = adapter_cls() + _adapter_registry[adapter.protocol] = adapter + logger.debug(f"Loaded storage adapter: {adapter.protocol}") + except Exception as e: + logger.warning(f"Failed to load storage adapter '{ep.name}': {e}") diff --git a/tests/unit/test_storage_adapter.py b/tests/unit/test_storage_adapter.py new file mode 100644 index 000000000..9808289bf --- /dev/null +++ b/tests/unit/test_storage_adapter.py @@ -0,0 +1,215 @@ +"""Tests for the StorageAdapter plugin system.""" + +import pytest + +import datajoint as dj +from datajoint.errors import DataJointError +from datajoint.storage import StorageBackend +from datajoint.storage_adapter import ( + StorageAdapter, + _adapter_registry, + _COMMON_STORE_KEYS, + get_storage_adapter, +) + + +class _DummyAdapter(StorageAdapter): + """Test adapter for registry tests.""" + + protocol = "dummy" + required_keys = ("protocol", "endpoint") + allowed_keys = ("protocol", "endpoint", "token") + + def create_filesystem(self, spec): + return None # Not testing actual filesystem creation + + +class TestStorageAdapterRegistry: + def setup_method(self): + _adapter_registry["dummy"] = _DummyAdapter() + + def teardown_method(self): + _adapter_registry.pop("dummy", None) + + def test_get_registered_adapter(self): + adapter = get_storage_adapter("dummy") + assert adapter is not None + assert adapter.protocol == "dummy" + + def test_get_unknown_adapter_returns_none(self): + adapter = get_storage_adapter("nonexistent_protocol_xyz") + assert adapter is None + + def test_adapter_protocol_attribute(self): + adapter = get_storage_adapter("dummy") + assert isinstance(adapter.protocol, str) + assert adapter.protocol == "dummy" + + +class TestStorageAdapterValidation: + def setup_method(self): + self.adapter = _DummyAdapter() + + def test_valid_spec_passes(self): + spec = {"protocol": "dummy", "endpoint": "https://example.com"} + self.adapter.validate_spec(spec) + + def test_missing_required_key_raises(self): + spec = {"protocol": "dummy"} + with pytest.raises(DataJointError, match="missing.*endpoint"): + self.adapter.validate_spec(spec) + + def test_invalid_key_raises(self): + spec = {"protocol": "dummy", "endpoint": "https://example.com", "bogus": "val"} + with pytest.raises(DataJointError, match="Invalid.*bogus"): + self.adapter.validate_spec(spec) + + def test_common_store_keys_always_allowed(self): + spec = { + "protocol": "dummy", + "endpoint": "https://example.com", + "hash_prefix": "_hash", + "subfolding": None, + "schema_prefix": "_schema", + } + self.adapter.validate_spec(spec) + + def test_common_store_keys_content(self): + assert "hash_prefix" in _COMMON_STORE_KEYS + assert "schema_prefix" in _COMMON_STORE_KEYS + assert "subfolding" in _COMMON_STORE_KEYS + assert "protocol" in _COMMON_STORE_KEYS + assert "location" in _COMMON_STORE_KEYS + + +class TestStorageAdapterFullPath: + def setup_method(self): + self.adapter = _DummyAdapter() + + def test_full_path_with_location(self): + spec = {"location": "data/blobs"} + assert self.adapter.full_path(spec, "schema/ab/cd/hash") == "data/blobs/schema/ab/cd/hash" + + def test_full_path_empty_location(self): + spec = {"location": ""} + assert self.adapter.full_path(spec, "schema/ab/cd/hash") == "schema/ab/cd/hash" + + def test_full_path_no_location_key(self): + spec = {} + assert self.adapter.full_path(spec, "schema/ab/cd/hash") == "schema/ab/cd/hash" + + +class TestStorageAdapterGetUrl: + def setup_method(self): + self.adapter = _DummyAdapter() + + def test_default_url_format(self): + assert self.adapter.get_url({}, "data/file.dat") == "dummy://data/file.dat" + + +class _FakeFS: + """Minimal fake fsspec filesystem for testing.""" + + protocol = "dummy" + + +class _FSAdapter(StorageAdapter): + """Adapter that returns a fake filesystem.""" + + protocol = "testfs" + required_keys = ("protocol",) + allowed_keys = ("protocol",) + + def create_filesystem(self, spec): + return _FakeFS() + + def get_url(self, spec, path): + return f"https://test.example.com/{path}" + + +class TestStorageBackendPluginDelegation: + """Tests for plugin delegation in StorageBackend methods.""" + + def setup_method(self): + import datajoint.storage_adapter as sa_mod + + sa_mod._adapter_registry["testfs"] = _FSAdapter() + + def teardown_method(self): + import datajoint.storage_adapter as sa_mod + + sa_mod._adapter_registry.pop("testfs", None) + + def test_create_filesystem_delegates_to_adapter(self): + backend = StorageBackend.__new__(StorageBackend) + backend.spec = {"protocol": "testfs"} + backend.protocol = "testfs" + backend._fs = None + fs = backend._create_filesystem() + assert isinstance(fs, _FakeFS) + + def test_full_path_delegates_to_adapter(self): + backend = StorageBackend.__new__(StorageBackend) + backend.spec = {"protocol": "testfs", "location": "data"} + backend.protocol = "testfs" + result = backend._full_path("schema/ab/cd/hash123") + assert result == "data/schema/ab/cd/hash123" + + def test_full_path_empty_location(self): + backend = StorageBackend.__new__(StorageBackend) + backend.spec = {"protocol": "testfs", "location": ""} + backend.protocol = "testfs" + result = backend._full_path("schema/ab/cd/hash123") + assert result == "schema/ab/cd/hash123" + + def test_get_url_delegates_to_adapter(self): + backend = StorageBackend.__new__(StorageBackend) + backend.spec = {"protocol": "testfs", "location": ""} + backend.protocol = "testfs" + result = backend.get_url("schema/file.dat") + assert result == "https://test.example.com/schema/file.dat" + + def test_unsupported_protocol_error(self): + backend = StorageBackend.__new__(StorageBackend) + backend.spec = {"protocol": "totally_unknown_xyz"} + backend.protocol = "totally_unknown_xyz" + backend._fs = None + with pytest.raises(DataJointError, match="Unsupported storage protocol"): + backend._create_filesystem() + + +class TestGetStoreSpecPluginDelegation: + """Tests for plugin protocol handling in Config.get_store_spec().""" + + def setup_method(self): + import datajoint.storage_adapter as sa_mod + + sa_mod._adapter_registry["dummy"] = _DummyAdapter() + self._original_stores = dj.config.stores.copy() + + def teardown_method(self): + import datajoint.storage_adapter as sa_mod + + sa_mod._adapter_registry.pop("dummy", None) + dj.config.stores = self._original_stores + + def test_plugin_protocol_accepted(self): + """Plugin protocol passes validation via adapter.""" + dj.config.stores["test_store"] = { + "protocol": "dummy", + "endpoint": "https://example.com", + "location": "", + "hash_prefix": "_hash", + "schema_prefix": "_schema", + } + spec = dj.config.get_store_spec("test_store") + assert spec["protocol"] == "dummy" + + def test_unknown_protocol_error_message(self): + """Unknown protocol gives clear error mentioning plugin installation.""" + dj.config.stores["bad_store"] = { + "protocol": "nonexistent_xyz", + "location": "", + } + with pytest.raises(DataJointError, match="Install a plugin"): + dj.config.get_store_spec("bad_store")