Skip to content

Commit 9684213

Browse files
committed
Addressed the comments
1 parent 2a1924d commit 9684213

3 files changed

Lines changed: 13 additions & 30 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,6 @@ doc = [
5050
"sphinx>=9.0.1",
5151
"sphinx-autofixture>=0.4.1",
5252
]
53-
# for update-mypy-hook
54-
mypy = [
55-
"fast-array-utils[full]",
56-
"scipy-stubs",
57-
# TODO: replace sphinx with this: { include-group = "doc" },
58-
"sphinx",
59-
"types-docutils",
60-
{ include-group = "test" },
61-
]
6253
test-min = [
6354
"coverage[toml]",
6455
"fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense
@@ -75,7 +66,7 @@ envs.docs.dependency-groups = [ "doc" ]
7566
envs.docs.scripts.build = "sphinx-build -M html docs docs/_build"
7667
envs.docs.scripts.clean = "git clean -fdX docs"
7768
envs.docs.scripts.open = "python -m webbrowser -t docs/_build/html/index.html"
78-
envs.hatch-test.default-args = []
69+
envs.hatch-test.default-args = [ ]
7970
envs.hatch-test.dependency-groups = [ "test-min" ]
8071
# TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released
8172
envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ]
@@ -100,7 +91,7 @@ metadata.hooks.docstring-description = {}
10091
metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst"
10192
metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ]
10293
version.source = "vcs"
103-
version.raw-options = { local_scheme = "no-local-version" } # be able to publish dev version
94+
version.raw-options = { local_scheme = "no-local-version" } # be able to publish dev version
10495

10596
[tool.uv]
10697
override-dependencies = [ "sphinx>=9.0.1" ]
@@ -141,7 +132,7 @@ lint.per-file-ignores."typings/**/*.pyi" = [ "A002", "F403", "F405", "N801" ]
141132
lint.allowed-confusables = [ "×", "" ]
142133
lint.flake8-bugbear.extend-immutable-calls = [ "testing.fast_array_utils.Flags" ]
143134
lint.flake8-copyright.notice-rgx = "SPDX-License-Identifier: MPL-2\\.0"
144-
lint.flake8-type-checking.exempt-modules = []
135+
lint.flake8-type-checking.exempt-modules = [ ]
145136
lint.flake8-type-checking.strict = true
146137
lint.isort.known-first-party = [ "fast_array_utils" ]
147138
lint.isort.lines-after-imports = 2

src/fast_array_utils/stats/_is_constant.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from functools import partial, singledispatch
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING
66

77
import numba
88
import numpy as np
@@ -15,8 +15,6 @@
1515

1616
from numpy.typing import NDArray
1717

18-
# checking if all values in an array are the same
19-
2018

2119
@singledispatch
2220
def is_constant_(
@@ -26,36 +24,30 @@ def is_constant_(
2624
axis: Literal[0, 1] | None = None,
2725
) -> Any: # noqa: ANN401
2826

27+
# Catch types that lack __array_namespace like PyTorch
2928
import array_api_compat
3029

3130
if array_api_compat.is_array_api_obj(a):
32-
xp = array_api_compat.array_namespace(a)
33-
match axis:
34-
case None:
35-
return bool((a == xp.reshape(a, (-1,))[0]).all())
36-
case 0:
37-
return is_constant_(a.T, axis=1) # reusing axis = 1
38-
case 1:
39-
b = xp.broadcast_to(a[:, 0:1], a.shape)
40-
return (a == b).all(axis=1)
31+
return _is_constant_ndarray(a, axis=axis)
4132
raise NotImplementedError
4233

4334

44-
@is_constant_.register(np.ndarray | types.CupyArray)
35+
@is_constant_.register(np.ndarray | types.CupyArray | types.HasArrayNamespace)
4536
def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool] | types.CupyArray:
4637
# Should eventually support nd, not now.
38+
4739
match axis:
4840
case None:
49-
return bool((a == a.flat[0]).all())
41+
return bool((a == a.reshape(-1)[0]).all())
5042
case 0:
5143
return _is_constant_rows(a.T)
5244
case 1:
5345
return _is_constant_rows(a)
5446

5547

5648
def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray:
57-
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
58-
return cast("NDArray[np.bool]", (a == b).all(axis=1))
49+
# broadcasts without needing np.broadcast_to
50+
return (a == a[:, 0:1]).all(axis=1)
5951

6052

6153
@is_constant_.register(types.CSBase)

src/fast_array_utils/stats/_mean_var.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def mean_var_(
4545
if axis is not None and isinstance(x, types.CSBase):
4646
mean_, var = _sparse_mean_var(x, axis=axis)
4747
else:
48-
mean_ = mean(x, axis=axis, dtype=float64)
49-
mean_sq = mean(power(x, 2, dtype=float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=float64)
48+
mean_ = mean(x, axis=axis, dtype=xp.float64)
49+
mean_sq = mean(power(x, 2, dtype=xp.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=xp.float64)
5050
var = mean_sq - mean_**2
5151
if correction: # R convention == 1 (unbiased estimator)
5252
n = np.prod(x.shape) if axis is None else x.shape[axis]

0 commit comments

Comments
 (0)