Skip to content

Commit 7719e48

Browse files
committed
BUG: add torch.repeat
1 parent 3e5fdc0 commit 7719e48

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
177177
# torch.conj sets the conjugation bit, which breaks conversion to other
178178
# libraries. See https://github.com/data-apis/array-api-compat/issues/173
179179
conj = torch.conj_physical
180-
181180
# Two-arg elementwise functions
182181
# These require a wrapper to do the correct type promotion on 0-D tensors
183182
add = _two_arg(torch.add)
@@ -574,6 +573,11 @@ def count_nonzero(
574573
return result
575574

576575

576+
# "repeat" is torch.repeat_interleave; also the dim argument
577+
def repeat(x: Array, repeats: int | array, /, *, axis: int | None = None) -> Array:
578+
return torch.repeat_interleave(x, repeats, axis)
579+
580+
577581
def where(
578582
condition: Array,
579583
x1: Array | bool | int | float | complex,
@@ -854,6 +858,6 @@ def sign(x: Array, /) -> Array:
854858
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
855859
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
856860
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
857-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo']
861+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
858862

859863
_all_ignore = ['torch', 'get_xp']

torch-xfails.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype
120120
array_api_tests/test_data_type_functions.py::test_iinfo_dtype
121121

122122
# 2023.12 support
123-
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
123+
124+
# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers
124125
array_api_tests/test_manipulation_functions.py::test_repeat
125-
array_api_tests/test_signatures.py::test_func_signature[repeat]
126126
# Argument 'device' missing from signature
127127
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
128128
# Argument 'max_version' missing from signature

0 commit comments

Comments
 (0)