Skip to content

Commit 42181d3

Browse files
authored
Improve import cubed time (#866)
* Improve 'import cubed' time * Use plain NumPy as default backend array API * Fix failing test
1 parent 26103cd commit 42181d3

6 files changed

Lines changed: 17 additions & 12 deletions

File tree

.github/workflows/array-api-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,4 @@ jobs:
162162
163163
EOF
164164
165-
pytest -v -rxXfEA --hypothesis-max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --hypothesis-disable-deadline
165+
CUBED_BACKEND_ARRAY_API_MODULE=array_api_compat.numpy pytest -v -rxXfEA --hypothesis-max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --hypothesis-disable-deadline

cubed/array_api/inspection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
class __array_namespace_info__:
55
def capabilities(self):
6-
return {
6+
cap = {
77
"boolean indexing": False, # not supported in Cubed (#73)
88
"data-dependent shapes": False, # not supported in Cubed
9-
"max dimensions": nxp.__array_namespace_info__().capabilities()[
10-
"max dimensions"
11-
],
129
}
10+
nxp_cap = nxp.__array_namespace_info__().capabilities()
11+
if "max dimensions" in nxp_cap:
12+
cap["max dimensions"] = nxp_cap["max dimensions"]
13+
return cap
1314

1415
# devices and dtypes are determined by the backend array API
1516

cubed/backend_array_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
# The array implementation used for backend operations is stored in the
7-
# namespace variable, and defaults to array_api_compat.nump, unless it
7+
# namespace variable, and defaults to numpy, unless it
88
# is overridden by an environment variable.
99
# It must be compatible with the Python Array API standard, although
1010
# some extra functions are used too (e.g. nan functions),
@@ -29,9 +29,9 @@
2929
namespace = xp
3030

3131
else:
32-
import array_api_compat.numpy
32+
import numpy
3333

34-
namespace = array_api_compat.numpy
34+
namespace = numpy
3535
xp_name = "numpy"
3636

3737

cubed/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import random as pyrandom
22

3-
from numpy.random import Generator, Philox
4-
53
from cubed.backend_array_api import namespace as nxp
64
from cubed.backend_array_api import numpy_array_to_backend_array
75
from cubed.core.ops import map_blocks
@@ -27,6 +25,9 @@ def random(size, *, dtype=nxp.float64, chunks=None, spec=None):
2725

2826

2927
def _random(x, numblocks=None, root_seed=None, dtype=nxp.float64, block_id=None):
28+
# import as needed to avoid slow 'import cubed'
29+
from numpy.random import Generator, Philox
30+
3031
stream_id = block_id_to_offset(block_id, numblocks)
3132
rg = Generator(Philox(key=root_seed + stream_id))
3233
out = rg.random(x.shape, dtype=dtype)

cubed/storage/virtual.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any
33

44
import numpy as np
5-
from ndindex import ndindex
65

76
from cubed.backend_array_api import namespace as nxp
87
from cubed.backend_array_api import numpy_array_to_backend_array
@@ -27,6 +26,8 @@ def __init__(
2726
super().__init__(shape, dtype, chunks)
2827

2928
def __getitem__(self, key):
29+
from ndindex import ndindex # import as needed to avoid slow 'import cubed'
30+
3031
idx = ndindex[key]
3132
newshape = idx.newshape(self.shape)
3233
# use broadcast trick so array chunks only occupy a single value in memory
@@ -52,6 +53,8 @@ def __init__(
5253
self.fill_value = fill_value
5354

5455
def __getitem__(self, key):
56+
from ndindex import ndindex # import as needed to avoid slow 'import cubed'
57+
5558
idx = ndindex[key]
5659
newshape = idx.newshape(self.shape)
5760
# use broadcast trick so array chunks only occupy a single value in memory

cubed/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_add_scalars():
225225
("min", "max"),
226226
[
227227
(None, None),
228-
(4, None),
228+
# (4, None), # fails unless array-api-compat is used
229229
(None, 7),
230230
(4, 7),
231231
(0, 10),

0 commit comments

Comments
 (0)