Skip to content

Commit f16d730

Browse files
committed
Look up the codec implementation
1 parent d558ef8 commit f16d730

3 files changed

Lines changed: 37 additions & 3 deletions

File tree

src/zarr/core/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from zarr.codecs._v2 import V2Codec
3232
from zarr.codecs.bytes import BytesCodec
3333
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
34-
from zarr.codecs.zstd import ZstdCodec
3534
from zarr.core._info import ArrayInfo
3635
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, parse_array_config
3736
from zarr.core.attributes import Attributes
@@ -124,6 +123,7 @@
124123
_parse_array_array_codec,
125124
_parse_array_bytes_codec,
126125
_parse_bytes_bytes_codec,
126+
get_codec_class,
127127
get_pipeline_class,
128128
)
129129
from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path
@@ -4686,9 +4686,9 @@ def default_compressors_v3(dtype: ZDType[Any, Any]) -> tuple[BytesBytesCodec, ..
46864686
"""
46874687
Given a data type, return the default compressors for that data type.
46884688
4689-
This is just a tuple containing ``ZstdCodec``
4689+
This is just a tuple containing an instance of the default "zstd" codec class.
46904690
"""
4691-
return (ZstdCodec(),)
4691+
return (cast(BytesBytesCodec, get_codec_class("zstd")()),)
46924692

46934693

46944694
def default_serializer_v3(dtype: ZDType[Any, Any]) -> ArrayBytesCodec:

tests/test_codecs/test_codecs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
GzipCodec,
1717
ShardingCodec,
1818
TransposeCodec,
19+
ZstdCodec,
1920
)
2021
from zarr.core.buffer import default_buffer_prototype
2122
from zarr.core.indexing import BasicSelection, morton_order_iter
2223
from zarr.core.metadata.v3 import ArrayV3Metadata
24+
from zarr.registry import register_codec
2325
from zarr.storage import StorePath
2426

2527
if TYPE_CHECKING:
@@ -413,3 +415,22 @@ async def test_resize(store: Store) -> None:
413415
assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None
414416
assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is None
415417
assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is None
418+
419+
420+
def test_uses_default_codec() -> None:
421+
class MyZstdCodec(ZstdCodec):
422+
pass
423+
424+
register_codec("zstd", MyZstdCodec)
425+
426+
with zarr.config.set(
427+
{"codecs": {"zstd": f"{MyZstdCodec.__module__}.{MyZstdCodec.__qualname__}"}}
428+
):
429+
a = zarr.create_array(
430+
StorePath(zarr.storage.MemoryStore(), path="mycodec"),
431+
shape=(10, 10),
432+
chunks=(10, 10),
433+
dtype="int32",
434+
)
435+
assert a.metadata.zarr_format == 3
436+
assert isinstance(a.metadata.codecs[-1], MyZstdCodec)

tests/test_codecs/test_nvcomp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,16 @@ def test_invalid_raises() -> None:
4343

4444
with pytest.raises(TypeError):
4545
NvcompZstdCodec(checksum="False") # type: ignore[arg-type]
46+
47+
48+
@gpu_test
49+
def test_uses_default_codec() -> None:
50+
with zarr.config.enable_gpu():
51+
a = zarr.create_array(
52+
StorePath(zarr.storage.MemoryStore(), path="nvcomp_zstd"),
53+
shape=(10, 10),
54+
chunks=(10, 10),
55+
dtype="int32",
56+
)
57+
assert a.metadata.zarr_format == 3
58+
assert isinstance(a.metadata.codecs[-1], NvcompZstdCodec)

0 commit comments

Comments
 (0)