forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_torch.py
More file actions
119 lines (89 loc) · 3.79 KB
/
test_torch.py
File metadata and controls
119 lines (89 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
"""
import itertools
import pytest
try:
import torch
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found")
from array_api_compat import torch as xp
class TestResultType:
def test_empty(self):
with pytest.raises(ValueError):
xp.result_type()
def test_one_arg(self):
for x in [1, 1.0, 1j, '...', None]:
with pytest.raises((ValueError, AttributeError)):
xp.result_type(x)
for x in [xp.float32, xp.int64, torch.complex64]:
assert xp.result_type(x) == x
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
assert xp.result_type(x) == x.dtype
def test_two_args(self):
# Only include here things "unspecified" in the spec
# scalar, tensor or tensor,tensor
for x, y in [
(1., 1j),
(1j, xp.arange(3)),
(True, xp.asarray(3.)),
(xp.ones(3) == 1, 1j*xp.ones(3)),
]:
assert xp.result_type(x, y) == torch.result_type(x, y)
# dtype, scalar
for x, y in [
(1j, xp.int64),
(True, xp.float64),
]:
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
# dtype, dtype
for x, y in [
(xp.bool, xp.complex64)
]:
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
assert xp.result_type(x, y) == torch.result_type(xt, yt)
def test_multi_arg(self):
torch.set_default_dtype(torch.float32)
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
assert xp.result_type(*args) == torch.float16
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
assert xp.result_type(*args) == xp.complex64
args = [1, 2, 3j, xp.float64, 4, 5, 6]
assert xp.result_type(*args) == xp.complex128
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
assert xp.result_type(*args) == xp.complex128
i64 = xp.ones(1, dtype=xp.int64)
f16 = xp.ones(1, dtype=xp.float16)
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
assert xp.result_type(*i) == xp.float16, f"{i}"
with pytest.raises(ValueError):
xp.result_type(1, 2, 3, 4)
@pytest.mark.parametrize("default_dt", ['float32', 'float64'])
@pytest.mark.parametrize("dtype_a",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
@pytest.mark.parametrize("dtype_b",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
def test_gh_273(self, default_dt, dtype_a, dtype_b):
# Regression test for https://github.com/data-apis/array-api-compat/issues/273
try:
prev_default = torch.get_default_dtype()
default_dtype = getattr(torch, default_dt)
torch.set_default_dtype(default_dtype)
a = xp.asarray([2, 1], dtype=dtype_a)
b = xp.asarray([1, -1], dtype=dtype_b)
dtype_1 = xp.result_type(a, b, 1.0)
dtype_2 = xp.result_type(b, a, 1.0)
assert dtype_1 == dtype_2
finally:
torch.set_default_dtype(prev_default)
def test_meshgrid():
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
x, y = xp.asarray([1, 2]), xp.asarray([4])
X, Y = xp.meshgrid(x, y)
# output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
assert X.shape == X_xy.shape
assert xp.all(X == X_xy)
assert Y.shape == Y_xy.shape
assert xp.all(Y == Y_xy)