Skip to content

Commit 3e5f3f5

Browse files
authored
Bug fix for region slices that have start or stop = None (#857)
1 parent 3e8994a commit 3e5f3f5

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

cubed/core/ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,16 @@ def _store_array(
212212
shape = target.shape
213213
chunks = target.chunks
214214
for i, (sl, cs) in enumerate(zip(region, chunks)):
215-
if sl.start % cs != 0 or (sl.stop % cs != 0 and sl.stop != shape[i]):
215+
if (sl.start is not None and sl.start % cs != 0) or (
216+
sl.stop is not None and sl.stop % cs != 0 and sl.stop != shape[i]
217+
):
216218
raise ValueError(
217219
f"Region {region} does not align with target chunks {chunks}"
218220
)
219-
block_offsets = [sl.start // cs for sl, cs in zip(region, chunks)]
221+
block_offsets = [
222+
(0 if sl.start is None else sl.start // cs)
223+
for sl, cs in zip(region, chunks)
224+
]
220225

221226
def key_function(out_key):
222227
out_coords = out_key[1:]

cubed/tests/test_core.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,13 @@ def test_to_zarr_array(tmp_path, spec, executor):
190190

191191
def test_to_zarr_region(tmp_path, spec, executor):
192192
a = xp.asarray(
193-
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
193+
[
194+
[1, 2, 3, 4],
195+
[5, 6, 7, 8],
196+
[9, 10, 11, 12],
197+
[13, 14, 15, 16],
198+
[17, 18, 19, 20],
199+
],
194200
chunks=(2, 2),
195201
spec=spec,
196202
)
@@ -272,6 +278,22 @@ def test_to_zarr_region(tmp_path, spec, executor):
272278
),
273279
)
274280

281+
region = (slice(None), slice(4, 5))
282+
cubed.to_zarr(a[:, 0:1], z, region=region, executor=executor)
283+
res = open_storage_array(store, mode="r")
284+
assert_array_equal(
285+
res[:],
286+
np.array(
287+
[
288+
[1, 2, 0, 0, 1],
289+
[5, 6, 0, 0, 5],
290+
[1, 2, 1, 2, 9],
291+
[5, 6, 5, 6, 13],
292+
[0, 0, 9, 10, 17],
293+
]
294+
),
295+
)
296+
275297

276298
def test_to_zarr_region_fails(tmp_path):
277299
a = xp.ones((4, 4), chunks=(2, 2))

0 commit comments

Comments
 (0)