Skip to content

Commit 1b85fdc

Browse files
committed
fixed Buffer.__add__
1 parent 69ea74e commit 1b85fdc

1 file changed

Lines changed: 9 additions & 10 deletions

File tree

src/zarr/core/buffer/gpu.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
cast,
99
)
1010

11-
import numpy as np
12-
import numpy.typing as npt
13-
1411
from zarr.core.buffer import core
1512
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike
1613
from zarr.registry import (
@@ -22,6 +19,8 @@
2219
from collections.abc import Iterable
2320
from typing import Self
2421

22+
import numpy.typing as npt
23+
2524
from zarr.core.common import BytesLike, ChunkCoords
2625

2726
try:
@@ -106,13 +105,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
106105
return cast("npt.NDArray[Any]", cp.asnumpy(self._data))
107106

108107
def __add__(self, other: core.Buffer) -> Self:
109-
other_array = other.as_array_like()
110-
assert other_array.dtype == np.dtype("B")
111-
gpu_other = Buffer(other_array)
112-
gpu_other_array = gpu_other.as_array_like()
113-
return self.__class__(
114-
cp.concatenate((cp.asanyarray(self._data), cp.asanyarray(gpu_other_array)))
115-
)
108+
other_array = cp.asanyarray(other.as_array_like())
109+
left = self._data
110+
if left.dtype != other_array.dtype:
111+
other_array = other_array.view(left.dtype)
112+
113+
buffer = cp.concatenate([left, other_array])
114+
return type(self)(buffer)
116115

117116

118117
class NDBuffer(core.NDBuffer):

0 commit comments

Comments
 (0)