@@ -177,7 +177,6 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
177177# torch.conj sets the conjugation bit, which breaks conversion to other
178178# libraries. See https://github.com/data-apis/array-api-compat/issues/173
179179conj = torch .conj_physical
180-
181180# Two-arg elementwise functions
182181# These require a wrapper to do the correct type promotion on 0-D tensors
183182add = _two_arg (torch .add )
@@ -574,6 +573,11 @@ def count_nonzero(
574573 return result
575574
576575
576+ # "repeat" is torch.repeat_interleave; also the dim argument
577+ def repeat (x : Array , repeats : int | array , / , * , axis : int | None = None ) -> Array :
578+ return torch .repeat_interleave (x , repeats , axis )
579+
580+
577581def where (
578582 condition : Array ,
579583 x1 : Array | bool | int | float | complex ,
@@ -854,6 +858,6 @@ def sign(x: Array, /) -> Array:
854858 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
855859 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
856860 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
857- 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' ]
861+ 'take' , 'take_along_axis' , 'sign' , 'finfo' , 'iinfo' , 'repeat' ]
858862
859863_all_ignore = ['torch' , 'get_xp' ]
0 commit comments