|
2 | 2 | Clustering tools based on the Mapper algorithm. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +from typing import List, Optional |
| 8 | + |
5 | 9 | import tdamapper.core |
6 | 10 | from tdamapper._common import EstimatorMixin, ParamsMixin, clone, deprecated |
7 | | -from tdamapper.core import TrivialCover, mapper_connected_components |
| 11 | +from tdamapper.core import ( |
| 12 | + ArrayLike, |
| 13 | + Clustering, |
| 14 | + Cover, |
| 15 | + TrivialCover, |
| 16 | + mapper_connected_components, |
| 17 | +) |
8 | 18 |
|
9 | 19 |
|
10 | 20 | class TrivialClustering(tdamapper.core.TrivialClustering): |
@@ -37,12 +47,19 @@ def __init__(self, *args, **kwargs): |
37 | 47 |
|
38 | 48 | class _MapperClustering(EstimatorMixin, ParamsMixin): |
39 | 49 |
|
40 | | - def __init__(self, cover=None, clustering=None, n_jobs=1): |
| 50 | + labels_: List[int] |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + cover: Optional[Cover] = None, |
| 55 | + clustering: Optional[Clustering] = None, |
| 56 | + n_jobs: int = 1, |
| 57 | + ): |
41 | 58 | self.cover = cover |
42 | 59 | self.clustering = clustering |
43 | 60 | self.n_jobs = n_jobs |
44 | 61 |
|
45 | | - def fit(self, X, y=None): |
| 62 | + def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> _MapperClustering: |
46 | 63 | y = X if y is None else y |
47 | 64 | X, y = self._validate_X_y(X, y) |
48 | 65 | cover = TrivialCover() if self.cover is None else self.cover |
|
0 commit comments