Skip to content

Commit 474c969

Browse files
committed
fmt
1 parent f6463cc commit 474c969

5 files changed

Lines changed: 43 additions & 58 deletions

File tree

src/fast_array_utils/conv/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def to_dense(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K",
3636
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> types.CupyArray: ...
3737
@overload
3838
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...
39+
40+
3941
@overload
4042
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> A: ...
4143
@overload

src/fast_array_utils/conv/_to_dense.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K
4646
return np.asarray(x, order=order)
4747

4848

49-
@to_dense_.register(types.HasArrayNamespace)
50-
def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray:
51-
if to_cpu_memory:
52-
return np.asarray(x, order=order)
53-
return x
54-
55-
5649
@to_dense_.register(types.DaskArray)
5750
def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> NDArray[Any] | types.DaskArray:
5851
from . import to_dense
@@ -83,6 +76,13 @@ def _to_dense_cupy(x: GpuArray, /, *, order: Literal["K", "A", "C", "F"] = "K",
8376
return x.get(order="A") if to_cpu_memory else x
8477

8578

79+
@to_dense_.register(types.HasArrayNamespace)
80+
def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray:
81+
if to_cpu_memory:
82+
return np.asarray(x, order=order)
83+
return x
84+
85+
8686
def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types.CSDataset, /, *, order: Literal["K", "A", "C", "F"]) -> Literal["C", "F"]:
8787
if TYPE_CHECKING:
8888
from scipy.sparse._base import _spbase

src/fast_array_utils/stats/_mean.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def mean_(
2323
*,
2424
axis: Literal[0, 1] | None = None,
2525
dtype: DTypeLike | None = None,
26-
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
26+
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
2727
total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type]
2828
n = np.prod(x.shape) if axis is None else x.shape[axis]
2929
return total / n # type: ignore[no-any-return]

src/fast_array_utils/stats/_mean_var.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,15 @@ def mean_var_(
3232
| tuple[np.float64, np.float64]
3333
| tuple[types.DaskArray, types.DaskArray]
3434
):
35-
3635
from . import mean
3736

38-
if isinstance(x, np.ndarray | types.CSBase):
37+
if isinstance(x, np.ndarray | types.CSBase) or not isinstance(x, types.HasArrayNamespace):
3938
xp = np
40-
elif isinstance(x, types.HasArrayNamespace):
39+
else:
4140
import array_api_compat
4241

4342
xp = array_api_compat.array_namespace(x)
44-
else:
45-
xp = np
43+
4644
if axis is not None and isinstance(x, types.CSBase):
4745
mean_, var = _sparse_mean_var(x, axis=axis)
4846
else:

tests/test_jax.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,39 +33,18 @@ def jax_arr() -> jax.Array:
3333

3434

3535
@pytest.mark.parametrize("axis", [None, 0, 1])
36-
def test_sum(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
36+
@pytest.mark.parametrize("func", ["sum", "min", "max", "mean"])
37+
def test_simple_stat(jax_arr: jax.Array, func: Literal["sum", "min", "max", "mean"], axis: Literal[0, 1] | None) -> None:
3738
import jax.numpy as jnp
3839

39-
result = stats.sum(jax_arr, axis=axis)
40-
expected = jnp.sum(jax_arr, axis=axis)
41-
assert jnp.array_equal(result, expected)
40+
result = getattr(stats, func)(jax_arr, axis=axis)
41+
expected = getattr(jnp, func)(jax_arr, axis=axis)
4242

43-
44-
@pytest.mark.parametrize("axis", [None, 0, 1])
45-
def test_min(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
46-
import jax.numpy as jnp
47-
48-
result = stats.min(jax_arr, axis=axis)
49-
expected = jnp.min(jax_arr, axis=axis)
50-
assert jnp.array_equal(result, expected)
51-
52-
53-
@pytest.mark.parametrize("axis", [None, 0, 1])
54-
def test_max(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
55-
import jax.numpy as jnp
56-
57-
result = stats.max(jax_arr, axis=axis)
58-
expected = jnp.max(jax_arr, axis=axis)
59-
assert jnp.array_equal(result, expected)
60-
61-
62-
@pytest.mark.parametrize("axis", [None, 0, 1])
63-
def test_mean(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
64-
import jax.numpy as jnp
65-
66-
result = stats.mean(jax_arr, axis=axis)
67-
expected = jnp.mean(jax_arr, axis=axis)
68-
assert jnp.allclose(result, expected)
43+
assert type(result) is type(expected)
44+
if func == "mean":
45+
assert jnp.allclose(result, expected)
46+
else:
47+
assert jnp.array_equal(result, expected)
6948

7049

7150
@pytest.mark.parametrize("axis", [None, 0, 1])
@@ -86,37 +65,43 @@ def test_is_constant(axis: Literal[0, 1] | None) -> None:
8665
result = stats.is_constant(x, axis=axis)
8766

8867
if axis is None:
89-
assert bool(result) is False
68+
assert not result
9069
elif axis == 0:
9170
expected = jnp.array([True, True, False, False])
71+
assert type(result) is type(expected)
9272
assert jnp.array_equal(result, expected)
9373
else:
9474
expected = jnp.array([False, False, True, True, False, True])
75+
assert type(result) is type(expected)
9576
assert jnp.array_equal(result, expected)
9677

9778

9879
@pytest.mark.parametrize("axis", [None, 0, 1])
99-
def test_mean_var(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
80+
def test_mean_var(subtests: pytest.Subtests, jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
10081
import jax.numpy as jnp
10182

10283
mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)
10384

104-
mean_expected = jnp.mean(jax_arr, axis=axis)
105-
n = jax_arr.size if axis is None else jax_arr.shape[axis]
106-
var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)
85+
for name, result in dict(mean=mean, var=var).items():
86+
if name == "mean":
87+
expected = jnp.mean(jax_arr, axis=axis)
88+
else:
89+
n = jax_arr.size if axis is None else jax_arr.shape[axis]
90+
expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)
10791

108-
assert jnp.allclose(mean, mean_expected)
109-
assert jnp.allclose(var, var_expected)
92+
with subtests.test(name):
93+
assert type(result) is type(expected)
94+
assert jnp.allclose(result, expected)
11095

11196

112-
def test_to_dense(jax_arr: jax.Array) -> None:
97+
@pytest.mark.parametrize("to_cpu_memory", [True, False], ids=["to_cpu_memory", "not_to_cpu_memory"])
98+
def test_to_dense(*, jax_arr: jax.Array, to_cpu_memory: bool) -> None:
11399
import jax.numpy as jnp
114100

115-
result = to_dense(jax_arr)
116-
assert jnp.array_equal(result, jax_arr)
117-
101+
result = to_dense(jax_arr, to_cpu_memory=to_cpu_memory)
118102

119-
def test_to_dense_to_cpu(jax_arr: jax.Array) -> None:
120-
result = to_dense(jax_arr, to_cpu_memory=True)
121-
assert isinstance(result, np.ndarray)
122-
np.testing.assert_array_equal(result, np.asarray(jax_arr))
103+
if to_cpu_memory:
104+
assert isinstance(result, np.ndarray)
105+
else:
106+
assert isinstance(result, jax.Array)
107+
assert jnp.array_equal(result, jax_arr)

0 commit comments

Comments
 (0)