Skip to content

Commit 75ee8bc

Browse files
lucascolleyev-br
andauthored
ENH: array_namespace: support torch.compile (#413)
* ENH: array_namespace: support `torch.compile` without graph breaks torch.dynamo does not know that module objects are hashable, and graph breaks on compiling a set of module variables. Use lists instead to avoid a graph break. --------- Co-authored-by: Evgeni Burovski <evgeny.burovskiy@gmail.com>
1 parent 9f3a525 commit 75ee8bc

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

array_api_compat/common/_helpers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def your_function(x, y):
641641
is_pydata_sparse_array
642642
643643
"""
644-
namespaces: set[Namespace] = set()
644+
namespaces: list[Namespace] = []
645645
for x in xs:
646646
xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
647647
if info is _ClsToXPInfo.SCALAR:
@@ -663,7 +663,19 @@ def your_function(x, y):
663663
)
664664
xp = get_ns(api_version=api_version)
665665

666-
namespaces.add(xp)
666+
namespaces.append(xp)
667+
668+
# Use a list of modules to avoid a graph break under torch.compile:
669+
# torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable
670+
# Explanation: Dynamo does not know whether the underlying python object for
671+
# PythonModuleVariable(<module 'array_api_compat.torch' from ...) is hashable
672+
names = set(x.__name__ for x in namespaces)
673+
unique_namespaces = []
674+
for ns in namespaces:
675+
if (name := ns.__name__) in names:
676+
unique_namespaces.append(ns)
677+
names.remove(name)
678+
namespaces = unique_namespaces
667679

668680
try:
669681
(xp,) = namespaces

tests/test_no_dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Array:
1717
# Dummy array namespace that doesn't depend on any array library
1818
def __array_namespace__(self, api_version=None):
1919
class Namespace:
20-
pass
20+
__name__: str = "foobar"
2121
return Namespace()
2222

2323
def _test_dependency(mod):

tests/test_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,18 @@ def test_round():
161161
r = xp.round(x, decimals=1, out=o)
162162
assert xp.all(r == o)
163163
assert r is o
164+
165+
166+
def test_dynamo_array_namespace():
167+
"""Check that torch.compiling array_namespace does not incur graph breaks."""
168+
from array_api_compat import array_namespace
169+
170+
def foo(x):
171+
xp = array_namespace(x)
172+
return xp.multiply(x, x)
173+
174+
bar = torch.compile(fullgraph=True)(foo)
175+
176+
x = torch.arange(3)
177+
y = bar(x)
178+
assert xp.all(y == x**2)

0 commit comments

Comments
 (0)