Skip to content

Commit 530e856

Browse files
authored
TST: add a test of torch.compile-ing array_namespace (#414)
1 parent 21bc5fa commit 530e856

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/test_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,18 @@ def test_round():
161161
r = xp.round(x, decimals=1, out=o)
162162
assert xp.all(r == o)
163163
assert r is o
164+
165+
166+
def test_dynamo_array_namespace():
167+
"""Check that torch.compiling array_namespace does not incur graph breaks."""
168+
from array_api_compat import array_namespace
169+
170+
def foo(x):
171+
xp = array_namespace(x)
172+
return xp.multiply(x, x)
173+
174+
bar = torch.compile(fullgraph=True)(foo)
175+
176+
x = torch.arange(3)
177+
y = bar(x)
178+
assert xp.all(y == x**2)

0 commit comments

Comments
 (0)