Skip to content

Commit 613b732

Browse files
committed
Added types
1 parent 0eb73fe commit 613b732

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

src/tdamapper/clustering.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,19 @@
22
Clustering tools based on the Mapper algorithm.
33
"""
44

5+
from __future__ import annotations
6+
7+
from typing import List, Optional
8+
59
import tdamapper.core
610
from tdamapper._common import EstimatorMixin, ParamsMixin, clone, deprecated
7-
from tdamapper.core import TrivialCover, mapper_connected_components
11+
from tdamapper.core import (
12+
ArrayLike,
13+
Clustering,
14+
Cover,
15+
TrivialCover,
16+
mapper_connected_components,
17+
)
818

919

1020
class TrivialClustering(tdamapper.core.TrivialClustering):
@@ -37,12 +47,19 @@ def __init__(self, *args, **kwargs):
3747

3848
class _MapperClustering(EstimatorMixin, ParamsMixin):
3949

40-
def __init__(self, cover=None, clustering=None, n_jobs=1):
50+
labels_: List[int]
51+
52+
def __init__(
53+
self,
54+
cover: Optional[Cover] = None,
55+
clustering: Optional[Clustering] = None,
56+
n_jobs: int = 1,
57+
):
4158
self.cover = cover
4259
self.clustering = clustering
4360
self.n_jobs = n_jobs
4461

45-
def fit(self, X, y=None):
62+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> _MapperClustering:
4663
y = X if y is None else y
4764
X, y = self._validate_X_y(X, y)
4865
cover = TrivialCover() if self.cover is None else self.cover

0 commit comments

Comments
 (0)