Skip to content

Commit e28f176

Browse files
committed
import fix
1 parent 86a7503 commit e28f176

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import singledispatch
55
from typing import TYPE_CHECKING, cast, get_args
66

7-
import array_api_compat
87
import numpy as np
98

109
from .. import types
@@ -45,7 +44,7 @@ def generic_op(
4544

4645

4746
@generic_op.register(np.ndarray)
48-
# to avoid going array api path that would slow down the performance
47+
# register explicitly to avoid the array API path and performance slow down
4948
def _generic_op_numpy(
5049
x: np.ndarray,
5150
/,
@@ -73,6 +72,8 @@ def _generic_op_array_api(
7372
"""Handle arrays with native array API support."""
7473
del keep_cupy_as_array
7574

75+
import array_api_compat
76+
7677
xp = array_api_compat.array_namespace(x)
7778
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
7879

src/fast_array_utils/stats/_power.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import singledispatch
55
from typing import TYPE_CHECKING
66

7-
import array_api_compat
87
import numpy as np
98

109
from .. import types
@@ -40,6 +39,7 @@ def _power_numpy(x: np.ndarray, n: int, /, dtype: DTypeLike | None = None) -> np
4039

4140
@_power.register(types.HasArrayNamespace)
4241
def _power_array_api(x: types.HasArrayNamespace, n: int, /, dtype: DTypeLike | None = None) -> types.HasArrayNamespace:
42+
import array_api_compat
4343

4444
xp = array_api_compat.array_namespace(x)
4545
return xp.pow(x, n) if dtype is None else xp.pow(xp.astype(x, dtype), n)

0 commit comments

Comments
 (0)