Skip to content

Commit 65e83a6

Browse files
committed
speed up fix
1 parent b2e3f9b commit 65e83a6

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

src/fast_array_utils/stats/_mean_var.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,16 @@ def mean_var_(
3535

3636
from . import mean
3737

38-
if array_api_compat.is_array_api_obj(x):
39-
xp = array_api_compat.array_namespace(x)
40-
float64 = xp.float64
41-
else:
38+
if isinstance(x, np.ndarray | types.CSBase):
4239
float64 = np.float64
40+
else:
41+
import array_api_compat
42+
43+
if array_api_compat.is_array_api_obj(x):
44+
xp = array_api_compat.array_namespace(x)
45+
float64 = xp.float64
46+
else:
47+
float64 = np.float64
4348

4449
if axis is not None and isinstance(x, types.CSBase):
4550
mean_, var = _sparse_mean_var(x, axis=axis)

src/fast_array_utils/stats/_power.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
2727

2828

2929
@singledispatch
30-
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Any:
30+
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Any: # noqa: ANN401
3131
if TYPE_CHECKING:
3232
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
3333

0 commit comments

Comments
 (0)