Skip to content

Commit 15310df

Browse files
committed
Adding meshgrid
1 parent 7d54fe4 commit 15310df

3 files changed

Lines changed: 98 additions & 4 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,12 +898,16 @@ def validate_inputs(inputs: dict, out=None, reduce=False) -> tuple: # noqa: C90
898898
fast_path = False
899899
if first_input.blocks != out.blocks:
900900
fast_path = False
901+
if 0 in out.chunks: # fast_eval has zero division error for 0 shapes
902+
fast_path = False
901903
# Then, the rest of the operands
902904
for input_ in NDinputs:
903905
if first_input.chunks != input_.chunks:
904906
fast_path = False
905907
if first_input.blocks != input_.blocks:
906908
fast_path = False
909+
if 0 in input_.chunks: # fast_eval has zero division error for 0 shapes
910+
fast_path = False
907911

908912
return first_input.shape, first_input.chunks, first_input.blocks, fast_path
909913

@@ -2269,7 +2273,7 @@ def get_chunk(self, nchunk):
22692273
return out.schunk.get_chunk(nchunk)
22702274

22712275
def update_expr(self, new_op): # noqa: C901
2272-
prev_flag = getattr(blosc2, "_disable_overloaded_equal", False)
2276+
prev_flag = blosc2._disable_overloaded_equal
22732277
# We use a lot of the original NDArray.__eq__ as 'is', so deactivate the overloaded one
22742278
blosc2._disable_overloaded_equal = True
22752279
# One of the two operands are LazyExpr instances

src/blosc2/ndarray.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _check_allowed_dtypes(
385385
):
386386
raise RuntimeError(
387387
"Expected LazyExpr, NDArray, NDField, C2Array, Proxy, np.ndarray or scalar instances"
388-
f" and you provided a '{type(value)}' instance"
388+
+ f" and you provided a '{type(value)}' instance"
389389
)
390390

391391

@@ -5195,7 +5195,7 @@ def broadcast_to(arr, shape):
51955195
return (arr + blosc2.zeros(shape, dtype=arr.dtype)).compute() # return lazyexpr quickly
51965196

51975197

5198-
def meshgrid(arrays: NDArray, indexing: str = "xy") -> Sequence[NDArray]:
5198+
def meshgrid(*arrays: NDArray, indexing: str = "xy") -> Sequence[NDArray]:
51995199
"""
52005200
Returns coordinate matrices from coordinate vectors.
52015201
@@ -5217,4 +5217,32 @@ def meshgrid(arrays: NDArray, indexing: str = "xy") -> Sequence[NDArray]:
52175217
* if matrix indexing ij, then each returned array has shape (N1, N2, N3, ..., Nn).
52185218
* if Cartesian indexing xy, then each returned array has shape (N2, N1, N3, ..., Nn).
52195219
"""
5220-
raise NotImplementedError("Working on meshgrid")
5220+
out = ()
5221+
shape = np.ones(len(arrays))
5222+
first_arr = arrays[0]
5223+
myarrs = ()
5224+
if indexing == "xy" and len(shape) > 1:
5225+
# switch 0th and 1st shapes around
5226+
def mygen(i):
5227+
if i not in (0, 1):
5228+
return (j for j in range(len(arrays)) if j != i)
5229+
else:
5230+
return (j for j in range(len(arrays)) if j != builtins.abs(i - 1))
5231+
else:
5232+
mygen = lambda i: (j for j in range(len(arrays)) if j != i) # noqa : E731
5233+
5234+
for i, a in enumerate(arrays):
5235+
if len(a.shape) != 1 or a.dtype != first_arr.dtype:
5236+
raise ValueError("All arrays must be 1D and of same dtype.")
5237+
shape[i] = a.shape[0]
5238+
myarrs += (blosc2.expand_dims(a, tuple(mygen(i))),) # cheap, creates a view
5239+
5240+
# handle Cartesian indexing
5241+
shape = tuple(shape)
5242+
if indexing == "xy" and len(shape) > 1:
5243+
shape = (shape[1], shape[0]) + shape[2:]
5244+
5245+
# do broadcast
5246+
for a in myarrs:
5247+
out += (broadcast_to(a, shape),)
5248+
return out

tests/ndarray/test_ndarray.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,65 @@ def fancy_strided_output(inputs, output_indices, stride=1):
409409
output_indices = [800, 74, 671, 132, 818]
410410
out = fancy_strided_output(arr, output_indices, stride=16)
411411
assert out.shape == (2, 12, 5, 10, 8, 3)
412+
413+
414+
dtypes = [np.int32, np.float32, np.float64, np.uint8]
415+
416+
# Shapes for broadcast_to
417+
broadcast_shapes = [
418+
((10,), (50,), (4,), (3,)),
419+
((8, 6), (16, 12), (4, 3), (1, 3)),
420+
((2, 6), (2, 30), (3, 2), (1, 1)),
421+
((1, 1, 3), (2, 4, 3), (1, 1, 2), (1, 1, 1)),
422+
]
423+
424+
meshgrid_shapes = [
425+
((10, 20), (3,), (1,)),
426+
((8, 6), (4,), (3,)),
427+
((2, 30), (2,), (1,)),
428+
((20, 4, 3), (4,), (1,)),
429+
]
430+
431+
432+
@pytest.mark.parametrize("dtype", dtypes)
433+
@pytest.mark.parametrize(("src_shape", "dst_shape", "chunks", "blocks"), broadcast_shapes)
434+
def test_broadcast_to(dtype, src_shape, dst_shape, chunks, blocks):
435+
arr_np = np.arange(np.prod(src_shape), dtype=dtype).reshape(src_shape)
436+
arr_b2 = blosc2.asarray(arr_np, chunks=chunks, blocks=blocks)
437+
438+
try:
439+
np_broadcast = np.broadcast_to(arr_np, dst_shape)
440+
np_error = None
441+
except ValueError as e:
442+
np_broadcast = None
443+
np_error = e
444+
445+
if np_error is not None:
446+
with pytest.raises(type(np_error)):
447+
blosc2.broadcast_to(arr_b2, dst_shape)
448+
else:
449+
b2_broadcast = blosc2.broadcast_to(arr_b2, dst_shape)
450+
assert np.array_equal(b2_broadcast[:], np_broadcast)
451+
452+
453+
@pytest.mark.parametrize("dtype", dtypes)
454+
@pytest.mark.parametrize(("shapes", "chunks", "blocks"), meshgrid_shapes)
455+
@pytest.mark.parametrize("indexing", ["xy", "ij"])
456+
def test_meshgrid(dtype, shapes, chunks, blocks, indexing):
457+
arrays_np = [np.arange(np.prod(shape), dtype=dtype).reshape(shape) for shape in shapes]
458+
arrays_b2 = [blosc2.asarray(a, chunks=chunks, blocks=blocks) for a in arrays_np]
459+
try:
460+
np_grids = np.meshgrid(*arrays_np, indexing=indexing)
461+
np_error = None
462+
except ValueError as e:
463+
np_grids = None
464+
np_error = e
465+
466+
if np_error is not None:
467+
with pytest.raises(type(np_error)):
468+
blosc2.meshgrid(*arrays_b2, indexing=indexing)
469+
else:
470+
b2_grids = blosc2.meshgrid(*arrays_b2, indexing=indexing)
471+
assert len(b2_grids) == len(np_grids)
472+
for g_b2, g_np in zip(b2_grids, np_grids, strict=False):
473+
assert np.array_equal(g_b2[:], g_np)

0 commit comments

Comments
 (0)