Skip to content

Commit 3caee31

Browse files
kushalbakshiclaude
andcommitted
feat: add StorageAdapter ABC and entry-point registry
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8d7a7d6 commit 3caee31

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

src/datajoint/storage_adapter.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Plugin system for third-party storage protocols.
2+
3+
Third-party packages register adapters via entry points::
4+
5+
[project.entry-points."datajoint.storage"]
6+
myprotocol = "my_package:MyStorageAdapter"
7+
8+
The adapter is auto-discovered when DataJoint encounters the protocol name
9+
in a store configuration. No explicit import is needed.
10+
"""
11+
12+
from abc import ABC, abstractmethod
13+
from typing import Any
14+
import logging
15+
16+
import fsspec
17+
18+
from . import errors
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class StorageAdapter(ABC):
24+
"""Base class for storage protocol adapters.
25+
26+
Subclass this and declare an entry point to add a new storage protocol
27+
to DataJoint. At minimum, implement ``create_filesystem`` and set
28+
``protocol``, ``required_keys``, and ``allowed_keys``.
29+
"""
30+
31+
protocol: str
32+
required_keys: tuple[str, ...] = ()
33+
allowed_keys: tuple[str, ...] = ()
34+
35+
@abstractmethod
36+
def create_filesystem(self, spec: dict[str, Any]) -> fsspec.AbstractFileSystem:
37+
"""Return an fsspec filesystem instance for this protocol."""
38+
...
39+
40+
def validate_spec(self, spec: dict[str, Any]) -> None:
41+
"""Validate protocol-specific config fields."""
42+
missing = [k for k in self.required_keys if k not in spec]
43+
if missing:
44+
raise errors.DataJointError(
45+
f'{self.protocol} store is missing: {", ".join(missing)}'
46+
)
47+
all_allowed = set(self.allowed_keys) | _COMMON_STORE_KEYS
48+
invalid = [k for k in spec if k not in all_allowed]
49+
if invalid:
50+
raise errors.DataJointError(
51+
f'Invalid key(s) for {self.protocol}: {", ".join(invalid)}'
52+
)
53+
54+
def full_path(self, spec: dict[str, Any], relpath: str) -> str:
55+
"""Construct storage path from a relative path."""
56+
location = spec.get("location", "")
57+
return f"{location}/{relpath}" if location else relpath
58+
59+
def get_url(self, spec: dict[str, Any], path: str) -> str:
60+
"""Return a display URL for the stored object."""
61+
return f"{self.protocol}://{path}"
62+
63+
64+
_COMMON_STORE_KEYS = frozenset({
65+
"protocol",
66+
"location",
67+
"subfolding",
68+
"partition_pattern",
69+
"token_length",
70+
"hash_prefix",
71+
"schema_prefix",
72+
"filepath_prefix",
73+
"stage",
74+
})
75+
76+
_adapter_registry: dict[str, StorageAdapter] = {}
77+
_adapters_loaded: bool = False
78+
79+
80+
def get_storage_adapter(protocol: str) -> StorageAdapter | None:
81+
"""Look up a registered storage adapter by protocol name."""
82+
global _adapters_loaded
83+
if not _adapters_loaded:
84+
_discover_adapters()
85+
_adapters_loaded = True
86+
return _adapter_registry.get(protocol)
87+
88+
89+
def _discover_adapters() -> None:
90+
"""Load storage adapters from datajoint.storage entry points."""
91+
try:
92+
from importlib.metadata import entry_points
93+
except ImportError:
94+
logger.debug("importlib.metadata not available, skipping adapter discovery")
95+
return
96+
97+
try:
98+
eps = entry_points(group="datajoint.storage")
99+
except TypeError:
100+
eps = entry_points().get("datajoint.storage", [])
101+
102+
for ep in eps:
103+
if ep.name in _adapter_registry:
104+
continue
105+
try:
106+
adapter_cls = ep.load()
107+
adapter = adapter_cls()
108+
_adapter_registry[adapter.protocol] = adapter
109+
logger.debug(f"Loaded storage adapter: {adapter.protocol}")
110+
except Exception as e:
111+
logger.warning(f"Failed to load storage adapter '{ep.name}': {e}")

0 commit comments

Comments
 (0)