Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Tests

on:
pull_request:
branches: [main]
push:
branches: [main]
workflow_dispatch:

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install dependencies
run: uv sync --dev

- name: Run tests
run: uv run pytest tests -v
137 changes: 137 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pipenv
Pipfile.lock

# PEP 582
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# IDE
.idea/
.vscode/
*.swp
*.swo
*~

# OS
.DS_Store
Thumbs.db
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "dj-zarr-codecs"
version = "0.1.0"
description = "DataJoint codecs for Zarr array storage"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
license = {text = "MIT"}
authors = [
{name = "Davis Bennett", email = "davis.v.bennett@gmail.com"},
Expand All @@ -24,14 +24,15 @@ classifiers = [
dependencies = [
"datajoint>=2.0.0a22",
"zarr>=2.0",
"numpy>=1.20",
"numpy>=2.0",
]

[project.optional-dependencies]
[dependency-groups]
dev = [
"pytest>=7.0",
"pytest-cov>=4.0",
"ruff>=0.1.0",
"testcontainers[mysql]>=4.14.0",
]

[project.urls]
Expand All @@ -47,12 +48,14 @@ zarr = "dj_zarr_codecs:ZarrCodec"
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.uv.sources]
datajoint = {git = "https://github.com/datajoint/datajoint-python.git"}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will need to be changed when 2.0 lands on pypi (maybe it's already there)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, coming soon but not yet.


[tool.setuptools.packages.find]
where = ["src"]

[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-v --cov=dj_zarr_codecs --cov-report=term-missing"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these because they generated a lot of noise


[tool.ruff]
line-length = 100
Expand Down
4 changes: 2 additions & 2 deletions src/dj_zarr_codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""DataJoint codecs for Zarr array storage."""

from .codecs import ZarrCodec
from .codecs import ZarrArrayCodec

__version__ = "0.1.0"
__all__ = ["ZarrCodec"]
__all__ = ["ZarrArrayCodec"]
118 changes: 43 additions & 75 deletions src/dj_zarr_codecs/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@

import numpy as np
import zarr
import datajoint as dj
from datajoint import DataJointError
from datajoint.builtin_codecs import SchemaCodec

try:
import datajoint as dj
from datajoint import DataJointError
from datajoint.builtin_codecs import SchemaCodec
except ImportError as e:
raise ImportError(
"datajoint>=2.0.0a22 is required. Install with: pip install 'datajoint>=2.0.0a22'"
) from e


class ZarrCodec(SchemaCodec):
class ZarrArrayCodec(SchemaCodec):
"""
Store numpy arrays in Zarr format with schema-addressed paths.

Expand Down Expand Up @@ -78,7 +71,6 @@ class Recording(dj.Manual):
"""

name = "zarr"
CODEC_VERSION = "1.0" # Data format version for backward compatibility

def validate(self, value: Any) -> None:
"""
Expand All @@ -94,16 +86,16 @@ def validate(self, value: Any) -> None:
DataJointError
If value is not a numpy array or has object dtype.
"""
if not isinstance(value, np.ndarray):
raise DataJointError(
f"<zarr> requires numpy.ndarray, got {type(value).__name__}"
if not isinstance(value, np.ndarray | zarr.Array):
raise TypeError(
f"<zarr> requires a Numpy array or Zarr array, got {type(value).__name__}"
Comment on lines +89 to +91

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this raises a TypeError instead of a DataJointError, because the DataJointError does not convey to callers that the problem is related to values being the wrong type.

)
if value.dtype == object:
raise DataJointError("<zarr> does not support object dtype arrays")

def encode(
self,
value: np.ndarray,
value: np.ndarray | zarr.Array,
*,
key: dict | None = None,
store_name: str | None = None,
Expand All @@ -130,40 +122,35 @@ def encode(
DataJointError
If encoding fails.
"""
try:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer wrapped in a try... except block (a bad LLM habit)

# Extract context from key
schema, table, field, primary_key = self._extract_context(key)

# Build schema-addressed path
path, _token = self._build_path(
schema, table, field, primary_key, ext=".zarr", store_name=store_name
)

# Get storage backend
backend = self._get_backend(store_name)

# Get fsspec mapper for direct Zarr write
store_map = backend.get_fsmap(path)

# Write array to Zarr format
zarr.save_array(store_map, value)

# Store version metadata in Zarr attributes
z = zarr.open(store_map, mode="r+")
z.attrs["codec_version"] = self.CODEC_VERSION
z.attrs["codec_name"] = self.name

# Return metadata for database storage
return {
"path": path,
"store": store_name,
"codec_version": self.CODEC_VERSION,
"shape": list(value.shape),
"dtype": str(value.dtype),
}

except Exception as e:
raise DataJointError(f"Failed to encode Zarr array: {e}") from e
# import here to avoid circular import
from dj_zarr_codecs import __version__
# Extract context from key
schema, table, field, primary_key = self._extract_context(key)

# Build schema-addressed path
path, _token = self._build_path(
schema, table, field, primary_key, ext=".zarr", store_name=store_name
)

# Get storage backend
backend = self._get_backend(store_name)

# Get fsspec mapper for direct Zarr write
store_map = backend.get_fsmap(path)

zarr.create_array(store=store_map, data=value, write_data=True)

# Return metadata for database storage (stored as JSON column)
return {
"path": path,
"store": store_name,
"shape": list(value.shape),
"dtype": str(value.dtype),
"provenance": {
"datajoint-python": dj.__version__,
"dj-zarr-codecs": __version__,
},
Comment on lines +149 to +152

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rows in this table have a provenance column that stores JSON documents conveying the state of the software that inserted the column.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent.

}

def decode(self, stored: dict, *, key: dict | None = None) -> zarr.Array:
"""
Expand All @@ -187,30 +174,11 @@ def decode(self, stored: dict, *, key: dict | None = None) -> zarr.Array:
DataJointError
If decoding fails.
"""
try:
# Get storage backend
backend = self._get_backend(stored.get("store"))
# Get storage backend
backend = self._get_backend(stored.get("store"))

# Get fsspec mapper for Zarr path
store_map = backend.get_fsmap(stored["path"])

# Open Zarr array (read-only)
z = zarr.open(store_map, mode="r")

# Check codec version for backward compatibility
# Priority: Zarr attrs > DB metadata > default "1.0"
version = z.attrs.get(
"codec_version", stored.get("codec_version", "1.0")
)
Comment on lines -200 to -204

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version guards are removed because it isn't clear what "1.0" vs "2.0" would mean here

# Get fsspec mapper for Zarr path
store_map = backend.get_fsmap(stored["path"])

# All v1.x versions are compatible
if version.startswith("1."):
return z
else:
raise DataJointError(
f"Unsupported zarr codec version: {version}. "
f"Upgrade dj-zarr-codecs or migrate data."
)

except Exception as e:
raise DataJointError(f"Failed to decode Zarr array: {e}") from e
# Open Zarr array (read-only)
return zarr.open(store_map, mode="r")
Empty file added tests/__init__.py
Empty file.
Loading