Skip to content

Commit 31c01ef

Browse files
committed
Added type hints to core
1 parent 4877d91 commit 31c01ef

9 files changed

Lines changed: 482 additions & 301 deletions

File tree

src/tdamapper/core.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
handlers=[logging.StreamHandler()],
5656
)
5757

58+
PointLike = Union[Any, NDArray[np.float64]]
59+
5860
ArrayLike = Union[List[Any], NDArray[np.float64]]
5961

6062

@@ -103,6 +105,7 @@ def _run_clustering(local_ids, X_local, clust):
103105
local_lbls = clust.fit(X_local).labels_
104106
return local_ids, local_lbls
105107

108+
cover.fit(y)
106109
_lbls = Parallel(n_jobs, prefer="threads")(
107110
delayed(_run_clustering)(
108111
local_ids, [X[j] for j in local_ids], clone(clustering)
@@ -268,6 +271,19 @@ def aggregate_graph(X: ArrayLike, graph: nx.Graph, agg: Callable) -> Dict:
268271
return agg_values
269272

270273

274+
class SpatialSearch(Protocol):
275+
"""
276+
Abstract interface for spatial search algorithms.
277+
278+
Subclasses should override the methods of this class to implement more
279+
meaningful spatial search algorithms.
280+
"""
281+
282+
def fit(self, X: ArrayLike) -> SpatialSearch: ...
283+
284+
def search(self, x: PointLike) -> List[int]: ...
285+
286+
271287
class Cover(Protocol):
272288
"""
273289
Abstract interface for cover algorithms.
@@ -281,32 +297,14 @@ def fit(self, X: ArrayLike) -> Cover: ...
281297
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]: ...
282298

283299

284-
class ProximityNetCover:
285-
286-
def fit(self, X: ArrayLike) -> Cover:
287-
raise NotImplementedError()
288-
289-
def search(self, x: Any) -> List[int]:
290-
raise NotImplementedError()
291-
292-
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
293-
covered_ids = set()
294-
for i, xi in enumerate(X):
295-
if i not in covered_ids:
296-
neigh_ids = self.search(xi)
297-
covered_ids.update(neigh_ids)
298-
if neigh_ids:
299-
yield neigh_ids
300-
301-
302300
class Clustering(Protocol):
303301

304302
labels_: List[int]
305303

306304
def fit(self, X: ArrayLike, y: Any = None) -> Clustering: ...
307305

308306

309-
class TrivialCover:
307+
class TrivialCover(ParamsMixin):
310308
"""
311309
Cover algorithm that covers data with a single subset containing the whole
312310
dataset.

0 commit comments

Comments
 (0)