1616
1717from tdamapper ._common import ParamsMixin , warn_user
1818from tdamapper .core import proximity_net
19- from tdamapper .protocols import Array , Metric
19+ from tdamapper .protocols import ArrayRead , Metric
2020from tdamapper .utils .metrics import MetricLiteral , chebyshev , get_metric
2121from tdamapper .utils .vptree import PivotingStrategy , VPTree , VPTreeKind
2222
2323T = TypeVar ("T" )
24+ T_contra = TypeVar ("T_contra" , contravariant = True )
2425S = TypeVar ("S" )
26+ S_contra = TypeVar ("S_contra" , contravariant = True )
2527
2628
27- class _Pullback (Generic [S , T ]):
29+ class _Pullback (Generic [S_contra , T_contra ]):
2830
29- def __init__ (self , fun : Callable [[S ], T ], dist : Metric [T ]):
31+ def __init__ (
32+ self , fun : Callable [[S_contra ], T_contra ], dist : Metric [T_contra ]
33+ ) -> None :
3034 self .fun = fun
3135 self .dist = dist
3236
33- def __call__ (self , x : S , y : S ) -> float :
37+ def __call__ (self , x : S_contra , y : S_contra ) -> float :
3438 return self .dist (self .fun (x ), self .fun (y ))
3539
3640
37- def _snd (x : tuple [Any , ...]) -> Any :
41+ def _snd (x : tuple [T , ...]) -> T :
3842 return x [1 ]
3943
4044
41- class BallCover (ParamsMixin , Generic [T ]):
45+ class BallCover (ParamsMixin , Generic [T_contra ]):
4246 """
4347 Cover algorithm based on `ball proximity function`, which covers data with
4448 open balls of fixed radius.
@@ -67,8 +71,8 @@ class BallCover(ParamsMixin, Generic[T]):
6771 """
6872
6973 _radius : float
70- _data : list [tuple [int , T ]]
71- _vptree : VPTree [tuple [int , T ]]
74+ _data : list [tuple [int , T_contra ]]
75+ _vptree : VPTree [tuple [int , T_contra ]]
7276
7377 def __init__ (
7478 self ,
@@ -88,7 +92,7 @@ def __init__(
8892 self .leaf_radius = leaf_radius
8993 self .pivoting = pivoting
9094
91- def fit (self , X : Array [ T ]) -> BallCover [T ]:
95+ def fit (self , X : ArrayRead [ T_contra ]) -> BallCover [T_contra ]:
9296 """
9397 Train internal parameters.
9498
@@ -112,7 +116,7 @@ def fit(self, X: Array[T]) -> BallCover[T]:
112116 )
113117 return self
114118
115- def search (self , x : T ) -> list [int ]:
119+ def search (self , x : T_contra ) -> list [int ]:
116120 """
117121 Return a list of neighbors for the query point.
118122
@@ -130,7 +134,7 @@ def search(self, x: T) -> list[int]:
130134 )
131135 return [x for (x , _ ) in neighs ]
132136
133- def apply (self , X : Array [ T ]) -> Iterator [list [int ]]:
137+ def apply (self , X : ArrayRead [ T_contra ]) -> Iterator [list [int ]]:
134138 """
135139 Covers the dataset using proximity-net.
136140
@@ -144,7 +148,7 @@ def apply(self, X: Array[T]) -> Iterator[list[int]]:
144148 return proximity_net (self , X )
145149
146150
147- class KNNCover (ParamsMixin , Generic [T ]):
151+ class KNNCover (ParamsMixin , Generic [T_contra ]):
148152 """
149153 Cover algorithm based on `KNN proximity function`, which covers data using
150154 k-nearest neighbors (KNN).
@@ -173,8 +177,8 @@ class KNNCover(ParamsMixin, Generic[T]):
173177 """
174178
175179 _neighbors : int
176- _data : list [tuple [int , T ]]
177- _vptree : VPTree [tuple [int , T ]]
180+ _data : list [tuple [int , T_contra ]]
181+ _vptree : VPTree [tuple [int , T_contra ]]
178182
179183 def __init__ (
180184 self ,
@@ -194,7 +198,7 @@ def __init__(
194198 self .leaf_radius = leaf_radius
195199 self .pivoting = pivoting
196200
197- def fit (self , X : Array [ T ]) -> KNNCover [T ]:
201+ def fit (self , X : ArrayRead [ T_contra ]) -> KNNCover [T_contra ]:
198202 """
199203 Train internal parameters.
200204
@@ -218,7 +222,7 @@ def fit(self, X: Array[T]) -> KNNCover[T]:
218222 )
219223 return self
220224
221- def search (self , x : T ) -> list [int ]:
225+ def search (self , x : T_contra ) -> list [int ]:
222226 """
223227 Return a list of neighbors for the query point.
224228
@@ -233,7 +237,7 @@ def search(self, x: T) -> list[int]:
233237 neighs = self ._vptree .knn_search ((- 1 , x ), self ._neighbors )
234238 return [x for (x , _ ) in neighs ]
235239
236- def apply (self , X : Array [ T ]) -> Iterator [list [int ]]:
240+ def apply (self , X : ArrayRead [ T_contra ]) -> Iterator [list [int ]]:
237241 """
238242 Covers the dataset using proximity-net.
239243
@@ -309,7 +313,7 @@ def _get_bounds(
309313 _delta [(_delta >= - eps ) & (_delta <= eps )] = self ._n_intervals
310314 return _min , _max , _delta
311315
312- def fit (self , X : Array [NDArray [np .float_ ]]) -> BaseCubicalCover :
316+ def fit (self , X : ArrayRead [NDArray [np .float_ ]]) -> BaseCubicalCover :
313317 """
314318 Train internal parameters.
315319
@@ -408,7 +412,7 @@ def __init__(
408412 pivoting = pivoting ,
409413 )
410414
411- def apply (self , X : Array [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
415+ def apply (self , X : ArrayRead [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
412416 """
413417 Covers the dataset using proximity-net.
414418
@@ -471,7 +475,7 @@ def __init__(
471475 )
472476
473477 def _landmarks (
474- self , X : Array [NDArray [np .float_ ]]
478+ self , X : ArrayRead [NDArray [np .float_ ]]
475479 ) -> dict [tuple [float ], NDArray [np .float_ ]]:
476480 lmrks = {}
477481 for x in X :
@@ -480,7 +484,7 @@ def _landmarks(
480484 lmrks [lmrk ] = x
481485 return lmrks
482486
483- def apply (self , X : Array [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
487+ def apply (self , X : ArrayRead [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
484488 """
485489 Covers the dataset using landmarks.
486490
@@ -578,7 +582,7 @@ def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCove
578582 "The only possible values for algorithm are 'standard' and 'proximity'."
579583 )
580584
581- def fit (self , X : Array [NDArray [np .float_ ]]) -> CubicalCover :
585+ def fit (self , X : ArrayRead [NDArray [np .float_ ]]) -> CubicalCover :
582586 """
583587 Train internal parameters.
584588
@@ -604,7 +608,7 @@ def search(self, x: NDArray[np.float_]) -> list[int]:
604608 """
605609 return self ._cubical_cover .search (x )
606610
607- def apply (self , X : Array [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
611+ def apply (self , X : ArrayRead [NDArray [np .float_ ]]) -> Iterator [list [int ]]:
608612 """
609613 Covers the dataset using hypercubes.
610614
0 commit comments