Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,31 @@ def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool
protocol = spec.get("protocol", "").lower()
supported_protocols = ("file", "s3", "gcs", "azure")
if protocol not in supported_protocols:
raise DataJointError(
f'Missing or invalid protocol in config.stores["{store}"]. '
f"Supported protocols: {', '.join(supported_protocols)}"
from .storage_adapter import get_storage_adapter

adapter = get_storage_adapter(protocol)
if adapter is None:
raise DataJointError(
f'Unknown protocol "{protocol}" in config.stores["{store}"]. '
f"Built-in: {', '.join(supported_protocols)}. "
f"Install a plugin package for additional protocols."
)
# Apply common defaults for plugin protocols
spec.setdefault("subfolding", None)
spec.setdefault("partition_pattern", None)
spec.setdefault("token_length", 8)
spec.setdefault("hash_prefix", "_hash")
spec.setdefault("schema_prefix", "_schema")
spec.setdefault("filepath_prefix", None)
spec.setdefault("location", "")
adapter.validate_spec(spec)
self._validate_prefix_separation(
store_name=store,
hash_prefix=spec.get("hash_prefix"),
schema_prefix=spec.get("schema_prefix"),
filepath_prefix=spec.get("filepath_prefix"),
)
return spec

# Set protocol-specific defaults
if protocol == "s3":
Expand Down
20 changes: 17 additions & 3 deletions src/datajoint/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,12 @@ def _create_filesystem(self) -> fsspec.AbstractFileSystem:
)

else:
raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}")
from .storage_adapter import get_storage_adapter

adapter = get_storage_adapter(self.protocol)
if adapter is None:
raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}")
return adapter.create_filesystem(self.spec)

def _full_path(self, path: str | PurePosixPath) -> str:
"""
Expand Down Expand Up @@ -398,7 +403,12 @@ def _full_path(self, path: str | PurePosixPath) -> str:
return f"{bucket}/{location}/{path}"
return f"{bucket}/{path}"
else:
# Local filesystem - prepend location if specified
from .storage_adapter import get_storage_adapter

adapter = get_storage_adapter(self.protocol)
if adapter is not None:
return adapter.full_path(self.spec, path)
# File-protocol fallback
location = self.spec.get("location", "")
if location:
return str(Path(location) / path)
Expand Down Expand Up @@ -448,7 +458,11 @@ def get_url(self, path: str | PurePosixPath) -> str:
elif self.protocol == "azure":
return f"az://{full_path}"
else:
# Fallback: use protocol prefix
from .storage_adapter import get_storage_adapter

adapter = get_storage_adapter(self.protocol)
if adapter is not None:
return adapter.get_url(self.spec, full_path)
return f"{self.protocol}://{full_path}"

def put_file(self, local_path: str | Path, remote_path: str | PurePosixPath, metadata: dict | None = None) -> None:
Expand Down
109 changes: 109 additions & 0 deletions src/datajoint/storage_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Plugin system for third-party storage protocols.

Third-party packages register adapters via entry points::

[project.entry-points."datajoint.storage"]
myprotocol = "my_package:MyStorageAdapter"

The adapter is auto-discovered when DataJoint encounters the protocol name
in a store configuration. No explicit import is needed.
"""

from abc import ABC, abstractmethod
from typing import Any
import logging

import fsspec

from . import errors

logger = logging.getLogger(__name__)


class StorageAdapter(ABC):
"""Base class for storage protocol adapters.

Subclass this and declare an entry point to add a new storage protocol
to DataJoint. At minimum, implement ``create_filesystem`` and set
``protocol``, ``required_keys``, and ``allowed_keys``.
"""

protocol: str
required_keys: tuple[str, ...] = ()
allowed_keys: tuple[str, ...] = ()

@abstractmethod
def create_filesystem(self, spec: dict[str, Any]) -> fsspec.AbstractFileSystem:
"""Return an fsspec filesystem instance for this protocol."""
...

def validate_spec(self, spec: dict[str, Any]) -> None:
"""Validate protocol-specific config fields."""
missing = [k for k in self.required_keys if k not in spec]
if missing:
raise errors.DataJointError(f'{self.protocol} store is missing: {", ".join(missing)}')
all_allowed = set(self.allowed_keys) | _COMMON_STORE_KEYS
invalid = [k for k in spec if k not in all_allowed]
if invalid:
raise errors.DataJointError(f'Invalid key(s) for {self.protocol}: {", ".join(invalid)}')

def full_path(self, spec: dict[str, Any], relpath: str) -> str:
"""Construct storage path from a relative path."""
location = spec.get("location", "")
return f"{location}/{relpath}" if location else relpath

def get_url(self, spec: dict[str, Any], path: str) -> str:
"""Return a display URL for the stored object."""
return f"{self.protocol}://{path}"


_COMMON_STORE_KEYS = frozenset(
{
"protocol",
"location",
"subfolding",
"partition_pattern",
"token_length",
"hash_prefix",
"schema_prefix",
"filepath_prefix",
"stage",
}
)

_adapter_registry: dict[str, StorageAdapter] = {}
_adapters_loaded: bool = False


def get_storage_adapter(protocol: str) -> StorageAdapter | None:
"""Look up a registered storage adapter by protocol name."""
global _adapters_loaded
if not _adapters_loaded:
_discover_adapters()
_adapters_loaded = True
return _adapter_registry.get(protocol)


def _discover_adapters() -> None:
"""Load storage adapters from datajoint.storage entry points."""
try:
from importlib.metadata import entry_points
except ImportError:
logger.debug("importlib.metadata not available, skipping adapter discovery")
return

try:
eps = entry_points(group="datajoint.storage")
except TypeError:
eps = entry_points().get("datajoint.storage", [])

for ep in eps:
if ep.name in _adapter_registry:
continue
try:
adapter_cls = ep.load()
adapter = adapter_cls()
_adapter_registry[adapter.protocol] = adapter
logger.debug(f"Loaded storage adapter: {adapter.protocol}")
except Exception as e:
logger.warning(f"Failed to load storage adapter '{ep.name}': {e}")
Loading
Loading