Skip to content

Commit ac981fc

Browse files
committed
Improved types and docs
1 parent 6f14ae4 commit ac981fc

11 files changed

Lines changed: 538 additions & 195 deletions

File tree

src/tdamapper/_common.py

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,27 @@
88
import io
99
import pstats
1010
import warnings
11-
from typing import Any, Callable, Dict
11+
from typing import Any, Callable, Dict, List, Union
1212

1313
import numpy as np
1414
from numpy.typing import NDArray
1515

1616
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
1717

1818

19+
PointLike = Union[Any, NDArray[np.float64]]
20+
21+
ArrayLike = Union[List[Any], NDArray[np.float64]]
22+
23+
1924
def deprecated(msg: str) -> Callable:
25+
"""
26+
Decorator to mark a function as deprecated.
27+
28+
:param msg: A message to be shown when the function is called.
29+
:return: A decorator that wraps the function and issues a warning when called.
30+
"""
31+
2032
def deprecated_func(func):
2133
def wrapper(*args, **kwargs):
2234
warnings.warn(msg, DeprecationWarning, stacklevel=2)
@@ -28,54 +40,91 @@ def wrapper(*args, **kwargs):
2840

2941

3042
def warn_user(msg: str) -> None:
43+
"""
44+
Issues a warning to the user.
45+
46+
:param msg: A message to be shown to the user.
47+
"""
3148
warnings.warn(msg, UserWarning, stacklevel=2)
3249

3350

3451
class EstimatorMixin:
52+
"""
53+
Mixin to add common functionalities to estimators, such as validation of
54+
input data, setting the number of features, and checking for sparse data.
55+
This mixin is intended to be used with scikit-learn compatible estimators.
56+
"""
57+
58+
def _is_sparse(self, x_arr: ArrayLike) -> bool:
59+
"""
60+
Checks if the input array `x_arr` is sparse.
3561
36-
def _is_sparse(self, X: NDArray) -> bool:
62+
:param x_arr: The input array to check.
63+
:return: True if `x_arr` is sparse, False otherwise.
64+
"""
3765
# simple alternative use scipy.sparse.issparse
38-
return hasattr(X, "toarray")
66+
return hasattr(x_arr, "toarray")
3967

40-
def _validate_X_y(self, X: NDArray, y: NDArray) -> tuple[NDArray, NDArray]:
41-
if self._is_sparse(X):
68+
def _validate_x_y(
69+
self,
70+
x_arr: ArrayLike,
71+
y_arr: ArrayLike,
72+
) -> tuple[NDArray, NDArray]:
73+
"""
74+
Validates the input arrays `x_arr` and `y_arr`.
75+
76+
:param x_arr: The input features array.
77+
:param y_arr: The target values array.
78+
:return: A tuple of validated numpy arrays (x_arr, y_arr).
79+
:raises ValueError: If the input arrays are not valid, e.g., if they
80+
are sparse, empty, 1-dimensional, contain complex numbers, or have
81+
NaNs or infinite values.
82+
"""
83+
if self._is_sparse(x_arr):
4284
raise ValueError("Sparse data not supported.")
4385

44-
X = np.asarray(X)
45-
y = np.asarray(y)
86+
x_arr_ = np.asarray(x_arr)
87+
y_arr_ = np.asarray(y_arr)
4688

47-
if X.size == 0:
48-
msg = f"0 feature(s) (shape={X.shape}) while a minimum of 1 is " "required."
89+
if x_arr_.size == 0:
90+
msg = (
91+
f"0 feature(s) (shape={x_arr_.shape}) while a minimum of 1 is "
92+
"required."
93+
)
4994
raise ValueError(msg)
5095

51-
if y.size == 0:
52-
msg = f"0 feature(s) (shape={y.shape}) while a minimum of 1 is " "required."
96+
if y_arr_.size == 0:
97+
msg = (
98+
f"0 feature(s) (shape={y_arr_.shape}) while a minimum of 1 is "
99+
"required."
100+
)
53101
raise ValueError(msg)
54102

55-
if X.ndim == 1:
103+
if x_arr_.ndim == 1:
56104
raise ValueError("1d-arrays not supported.")
57105

58-
if np.iscomplexobj(X) or np.iscomplexobj(y):
106+
if np.iscomplexobj(x_arr_) or np.iscomplexobj(y_arr_):
59107
raise ValueError("Complex data not supported.")
60108

61-
if X.dtype == np.object_:
62-
X = np.array(X, dtype=float)
109+
if x_arr_.dtype == np.object_:
110+
x_arr_ = np.array(x_arr_, dtype=float)
63111

64-
if y.dtype == np.object_:
65-
y = np.array(y, dtype=float)
112+
if y_arr_.dtype == np.object_:
113+
y_arr_ = np.array(y_arr_, dtype=float)
66114

67115
if (
68-
np.isnan(X).any()
69-
or np.isinf(X).any()
70-
or np.isnan(y).any()
71-
or np.isinf(y).any()
116+
np.isnan(x_arr_).any()
117+
or np.isinf(x_arr_).any()
118+
or np.isnan(y_arr_).any()
119+
or np.isinf(y_arr_).any()
72120
):
73121
raise ValueError("NaNs or infinite values not supported.")
74122

75-
return X, y
123+
return x_arr_, y_arr_
76124

77-
def _set_n_features_in(self, X: NDArray) -> None:
78-
self.n_features_in_ = X.shape[1]
125+
def _set_n_features_in(self, arr: ArrayLike) -> None:
126+
if hasattr(arr, "shape"):
127+
self.n_features_in_ = arr.shape[1]
79128

80129

81130
class ParamsMixin:
@@ -156,6 +205,14 @@ def clone(obj: Any) -> Any:
156205

157206

158207
def profile(n_lines: int = 10) -> Callable:
208+
"""
209+
Decorator to profile a function using cProfile and print the top `n_lines`
210+
lines of the profiling report.
211+
212+
:param n_lines: The number of lines to print from the profiling report.
213+
:return: A decorator that wraps the function and profiles its execution.
214+
"""
215+
159216
def decorator(func):
160217
def wrapper(*args, **kwargs):
161218
profiler = cProfile.Profile()

src/tdamapper/clustering.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ def __init__(self, *args, **kwargs):
4646

4747

4848
class _MapperClustering(EstimatorMixin, ParamsMixin):
49+
"""
50+
Mapper clustering model that fits the Mapper algorithm to the data.
51+
52+
This class is designed to be used with the Mapper algorithm for clustering
53+
data points based on their features. It allows for customization of the
54+
cover and clustering methods used in the Mapper algorithm.
55+
56+
:param cover: The cover method to use for the Mapper algorithm. If None,
57+
a trivial cover will be used.
58+
:param clustering: The clustering method to use for the Mapper algorithm.
59+
If None, a trivial clustering will be used.
60+
:param n_jobs: The number of jobs to run in parallel. Default is 1.
61+
"""
4962

5063
labels_: List[int]
5164

@@ -59,9 +72,18 @@ def __init__(
5972
self.clustering = clustering
6073
self.n_jobs = n_jobs
6174

62-
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> _MapperClustering:
63-
y = X if y is None else y
64-
X, y = self._validate_X_y(X, y)
75+
def fit(
76+
self, x_arr: ArrayLike, y_arr: Optional[ArrayLike] = None
77+
) -> _MapperClustering:
78+
"""
79+
Fit the Mapper clustering model to the data.
80+
81+
:param x_arr: The input features array.
82+
:param y_arr: The target values array. If None, `x_arr` is used as `y_arr`.
83+
:return: The fitted Mapper clustering model.
84+
"""
85+
y_arr = x_arr if y_arr is None else y_arr
86+
x_arr, y_arr = self._validate_x_y(x_arr, y_arr)
6587
cover = TrivialCover() if self.cover is None else self.cover
6688
cover = clone(cover)
6789
clustering = (
@@ -72,14 +94,14 @@ def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> _MapperClustering:
7294
clustering = clone(clustering)
7395
n_jobs = self.n_jobs
7496
itm_lbls = mapper_connected_components(
75-
X,
76-
y,
97+
x_arr,
98+
y_arr,
7799
cover,
78100
clustering,
79101
n_jobs=n_jobs,
80102
)
81-
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
82-
self._set_n_features_in(X)
103+
self.labels_ = [itm_lbls[i] for i, _ in enumerate(x_arr)]
104+
self._set_n_features_in(x_arr)
83105
return self
84106

85107

0 commit comments

Comments
 (0)