Skip to content

Commit 3933121

Browse files
committed
torch.meshgrid: make it return tuple, not list
1 parent c746621 commit 3933121

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,10 +881,11 @@ def sign(x: Array, /) -> Array:
881881
return out
882882

883883

884-
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
885-
# enforce the default of 'xy'
886-
# TODO: is the return type a list or a tuple
887-
return list(torch.meshgrid(*arrays, indexing=indexing))
884+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
885+
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
886+
# will be required to pass the indexing argument."
887+
# Thus always pass it explicitly.
888+
return torch.meshgrid(*arrays, indexing=indexing)
888889

889890

890891
__all__ = ['asarray', 'result_type', 'can_cast',

0 commit comments

Comments
 (0)