Skip to content
16 changes: 11 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,17 @@
"ArrayLike": "numpy.typing.ArrayLike",
"DTypeLike": "numpy.typing.DTypeLike",
"NDArray": "numpy.typing.NDArray",
"CSBase": "scipy.sparse.spmatrix",
"CupyArray": "cupy.ndarray",
"CupySparseMatrix": "cupyx.scipy.sparse.spmatrix",
"DaskArray": "dask.array.Array",
"H5Dataset": "h5py.Dataset",
**{
k: v
for k_plain, v in {
"CSBase": "scipy.sparse.spmatrix",
"CupyArray": "cupy.ndarray",
"CupySparseMatrix": "cupyx.scipy.sparse.spmatrix",
"DaskArray": "dask.array.Array",
"H5Dataset": "h5py.Dataset",
}.items()
for k in (k_plain, f"types.{k_plain}")
},
}
# If that doesn’t work, ignore them
nitpick_ignore = {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ lint.ignore = [
]
lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs
lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum`
lint.per-file-ignores."src/fast_array_utils/types.py" = [ "N806" ] # We have variables that are classes here
lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions
lint.per-file-ignores."tests/**/test_*.py" = [
"D100", # tests need no module docstrings
Expand Down
4 changes: 1 addition & 3 deletions src/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from __future__ import annotations

from . import _patches, conv, stats, types
from . import conv, stats, types


__all__ = ["conv", "stats", "types"]

_patches.patch_dask()
94 changes: 94 additions & 0 deletions src/fast_array_utils/_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cache
from types import UnionType
from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload


if TYPE_CHECKING:
from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")


__all__ = ["import_by_qualname", "lazy_singledispatch"]


def import_by_qualname(qualname: str) -> object:
from importlib import import_module

mod_path, obj_path = qualname.split(":")

mod = import_module(mod_path)

if mod_path == "dask" or mod_path.startswith("dask."):
from ._patches import patch_dask

patch_dask()

# get object
obj = mod
for name in obj_path.split("."):
try:
obj = getattr(obj, name)
except AttributeError as e:
msg = f"Could not import {'.'.join(obj_path)} from {'.'.join(mod_path)} "
raise ImportError(msg) from e
return obj


@dataclass
class lazy_singledispatch(Generic[P, R]): # noqa: N801
fallback: Callable[P, R]

_lazy: dict[tuple[str, str], Callable[..., R]] = field(init=False, default_factory=dict)
_eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
fn = self.dispatch(type(args[0])) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11470
return fn(*args, **kwargs)

def __hash__(self) -> int:
return hash(self.fallback)

@cache # noqa: B019
def dispatch(self, typ: type) -> Callable[P, R]:
for cls_reg, fn in self._eager.items():
if issubclass(typ, cls_reg):
return fn
for (import_qualname, host_mod_name), fn in self._lazy.items():
for cls in typ.mro():
if cls.__module__.startswith(host_mod_name): # can be deeper
cls_reg = cast(type, import_by_qualname(import_qualname))
if issubclass(typ, cls_reg):
return fn
return self.fallback

@overload
def register(
self, qualname_or_type: str, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...
@overload
def register(
self, qualname_or_type: type | UnionType, /, host_mod_name: None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...

def register(
self, qualname_or_type: str | type | UnionType, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]:
def decorator(fn: Callable[..., R]) -> lazy_singledispatch[P, R]:
match qualname_or_type, host_mod_name:
case str(), _:
hmn = qualname_or_type.split(":")[0] if host_mod_name is None else host_mod_name
self._lazy[(qualname_or_type, hmn)] = fn
case type() | UnionType(), None:
self._eager[qualname_or_type] = fn
case _:
msg = f"name_or_type {qualname_or_type!r} must be a str, type, or UnionType"
raise TypeError(msg)
return self

return decorator
44 changes: 25 additions & 19 deletions src/fast_array_utils/conv/_asarray.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING

import numpy as np

from ..types import CSBase, CupyArray, CupySparseMatrix, DaskArray, H5Dataset, OutOfCoreDataset
from .._import import lazy_singledispatch
from ..types import OutOfCoreDataset


if TYPE_CHECKING:
from typing import Any

from numpy.typing import ArrayLike, NDArray

from .. import types

__all__ = ["OutOfCoreDataset", "asarray"]

__all__ = ["asarray"]


# fallback’s arg0 type has to include types of registered functions
@singledispatch
def asarray(x: ArrayLike | CSBase | OutOfCoreDataset[Any]) -> NDArray[Any]:
@lazy_singledispatch
def asarray(
x: ArrayLike
| types.CSBase
| types.DaskArray
| types.OutOfCoreDataset[Any]
| types.H5Dataset
| types.ZarrArray
| types.CupyArray
| types.CupySparseMatrix,
) -> NDArray[Any]:
"""Convert x to a numpy array.

Parameters
Expand All @@ -36,33 +47,28 @@ def asarray(x: ArrayLike | CSBase | OutOfCoreDataset[Any]) -> NDArray[Any]:
return np.asarray(x)


@asarray.register(CSBase)
def _(x: CSBase) -> NDArray[Any]:
@asarray.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(x: types.CSBase) -> NDArray[Any]:
from .scipy import to_dense

return to_dense(x)


@asarray.register(DaskArray)
def _(x: DaskArray) -> NDArray[Any]:
@asarray.register("dask.array:Array")
def _(x: types.DaskArray) -> NDArray[Any]:
return asarray(x.compute()) # type: ignore[no-untyped-call]


@asarray.register(OutOfCoreDataset)
def _(x: OutOfCoreDataset[CSBase | NDArray[Any]]) -> NDArray[Any]:
def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]:
return asarray(x.to_memory())


@asarray.register(H5Dataset)
def _(x: H5Dataset) -> NDArray[Any]:
return x[...] # type: ignore[no-any-return]


@asarray.register(CupyArray)
def _(x: CupyArray) -> NDArray[Any]:
@asarray.register("cupy:ndarray")
def _(x: types.CupyArray) -> NDArray[Any]:
return x.get() # type: ignore[no-any-return]


@asarray.register(CupySparseMatrix)
def _(x: CupySparseMatrix) -> NDArray[Any]:
@asarray.register("cupyx.scipy.sparse:spmatrix")
def _(x: types.CupySparseMatrix) -> NDArray[Any]:
return x.toarray().get() # type: ignore[no-any-return]
51 changes: 30 additions & 21 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import partial, singledispatch
from functools import partial
from typing import TYPE_CHECKING, overload

import numpy as np

from ..types import CSBase, CSMatrix, DaskArray
from .._import import lazy_singledispatch


if TYPE_CHECKING:
from typing import Any, Literal

from numpy.typing import ArrayLike, DTypeLike, NDArray

from .. import types


@overload
def sum(x: ArrayLike, *, axis: None = None, dtype: DTypeLike | None = None) -> np.number[Any]: ...
def sum(
x: ArrayLike, /, *, axis: None = None, dtype: DTypeLike | None = None
) -> np.number[Any]: ...
@overload
def sum(x: ArrayLike, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[Any]: ...
def sum(
x: ArrayLike, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
) -> NDArray[Any]: ...
@overload
def sum(
x: DaskArray, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None
) -> DaskArray: ...
x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None
) -> types.DaskArray: ...


def sum(
x: ArrayLike, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any] | DaskArray:
x: ArrayLike | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
"""Sum over both or one axis.

Returns
Expand All @@ -43,32 +53,30 @@ def sum(
return _sum(x, axis=axis, dtype=dtype)


@singledispatch
@lazy_singledispatch
def _sum(
x: ArrayLike | CSBase | DaskArray,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | DaskArray:
assert not isinstance(x, CSBase | DaskArray)
x: ArrayLike, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]


@_sum.register(CSBase)
@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(
x: CSBase, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any]:
import scipy.sparse as sp

from ..types import CSMatrix

if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]


@_sum.register(DaskArray)
@_sum.register("dask.array:Array")
def _(
x: DaskArray, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> DaskArray:
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray:
if TYPE_CHECKING:
from dask.array.reductions import reduction
else:
Expand All @@ -79,7 +87,8 @@ def _(
raise TypeError(msg)

def sum_drop_keepdims(
a: NDArray[Any] | CSBase,
a: NDArray[Any] | types.CSBase,
/,
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
Expand Down
Loading