-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathtest_torch.py
More file actions
178 lines (130 loc) · 5.46 KB
/
test_torch.py
File metadata and controls
178 lines (130 loc) · 5.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
"""
import itertools
import pytest
try:
import torch
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found")
from array_api_compat import torch as xp
class TestResultType:
def test_empty(self):
with pytest.raises(ValueError):
xp.result_type()
def test_one_arg(self):
for x in [1, 1.0, 1j, '...', None]:
with pytest.raises((ValueError, AttributeError)):
xp.result_type(x)
for x in [xp.float32, xp.int64, torch.complex64]:
assert xp.result_type(x) == x
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
assert xp.result_type(x) == x.dtype
def test_two_args(self):
# Only include here things "unspecified" in the spec
# scalar, tensor or tensor,tensor
for x, y in [
(1., 1j),
(1j, xp.arange(3)),
(True, xp.asarray(3.)),
(xp.ones(3) == 1, 1j*xp.ones(3)),
]:
assert xp.result_type(x, y) == torch.result_type(x, y)
# dtype, scalar
for x, y in [
(1j, xp.int64),
(True, xp.float64),
]:
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
# dtype, dtype
for x, y in [
(xp.bool, xp.complex64)
]:
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
assert xp.result_type(x, y) == torch.result_type(xt, yt)
def test_multi_arg(self):
torch.set_default_dtype(torch.float32)
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
assert xp.result_type(*args) == torch.float16
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
assert xp.result_type(*args) == xp.complex64
args = [1, 2, 3j, xp.float64, 4, 5, 6]
assert xp.result_type(*args) == xp.complex128
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
assert xp.result_type(*args) == xp.complex128
i64 = xp.ones(1, dtype=xp.int64)
f16 = xp.ones(1, dtype=xp.float16)
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
assert xp.result_type(*i) == xp.float16, f"{i}"
with pytest.raises(ValueError):
xp.result_type(1, 2, 3, 4)
@pytest.mark.parametrize("default_dt", ['float32', 'float64'])
@pytest.mark.parametrize("dtype_a",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
@pytest.mark.parametrize("dtype_b",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
def test_gh_273(self, default_dt, dtype_a, dtype_b):
# Regression test for https://github.com/data-apis/array-api-compat/issues/273
try:
prev_default = torch.get_default_dtype()
default_dtype = getattr(torch, default_dt)
torch.set_default_dtype(default_dtype)
a = xp.asarray([2, 1], dtype=dtype_a)
b = xp.asarray([1, -1], dtype=dtype_b)
dtype_1 = xp.result_type(a, b, 1.0)
dtype_2 = xp.result_type(b, a, 1.0)
assert dtype_1 == dtype_2
finally:
torch.set_default_dtype(prev_default)
def test_clip_vmap():
# https://github.com/data-apis/array-api-compat/issues/350
def apply_clip_compat(a):
return xp.clip(a, min=0, max=30)
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
ref = apply_clip_compat(a)
v1 = torch.vmap(apply_clip_compat)
assert xp.all(v1(a) == ref)
def test_meshgrid():
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
x, y = xp.asarray([1, 2]), xp.asarray([4])
X, Y = xp.meshgrid(x, y)
# output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
assert X.shape == X_xy.shape
assert xp.all(X == X_xy)
assert Y.shape == Y_xy.shape
assert xp.all(Y == Y_xy)
# repeat with an explicit indexing
X, Y = xp.meshgrid(x, y, indexing='ij')
# output of torch.meshgrid(x, y, indexing='ij')
X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]])
assert X.shape == X_ij.shape
assert xp.all(X == X_ij)
assert Y.shape == Y_ij.shape
assert xp.all(Y == Y_ij)
def test_argsort_stable():
"""Verify that argsort defaults to a stable sort."""
# Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper
# enforces the stable=True default.
# cf https://github.com/data-apis/array-api-compat/pull/356 and
# https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329
t = xp.zeros(50) # should be >16
assert xp.all(xp.argsort(t) == xp.arange(50))
def test_round():
"""Verify the out= argument of xp.round with complex inputs."""
x = torch.as_tensor([1.23456786]*3) + 3.456789j
o = torch.empty(3, dtype=torch.complex64)
r = xp.round(x, decimals=1, out=o)
assert xp.all(r == o)
assert r is o
def test_dynamo_array_namespace():
"""Check that torch.compiling array_namespace does not incur graph breaks."""
from array_api_compat import array_namespace
def foo(x):
xp = array_namespace(x)
return xp.multiply(x, x)
bar = torch.compile(fullgraph=True)(foo)
x = torch.arange(3)
y = bar(x)
assert xp.all(y == x**2)