Skip to content

Commit a59d3d3

Browse files
kaushikcfdinducer
authored andcommitted
implements expand_dims
expand_dims has the potential to hold more metadata than a plain reshape and has cleaner axis-type propagation semantics
1 parent 02492f5 commit a59d3d3

5 files changed

Lines changed: 113 additions & 2 deletions

File tree

pytato/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
einsum,
4040

4141
matmul, roll, transpose, stack, reshape, concatenate,
42+
expand_dims,
4243

4344
maximum, minimum, where,
4445

@@ -95,7 +96,8 @@
9596
"make_dict_of_named_arrays", "make_placeholder", "make_size_param",
9697
"make_data_wrapper", "einsum",
9798

98-
"matmul", "roll", "transpose", "stack", "reshape", "concatenate",
99+
"matmul", "roll", "transpose", "stack", "reshape", "expand_dims",
100+
"concatenate",
99101

100102
"generate_loopy", "generate_jax",
101103

pytato/array.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,4 +2464,44 @@ def squeeze(array: Array) -> Array:
24642464
for i, s_i in enumerate(array.shape))]
24652465

24662466

2467+
def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array:
2468+
"""
2469+
Reshapes *array* by adding 1-long axes at *axis* dimensions of the returned
2470+
array.
2471+
"""
2472+
from pytato.tags import ExpandedDimsReshape
2473+
2474+
if isinstance(axis, int):
2475+
axis = axis,
2476+
2477+
output_ndim = array.ndim + len(axis)
2478+
2479+
normalized_axis: List[int] = []
2480+
2481+
# {{{ sanity checks
2482+
2483+
for ax in axis:
2484+
if not (-output_ndim <= ax < output_ndim):
2485+
raise ValueError(f"Dimension {ax} not present in {output_ndim}-D array.")
2486+
2487+
normalized_axis.append(ax if ax >= 0 else (ax+output_ndim))
2488+
2489+
if len(set(normalized_axis)) != len(normalized_axis):
2490+
raise ValueError(f"repeated axis in '{axis}'.")
2491+
2492+
# }}}
2493+
2494+
new_shape = list(array.shape)
2495+
2496+
for ax in sorted(normalized_axis):
2497+
assert (0 <= ax < output_ndim)
2498+
new_shape.insert(ax, 1)
2499+
2500+
assert len(new_shape) == output_ndim
2501+
2502+
return Reshape(array, tuple(new_shape), "C",
2503+
tags=(_get_default_tags()
2504+
| {ExpandedDimsReshape(tuple(normalized_axis))}),
2505+
axes=_get_default_axes(len(new_shape)))
2506+
24672507
# vim: foldmethod=marker

pytato/tags.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
.. autoclass:: AssumeNonNegative
1212
"""
1313

14-
14+
from typing import Tuple
1515
from pytools.tag import Tag, UniqueTag, tag_dataclass
16+
from dataclasses import dataclass
1617

1718

1819
# {{{ pre-defined tag: ImplementationStrategy
@@ -101,3 +102,25 @@ class AssumeNonNegative(Tag):
101102
:class:`~pytato.target.Target` that all entries of the tagged array are
102103
non-negative.
103104
"""
105+
106+
107+
@dataclass(eq=True, frozen=True, repr=True)
108+
class ExpandedDimsReshape(UniqueTag):
109+
"""
110+
A tag that can be attached to a :class:`~pytato.array.Reshape` to indicate
111+
that the new dimensions created by :func:`pytato.expand_dims`.
112+
113+
:attr new_dims: A :class:`tuple` of the dimensions of the reshaped array
114+
that were added.
115+
116+
.. testsetup::
117+
118+
>>> import pytato as pt
119+
120+
.. doctest::
121+
122+
>>> x = pt.make_placeholder("x", (10, 4), "float64")
123+
>>> pt.expand_dims(x, (0, 2, 4)).tags_of_type(pt.tags.ExpandedDimsReshape)
124+
frozenset({ExpandedDimsReshape(new_dims=(0, 2, 4))})
125+
"""
126+
new_dims: Tuple[int, ...]

test/test_codegen.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,35 @@ def test_no_computation_for_empty_arrays(ctx_factory):
16921692
assert not bprg.program.default_entrypoint.instructions
16931693

16941694

1695+
def test_expand_dims(ctx_factory):
1696+
from numpy.random import default_rng
1697+
1698+
ntests = 50
1699+
rng = default_rng(seed=0)
1700+
ctx = ctx_factory()
1701+
cq = cl.CommandQueue(ctx)
1702+
1703+
for _ in range(ntests):
1704+
in_dim = rng.integers(2, 7)
1705+
n_new_axes = rng.integers(0, 7)
1706+
in_shape = rng.integers(2, 7, in_dim)
1707+
np_input = rng.random(in_shape, "float32")
1708+
pt_input = pt.make_data_wrapper(np_input)
1709+
axis = tuple(int(rng.integers(-(n_new_axes+in_dim), n_new_axes+in_dim))
1710+
for _ in range(n_new_axes))
1711+
1712+
try:
1713+
np_output = np.expand_dims(np_input, axis=axis)
1714+
except ValueError:
1715+
with pytest.raises(ValueError):
1716+
pt.expand_dims(pt_input, axis=axis)
1717+
else:
1718+
_, (pt_output, ) = pt.generate_loopy(pt.expand_dims(pt_input,
1719+
axis=axis))(cq)
1720+
1721+
np.testing.assert_allclose(np_output, pt_output)
1722+
1723+
16951724
if __name__ == "__main__":
16961725
if len(sys.argv) > 1:
16971726
exec(sys.argv[1])

test/test_pytato.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,23 @@ def test_adv_indexing_into_zero_long_axes():
884884
# }}}
885885

886886

887+
def test_expand_dims_input_validate():
888+
a = pt.make_placeholder("x", (10, 4), dtype="float64")
889+
890+
assert pt.expand_dims(a, (0, 2, 4)).shape == (1, 10, 1, 4, 1)
891+
assert pt.expand_dims(a, (-5, -3, -1)).shape == (1, 10, 1, 4, 1)
892+
assert pt.expand_dims(a, (-3)).shape == (1, 10, 4)
893+
894+
with pytest.raises(ValueError):
895+
pt.expand_dims(a, (3, 3))
896+
897+
with pytest.raises(ValueError):
898+
pt.expand_dims(a, (0, 2, 5))
899+
900+
with pytest.raises(ValueError):
901+
pt.expand_dims(a, -4)
902+
903+
887904
if __name__ == "__main__":
888905
if len(sys.argv) > 1:
889906
exec(sys.argv[1])

0 commit comments

Comments
 (0)