Skip to content

Commit d71ae86

Browse files
committed
Improved variance of types
1 parent 61794a4 commit d71ae86

10 files changed

Lines changed: 129 additions & 93 deletions

File tree

src/tdamapper/_common.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from numpy.typing import NDArray
1515

16-
from tdamapper.protocols import Array
16+
from tdamapper.protocols import Array, ArrayRead
1717

1818
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
1919

@@ -35,48 +35,52 @@ def warn_user(msg: str) -> None:
3535

3636
class EstimatorMixin:
3737

38-
def _is_sparse(self, X: Array[Any]) -> bool:
38+
def _is_sparse(self, X: ArrayRead[Any]) -> bool:
3939
# simple alternative use scipy.sparse.issparse
4040
return hasattr(X, "toarray")
4141

4242
def _validate_X_y(
43-
self, X: Array[Any], y: Array[Any]
43+
self, X: ArrayRead[Any], y: ArrayRead[Any]
4444
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
4545
if self._is_sparse(X):
4646
raise ValueError("Sparse data not supported.")
4747

48-
X = np.asarray(X)
49-
y = np.asarray(y)
48+
X_ = np.asarray(X)
49+
y_ = np.asarray(y)
5050

51-
if X.size == 0:
52-
msg = f"0 feature(s) (shape={X.shape}) while a minimum of 1 is " "required."
51+
if X_.size == 0:
52+
msg = (
53+
f"0 feature(s) (shape={X_.shape}) while a minimum of 1 is " "required."
54+
)
5355
raise ValueError(msg)
5456

55-
if y.size == 0:
56-
msg = f"0 feature(s) (shape={y.shape}) while a minimum of 1 is " "required."
57+
if y_.size == 0:
58+
msg = (
59+
f"0 feature(s) (shape={y_.shape}) while a minimum of 1 is " "required."
60+
)
5761
raise ValueError(msg)
5862

59-
if X.ndim == 1:
63+
if X_.ndim == 1:
6064
raise ValueError("1d-arrays not supported.")
6165

62-
if np.iscomplexobj(X) or np.iscomplexobj(y):
66+
if np.iscomplexobj(X_) or np.iscomplexobj(y_):
6367
raise ValueError("Complex data not supported.")
6468

65-
if X.dtype == np.object_:
66-
X = np.array(X, dtype=float)
69+
if X_.dtype == np.object_:
70+
X_ = np.array(X_, dtype=float)
6771

68-
if y.dtype == np.object_:
69-
y = np.array(y, dtype=float)
72+
if y_.dtype == np.object_:
73+
y_ = np.array(y_, dtype=float)
7074

7175
if (
72-
np.isnan(X).any()
73-
or np.isinf(X).any()
74-
or np.isnan(y).any()
75-
or np.isinf(y).any()
76+
np.isnan(X_).any()
77+
or np.isinf(X_).any()
78+
or np.isnan(y_).any()
79+
or np.isinf(y_).any()
7680
):
7781
raise ValueError("NaNs or infinite values not supported.")
7882

79-
return X, y
83+
return X_, y_
8084

8185
def _set_n_features_in(self, X: Array[Any]) -> None:
8286
if hasattr(X, "shape"):

src/tdamapper/core.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from joblib import Parallel, delayed
3838

3939
from tdamapper._common import ParamsMixin, clone
40-
from tdamapper.protocols import Array, Clustering, Cover, SpatialSearch
40+
from tdamapper.protocols import ArrayRead, Clustering, Cover, SpatialSearch
4141
from tdamapper.utils.unionfind import UnionFind
4242

4343
ATTR_IDS = "ids"
@@ -60,8 +60,8 @@
6060

6161

6262
def mapper_labels(
63-
X: Array[S],
64-
y: Array[T],
63+
X: ArrayRead[S],
64+
y: ArrayRead[T],
6565
cover: Cover[T],
6666
clustering: Clustering[S],
6767
n_jobs: int = 1,
@@ -94,7 +94,7 @@ def mapper_labels(
9494
"""
9595

9696
def _run_clustering(
97-
local_ids: list[int], X_local: Array[S], clust: Clustering[S]
97+
local_ids: list[int], X_local: ArrayRead[S], clust: Clustering[S]
9898
) -> tuple[list[int], list[int]]:
9999
local_lbls = clust.fit(X_local).labels_
100100
return local_ids, local_lbls
@@ -119,8 +119,8 @@ def _run_clustering(
119119

120120

121121
def mapper_connected_components(
122-
X: Array[S],
123-
y: Array[T],
122+
X: ArrayRead[S],
123+
y: ArrayRead[T],
124124
cover: Cover[T],
125125
clustering: Clustering[S],
126126
n_jobs: int = 1,
@@ -168,8 +168,8 @@ def mapper_connected_components(
168168

169169

170170
def mapper_graph(
171-
X: Array[S],
172-
y: Array[T],
171+
X: ArrayRead[S],
172+
y: ArrayRead[T],
173173
cover: Cover[T],
174174
clustering: Clustering[S],
175175
n_jobs: int = 1,
@@ -218,7 +218,7 @@ def mapper_graph(
218218

219219

220220
def aggregate_graph(
221-
X: Array[S], graph: nx.Graph, agg: Callable[..., Any]
221+
X: ArrayRead[S], graph: nx.Graph, agg: Callable[..., Any]
222222
) -> dict[int, Any]:
223223
"""
224224
Apply an aggregation function to the nodes of a graph.
@@ -246,7 +246,7 @@ def aggregate_graph(
246246
return agg_values
247247

248248

249-
def proximity_net(search: SpatialSearch[S], X: Array[S]) -> Iterator[list[int]]:
249+
def proximity_net(search: SpatialSearch[S], X: ArrayRead[S]) -> Iterator[list[int]]:
250250
"""
251251
Covers the dataset using proximity-net.
252252
@@ -283,7 +283,7 @@ class TrivialCover(ParamsMixin, Generic[T]):
283283
dataset.
284284
"""
285285

286-
def apply(self, X: Array[T]) -> Iterator[list[int]]:
286+
def apply(self, X: ArrayRead[T]) -> Iterator[list[int]]:
287287
"""
288288
Covers the dataset with a single open set.
289289
@@ -317,7 +317,9 @@ def __init__(
317317
self.clustering = clustering
318318
self.verbose = verbose
319319

320-
def fit(self, X: Array[T], y: Optional[Array[T]] = None) -> FailSafeClustering[T]:
320+
def fit(
321+
self, X: ArrayRead[T], y: Optional[ArrayRead[T]] = None
322+
) -> FailSafeClustering[T]:
321323
self._clustering = (
322324
TrivialClustering() if self.clustering is None else self.clustering
323325
)
@@ -347,7 +349,9 @@ class TrivialClustering(ParamsMixin, Generic[T]):
347349
def __init__(self) -> None:
348350
pass
349351

350-
def fit(self, X: Array[T], _y: Optional[Array[T]] = None) -> TrivialClustering[T]:
352+
def fit(
353+
self, X: ArrayRead[T], _y: Optional[ArrayRead[T]] = None
354+
) -> TrivialClustering[T]:
351355
"""
352356
Fit the clustering algorithm to the data.
353357

src/tdamapper/cover.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,33 @@
1616

1717
from tdamapper._common import ParamsMixin, warn_user
1818
from tdamapper.core import proximity_net
19-
from tdamapper.protocols import Array, Metric
19+
from tdamapper.protocols import ArrayRead, Metric
2020
from tdamapper.utils.metrics import MetricLiteral, chebyshev, get_metric
2121
from tdamapper.utils.vptree import PivotingStrategy, VPTree, VPTreeKind
2222

2323
T = TypeVar("T")
24+
T_contra = TypeVar("T_contra", contravariant=True)
2425
S = TypeVar("S")
26+
S_contra = TypeVar("S_contra", contravariant=True)
2527

2628

27-
class _Pullback(Generic[S, T]):
29+
class _Pullback(Generic[S_contra, T_contra]):
2830

29-
def __init__(self, fun: Callable[[S], T], dist: Metric[T]):
31+
def __init__(
32+
self, fun: Callable[[S_contra], T_contra], dist: Metric[T_contra]
33+
) -> None:
3034
self.fun = fun
3135
self.dist = dist
3236

33-
def __call__(self, x: S, y: S) -> float:
37+
def __call__(self, x: S_contra, y: S_contra) -> float:
3438
return self.dist(self.fun(x), self.fun(y))
3539

3640

37-
def _snd(x: tuple[Any, ...]) -> Any:
41+
def _snd(x: tuple[T, ...]) -> T:
3842
return x[1]
3943

4044

41-
class BallCover(ParamsMixin, Generic[T]):
45+
class BallCover(ParamsMixin, Generic[T_contra]):
4246
"""
4347
Cover algorithm based on `ball proximity function`, which covers data with
4448
open balls of fixed radius.
@@ -67,8 +71,8 @@ class BallCover(ParamsMixin, Generic[T]):
6771
"""
6872

6973
_radius: float
70-
_data: list[tuple[int, T]]
71-
_vptree: VPTree[tuple[int, T]]
74+
_data: list[tuple[int, T_contra]]
75+
_vptree: VPTree[tuple[int, T_contra]]
7276

7377
def __init__(
7478
self,
@@ -88,7 +92,7 @@ def __init__(
8892
self.leaf_radius = leaf_radius
8993
self.pivoting = pivoting
9094

91-
def fit(self, X: Array[T]) -> BallCover[T]:
95+
def fit(self, X: ArrayRead[T_contra]) -> BallCover[T_contra]:
9296
"""
9397
Train internal parameters.
9498
@@ -112,7 +116,7 @@ def fit(self, X: Array[T]) -> BallCover[T]:
112116
)
113117
return self
114118

115-
def search(self, x: T) -> list[int]:
119+
def search(self, x: T_contra) -> list[int]:
116120
"""
117121
Return a list of neighbors for the query point.
118122
@@ -130,7 +134,7 @@ def search(self, x: T) -> list[int]:
130134
)
131135
return [x for (x, _) in neighs]
132136

133-
def apply(self, X: Array[T]) -> Iterator[list[int]]:
137+
def apply(self, X: ArrayRead[T_contra]) -> Iterator[list[int]]:
134138
"""
135139
Covers the dataset using proximity-net.
136140
@@ -144,7 +148,7 @@ def apply(self, X: Array[T]) -> Iterator[list[int]]:
144148
return proximity_net(self, X)
145149

146150

147-
class KNNCover(ParamsMixin, Generic[T]):
151+
class KNNCover(ParamsMixin, Generic[T_contra]):
148152
"""
149153
Cover algorithm based on `KNN proximity function`, which covers data using
150154
k-nearest neighbors (KNN).
@@ -173,8 +177,8 @@ class KNNCover(ParamsMixin, Generic[T]):
173177
"""
174178

175179
_neighbors: int
176-
_data: list[tuple[int, T]]
177-
_vptree: VPTree[tuple[int, T]]
180+
_data: list[tuple[int, T_contra]]
181+
_vptree: VPTree[tuple[int, T_contra]]
178182

179183
def __init__(
180184
self,
@@ -194,7 +198,7 @@ def __init__(
194198
self.leaf_radius = leaf_radius
195199
self.pivoting = pivoting
196200

197-
def fit(self, X: Array[T]) -> KNNCover[T]:
201+
def fit(self, X: ArrayRead[T_contra]) -> KNNCover[T_contra]:
198202
"""
199203
Train internal parameters.
200204
@@ -218,7 +222,7 @@ def fit(self, X: Array[T]) -> KNNCover[T]:
218222
)
219223
return self
220224

221-
def search(self, x: T) -> list[int]:
225+
def search(self, x: T_contra) -> list[int]:
222226
"""
223227
Return a list of neighbors for the query point.
224228
@@ -233,7 +237,7 @@ def search(self, x: T) -> list[int]:
233237
neighs = self._vptree.knn_search((-1, x), self._neighbors)
234238
return [x for (x, _) in neighs]
235239

236-
def apply(self, X: Array[T]) -> Iterator[list[int]]:
240+
def apply(self, X: ArrayRead[T_contra]) -> Iterator[list[int]]:
237241
"""
238242
Covers the dataset using proximity-net.
239243
@@ -309,7 +313,7 @@ def _get_bounds(
309313
_delta[(_delta >= -eps) & (_delta <= eps)] = self._n_intervals
310314
return _min, _max, _delta
311315

312-
def fit(self, X: Array[NDArray[np.float_]]) -> BaseCubicalCover:
316+
def fit(self, X: ArrayRead[NDArray[np.float_]]) -> BaseCubicalCover:
313317
"""
314318
Train internal parameters.
315319
@@ -408,7 +412,7 @@ def __init__(
408412
pivoting=pivoting,
409413
)
410414

411-
def apply(self, X: Array[NDArray[np.float_]]) -> Iterator[list[int]]:
415+
def apply(self, X: ArrayRead[NDArray[np.float_]]) -> Iterator[list[int]]:
412416
"""
413417
Covers the dataset using proximity-net.
414418
@@ -471,7 +475,7 @@ def __init__(
471475
)
472476

473477
def _landmarks(
474-
self, X: Array[NDArray[np.float_]]
478+
self, X: ArrayRead[NDArray[np.float_]]
475479
) -> dict[tuple[float], NDArray[np.float_]]:
476480
lmrks = {}
477481
for x in X:
@@ -480,7 +484,7 @@ def _landmarks(
480484
lmrks[lmrk] = x
481485
return lmrks
482486

483-
def apply(self, X: Array[NDArray[np.float_]]) -> Iterator[list[int]]:
487+
def apply(self, X: ArrayRead[NDArray[np.float_]]) -> Iterator[list[int]]:
484488
"""
485489
Covers the dataset using landmarks.
486490
@@ -578,7 +582,7 @@ def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCove
578582
"The only possible values for algorithm are 'standard' and 'proximity'."
579583
)
580584

581-
def fit(self, X: Array[NDArray[np.float_]]) -> CubicalCover:
585+
def fit(self, X: ArrayRead[NDArray[np.float_]]) -> CubicalCover:
582586
"""
583587
Train internal parameters.
584588
@@ -604,7 +608,7 @@ def search(self, x: NDArray[np.float_]) -> list[int]:
604608
"""
605609
return self._cubical_cover.search(x)
606610

607-
def apply(self, X: Array[NDArray[np.float_]]) -> Iterator[list[int]]:
611+
def apply(self, X: ArrayRead[NDArray[np.float_]]) -> Iterator[list[int]]:
608612
"""
609613
Covers the dataset using hypercubes.
610614

0 commit comments

Comments
 (0)