Skip to content

Commit 1f34958

Browse files
committed
WIP: add axis tuple support to torch.expand_dims
1 parent d4dae32 commit 1f34958

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,23 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
690690
return torch.triu(x, k)
691691

692692
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
693-
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
694-
return torch.unsqueeze(x, axis)
693+
def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array:
694+
if isinstance(axis, int):
695+
return torch.unsqueeze(x, axis)
696+
else:
697+
y_ndim = x.ndim + len(axis)
698+
699+
# normalize
700+
n_axis = tuple(ax + y_ndim if ax < 0 else ax for ax in axis)
701+
if (len(n_axis) != len(set(n_axis)) or
702+
_builtin_any(ax < 0 or ax >= y_ndim for ax in n_axis)
703+
):
704+
raise ValueError()
705+
706+
shape_it = iter(x.shape)
707+
shape = [1 if ax in n_axis else next(shape_it) for ax in range(y_ndim)]
695708

709+
return torch.reshape(x, shape)
696710

697711
def astype(
698712
x: Array,

0 commit comments

Comments
 (0)