-
Notifications
You must be signed in to change notification settings - Fork 1
feat: add tests and refactor codec #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
815f16d
20231fb
ccf829d
f6be988
5e4aa33
cf58710
d7d7eda
aa5f2f3
862f208
176c12a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| name: Tests | ||
|
|
||
| on: | ||
| pull_request: | ||
| branches: [main] | ||
| push: | ||
| branches: [main] | ||
| workflow_dispatch: | ||
|
|
||
| jobs: | ||
| test: | ||
| runs-on: ubuntu-latest | ||
| strategy: | ||
| matrix: | ||
| python-version: ["3.11", "3.12", "3.13"] | ||
|
|
||
| steps: | ||
| - uses: actions/checkout@v4 | ||
|
|
||
| - name: Install uv | ||
| uses: astral-sh/setup-uv@v4 | ||
|
|
||
| - name: Set up Python ${{ matrix.python-version }} | ||
| run: uv python install ${{ matrix.python-version }} | ||
|
|
||
| - name: Install dependencies | ||
| run: uv sync --dev | ||
|
|
||
| - name: Run tests | ||
| run: uv run pytest tests -v |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # C extensions | ||
| *.so | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Translations | ||
| *.mo | ||
| *.pot | ||
|
|
||
| # Django stuff: | ||
| *.log | ||
| local_settings.py | ||
| db.sqlite3 | ||
| db.sqlite3-journal | ||
|
|
||
| # Flask stuff: | ||
| instance/ | ||
| .webassets-cache | ||
|
|
||
| # Scrapy stuff: | ||
| .scrapy | ||
|
|
||
| # Sphinx documentation | ||
| docs/_build/ | ||
|
|
||
| # PyBuilder | ||
| .pybuilder/ | ||
| target/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # IPython | ||
| profile_default/ | ||
| ipython_config.py | ||
|
|
||
| # pipenv | ||
| Pipfile.lock | ||
|
|
||
| # PEP 582 | ||
| __pypackages__/ | ||
|
|
||
| # Celery stuff | ||
| celerybeat-schedule | ||
| celerybeat.pid | ||
|
|
||
| # SageMath parsed files | ||
| *.sage.py | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # Spyder project settings | ||
| .spyderproject | ||
| .spyproject | ||
|
|
||
| # Rope project settings | ||
| .ropeproject | ||
|
|
||
| # mkdocs documentation | ||
| /site | ||
|
|
||
| # mypy | ||
| .mypy_cache/ | ||
| .dmypy.json | ||
| dmypy.json | ||
|
|
||
| # Pyre type checker | ||
| .pyre/ | ||
|
|
||
| # pytype static type analyzer | ||
| .pytype/ | ||
|
|
||
| # Cython debug symbols | ||
| cython_debug/ | ||
|
|
||
| # IDE | ||
| .idea/ | ||
| .vscode/ | ||
| *.swp | ||
| *.swo | ||
| *~ | ||
|
|
||
| # OS | ||
| .DS_Store | ||
| Thumbs.db |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 3.11 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ name = "dj-zarr-codecs" | |
| version = "0.1.0" | ||
| description = "DataJoint codecs for Zarr array storage" | ||
| readme = "README.md" | ||
| requires-python = ">=3.10" | ||
| requires-python = ">=3.11" | ||
| license = {text = "MIT"} | ||
| authors = [ | ||
| {name = "Davis Bennett", email = "davis.v.bennett@gmail.com"}, | ||
|
|
@@ -24,14 +24,15 @@ classifiers = [ | |
| dependencies = [ | ||
| "datajoint>=2.0.0a22", | ||
| "zarr>=2.0", | ||
| "numpy>=1.20", | ||
| "numpy>=2.0", | ||
| ] | ||
|
|
||
| [project.optional-dependencies] | ||
| [dependency-groups] | ||
| dev = [ | ||
| "pytest>=7.0", | ||
| "pytest-cov>=4.0", | ||
| "ruff>=0.1.0", | ||
| "testcontainers[mysql]>=4.14.0", | ||
| ] | ||
|
|
||
| [project.urls] | ||
|
|
@@ -47,12 +48,14 @@ zarr = "dj_zarr_codecs:ZarrCodec" | |
| requires = ["setuptools>=61.0"] | ||
| build-backend = "setuptools.build_meta" | ||
|
|
||
| [tool.uv.sources] | ||
| datajoint = {git = "https://github.com/datajoint/datajoint-python.git"} | ||
|
|
||
| [tool.setuptools.packages.find] | ||
| where = ["src"] | ||
|
|
||
| [tool.pytest.ini_options] | ||
| testpaths = ["tests"] | ||
| addopts = "-v --cov=dj_zarr_codecs --cov-report=term-missing" | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed these because they generated a lot of noise |
||
|
|
||
| [tool.ruff] | ||
| line-length = 100 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """DataJoint codecs for Zarr array storage.""" | ||
|
|
||
| from .codecs import ZarrCodec | ||
| from .codecs import ZarrArrayCodec | ||
|
|
||
| __version__ = "0.1.0" | ||
| __all__ = ["ZarrCodec"] | ||
| __all__ = ["ZarrArrayCodec"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,18 +6,11 @@ | |
|
|
||
| import numpy as np | ||
| import zarr | ||
| import datajoint as dj | ||
| from datajoint import DataJointError | ||
| from datajoint.builtin_codecs import SchemaCodec | ||
|
|
||
| try: | ||
| import datajoint as dj | ||
| from datajoint import DataJointError | ||
| from datajoint.builtin_codecs import SchemaCodec | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "datajoint>=2.0.0a22 is required. Install with: pip install 'datajoint>=2.0.0a22'" | ||
| ) from e | ||
|
|
||
|
|
||
| class ZarrCodec(SchemaCodec): | ||
| class ZarrArrayCodec(SchemaCodec): | ||
| """ | ||
| Store numpy arrays in Zarr format with schema-addressed paths. | ||
|
|
||
|
|
@@ -78,7 +71,6 @@ class Recording(dj.Manual): | |
| """ | ||
|
|
||
| name = "zarr" | ||
| CODEC_VERSION = "1.0" # Data format version for backward compatibility | ||
|
|
||
| def validate(self, value: Any) -> None: | ||
| """ | ||
|
|
@@ -94,16 +86,16 @@ def validate(self, value: Any) -> None: | |
| DataJointError | ||
| If value is not a numpy array or has object dtype. | ||
| """ | ||
| if not isinstance(value, np.ndarray): | ||
| raise DataJointError( | ||
| f"<zarr> requires numpy.ndarray, got {type(value).__name__}" | ||
| if not isinstance(value, np.ndarray | zarr.Array): | ||
| raise TypeError( | ||
| f"<zarr> requires a Numpy array or Zarr array, got {type(value).__name__}" | ||
|
Comment on lines
+89
to
+91
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this raises a |
||
| ) | ||
| if value.dtype == object: | ||
| raise DataJointError("<zarr> does not support object dtype arrays") | ||
|
|
||
| def encode( | ||
| self, | ||
| value: np.ndarray, | ||
| value: np.ndarray | zarr.Array, | ||
| *, | ||
| key: dict | None = None, | ||
| store_name: str | None = None, | ||
|
|
@@ -130,40 +122,35 @@ def encode( | |
| DataJointError | ||
| If encoding fails. | ||
| """ | ||
| try: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is no longer wrapped in a try... except block (a bad LLM habit) |
||
| # Extract context from key | ||
| schema, table, field, primary_key = self._extract_context(key) | ||
|
|
||
| # Build schema-addressed path | ||
| path, _token = self._build_path( | ||
| schema, table, field, primary_key, ext=".zarr", store_name=store_name | ||
| ) | ||
|
|
||
| # Get storage backend | ||
| backend = self._get_backend(store_name) | ||
|
|
||
| # Get fsspec mapper for direct Zarr write | ||
| store_map = backend.get_fsmap(path) | ||
|
|
||
| # Write array to Zarr format | ||
| zarr.save_array(store_map, value) | ||
|
|
||
| # Store version metadata in Zarr attributes | ||
| z = zarr.open(store_map, mode="r+") | ||
| z.attrs["codec_version"] = self.CODEC_VERSION | ||
| z.attrs["codec_name"] = self.name | ||
|
|
||
| # Return metadata for database storage | ||
| return { | ||
| "path": path, | ||
| "store": store_name, | ||
| "codec_version": self.CODEC_VERSION, | ||
| "shape": list(value.shape), | ||
| "dtype": str(value.dtype), | ||
| } | ||
|
|
||
| except Exception as e: | ||
| raise DataJointError(f"Failed to encode Zarr array: {e}") from e | ||
| # import here to avoid circular import | ||
| from dj_zarr_codecs import __version__ | ||
| # Extract context from key | ||
| schema, table, field, primary_key = self._extract_context(key) | ||
|
|
||
| # Build schema-addressed path | ||
| path, _token = self._build_path( | ||
| schema, table, field, primary_key, ext=".zarr", store_name=store_name | ||
| ) | ||
|
|
||
| # Get storage backend | ||
| backend = self._get_backend(store_name) | ||
|
|
||
| # Get fsspec mapper for direct Zarr write | ||
| store_map = backend.get_fsmap(path) | ||
|
|
||
| zarr.create_array(store=store_map, data=value, write_data=True) | ||
|
|
||
| # Return metadata for database storage (stored as JSON column) | ||
| return { | ||
| "path": path, | ||
| "store": store_name, | ||
| "shape": list(value.shape), | ||
| "dtype": str(value.dtype), | ||
| "provenance": { | ||
| "datajoint-python": dj.__version__, | ||
| "dj-zarr-codecs": __version__, | ||
| }, | ||
|
Comment on lines
+149
to
+152
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rows in this table have a provenance column that stores JSON documents conveying the state of the software that inserted the column.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. excellent. |
||
| } | ||
|
|
||
| def decode(self, stored: dict, *, key: dict | None = None) -> zarr.Array: | ||
| """ | ||
|
|
@@ -187,30 +174,11 @@ def decode(self, stored: dict, *, key: dict | None = None) -> zarr.Array: | |
| DataJointError | ||
| If decoding fails. | ||
| """ | ||
| try: | ||
| # Get storage backend | ||
| backend = self._get_backend(stored.get("store")) | ||
| # Get storage backend | ||
| backend = self._get_backend(stored.get("store")) | ||
|
|
||
| # Get fsspec mapper for Zarr path | ||
| store_map = backend.get_fsmap(stored["path"]) | ||
|
|
||
| # Open Zarr array (read-only) | ||
| z = zarr.open(store_map, mode="r") | ||
|
|
||
| # Check codec version for backward compatibility | ||
| # Priority: Zarr attrs > DB metadata > default "1.0" | ||
| version = z.attrs.get( | ||
| "codec_version", stored.get("codec_version", "1.0") | ||
| ) | ||
|
Comment on lines
-200
to
-204
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. version guards are removed because it isn't clear what "1.0" vs "2.0" would mean here |
||
| # Get fsspec mapper for Zarr path | ||
| store_map = backend.get_fsmap(stored["path"]) | ||
|
|
||
| # All v1.x versions are compatible | ||
| if version.startswith("1."): | ||
| return z | ||
| else: | ||
| raise DataJointError( | ||
| f"Unsupported zarr codec version: {version}. " | ||
| f"Upgrade dj-zarr-codecs or migrate data." | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| raise DataJointError(f"Failed to decode Zarr array: {e}") from e | ||
| # Open Zarr array (read-only) | ||
| return zarr.open(store_map, mode="r") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will need to be changed when 2.0 lands on pypi (maybe it's already there)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, coming soon but not yet.