Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/blosc2/blosc2_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2356,9 +2356,10 @@ cdef class slice_flatter:
cdef class NDArray:
cdef b2nd_array_t* array

def __init__(self, array):
def __init__(self, array, base=None):
self._dtype = None
self.array = <b2nd_array_t *> PyCapsule_GetPointer(array, <char *> "b2nd_array_t*")
self.base = base # add reference to base if NDArray is a view

@property
def shape(self) -> tuple[int]:
Expand Down Expand Up @@ -2996,5 +2997,7 @@ def expand_dims(arr1: NDArray, axis_mask: list[bool], final_dims: int) -> blosc2
mask_[i] = axis_mask[i]
_check_rc(b2nd_expand_dims(arr1.array, &view, mask_, final_dims),"Error while expanding the arrays")

# create view with reference to arr1 to hold onto
new_base = arr1 if arr1.base is None else arr1.base
return blosc2.NDArray(_schunk=PyCapsule_New(view.sc, <char *> "blosc2_schunk*", NULL),
_array=PyCapsule_New(view, <char *> "b2nd_array_t*", NULL))
_array=PyCapsule_New(view, <char *> "b2nd_array_t*", NULL), _base=new_base)
3 changes: 2 additions & 1 deletion src/blosc2/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ def __init__(self, **kwargs):
self._keep_last_read = False
# Where to store the last read data
self._last_read = {}
super().__init__(kwargs["_array"])
base = kwargs.pop("_base", None)
super().__init__(kwargs["_array"], base=base)
# Accessor to fields
self._fields = {}
if self.dtype.fields:
Expand Down
32 changes: 26 additions & 6 deletions tests/ndarray/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,36 @@ def test_resize(shape, new_shape, chunks, blocks, fill_value):
)
def test_expand_dims(shape, axis, chunks, blocks, fill_value):
a = blosc2.full(shape, fill_value=fill_value, chunks=chunks, blocks=blocks)

npa = a[:]
b = blosc2.expand_dims(a, axis=axis)
npa = np.expand_dims(a[:], axis)
assert npa.shape == b.shape
np.testing.assert_array_equal(npa, b[:])
npb = np.expand_dims(npa, axis)
assert npb.shape == b.shape
np.testing.assert_array_equal(npb, b[:])

# Repeated expansion
axis = (axis,) if isinstance(axis, int) else axis
axis = axis[0] if (len(axis) + b.ndim) > blosc2.MAX_DIM else axis
b = blosc2.expand_dims(b, axis=axis)
npb = np.expand_dims(npb, axis)
assert npb.shape == b.shape
np.testing.assert_array_equal(npb, b[:])

# Check that handling of views is correct
a = blosc2.expand_dims(a, axis=axis) # could lose ref to original array and thus dealloc data
npa = np.expand_dims(npa, axis)
assert npa.shape == b.shape
np.testing.assert_array_equal(npa, b[:])
assert a[()].shape == npa[()].shape # getitem fails if deallocate has happened

# Now check that garbage collecting works and there will be no memory leaks for views
import sys

arr = np.arange(4)
bloscarr_ = blosc2.asarray(arr)
assert sys.getrefcount(arr) == sys.getrefcount(bloscarr_) == 2

view = np.expand_dims(arr, 0)
bloscview = blosc2.expand_dims(bloscarr_, 0)
assert sys.getrefcount(arr) == sys.getrefcount(bloscarr_) == 3

del view
del bloscview
assert sys.getrefcount(arr) == sys.getrefcount(bloscarr_) == 2
Loading