Skip to content

Commit d5b14af

Browse files
committed
cupy 14 compat
1 parent fda4350 commit d5b14af

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

src/zarr/codecs/gpu.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def to_dict(self) -> dict[str, JSON]:
7474
def _zstd_codec(self) -> nvcomp.Codec:
7575
device = cp.cuda.Device() # Select the current default device
7676
stream = cp.cuda.get_current_stream() # Use the current default stream
77+
# Note: this returns an array with dtype=np.dtype("int8")
7778
return nvcomp.Codec(
7879
algorithm="Zstd",
7980
bitstream_kind=nvcomp.BitstreamKind.RAW,
@@ -95,12 +96,18 @@ def _convert_from_nvcomp_arrays(
9596
arrays: Iterable[nvcomp.Array],
9697
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
9798
) -> Iterable[Buffer | None]:
98-
return [
99-
spec.prototype.buffer.from_array_like(cp.array(a, dtype=np.dtype("B"), copy=False))
100-
if a
101-
else None
102-
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True)
103-
]
99+
result: list[Buffer | None] = []
100+
101+
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True):
102+
if a is None:
103+
result.append(None)
104+
else:
105+
a2 = cp.array(a, dtype=a.dtype, copy=False)
106+
if a2.dtype != np.dtype("B"):
107+
a2 = a2.view(dtype=np.dtype("B"))
108+
result.append(spec.prototype.buffer.from_array_like(a2))
109+
110+
return result
104111

105112
async def decode(
106113
self,

0 commit comments

Comments
 (0)