Skip to content

Commit 4138c05

Browse files
committed
BUG: torch.arange: workaround for missing dtype implementations
1 parent 288942c commit 4138c05

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,14 @@ def arange(start: float,
616616
dtype = torch.int64
617617
else:
618618
dtype = torch.float32
619-
return torch.empty(0, dtype=dtype, device=device, **kwargs)
620-
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
619+
try:
620+
return torch.empty(0, dtype=dtype, device=device, **kwargs)
621+
except NotImplementedError:
622+
return torch.empty(0, device=device, **kwargs).to(dtype)
623+
try:
624+
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
625+
except NotImplementedError:
626+
return torch.arange(start, stop, step, device=device, **kwargs).to(dtype)
621627

622628
# torch.eye does not accept None as a default for the second argument and
623629
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)

0 commit comments

Comments
 (0)