Skip to content

Commit 94c1706

Browse files
committed
torch.meshgrid: make it return tuple, not list
1 parent ddf9953 commit 94c1706

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -897,10 +897,11 @@ def sign(x: Array, /) -> Array:
897897
return out
898898

899899

900-
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
901-
# enforce the default of 'xy'
902-
# TODO: is the return type a list or a tuple
903-
return list(torch.meshgrid(*arrays, indexing=indexing))
900+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
901+
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
902+
# will be required to pass the indexing argument."
903+
# Thus always pass it explicitly.
904+
return torch.meshgrid(*arrays, indexing=indexing)
904905

905906

906907
__all__ = ['asarray', 'result_type', 'can_cast',

0 commit comments

Comments
 (0)