Skip to content

Commit 582f55f

Browse files
committed
Improved types
1 parent 613b732 commit 582f55f

17 files changed

Lines changed: 508 additions & 349 deletions

src/tdamapper/core.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,11 @@ def _run_clustering(local_ids, X_local, clust):
105105
local_lbls = clust.fit(X_local).labels_
106106
return local_ids, local_lbls
107107

108-
cover.fit(y)
109108
_lbls = Parallel(n_jobs, prefer="threads")(
110109
delayed(_run_clustering)(
111110
local_ids, [X[j] for j in local_ids], clone(clustering)
112111
)
113-
for local_ids in cover.transform(y)
112+
for local_ids in cover.fit_transform(y)
114113
)
115114
itm_lbls: List[List[int]] = [[] for _ in X]
116115
max_lbl = 0
@@ -322,6 +321,18 @@ def fit(self, X: ArrayLike) -> Cover:
322321
:return: self
323322
"""
324323

324+
def fit_transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
325+
"""
326+
Fit the cover algorithm to the data and transform it.
327+
328+
This method should yield a generator of lists, where each list contains
329+
the indices of the points in the dataset that belong to the open set.
330+
331+
:param X: A dataset of n points.
332+
:type X: array-like of shape (n, m) or list-like of length n
333+
:yield: A generator of lists of indices.
334+
"""
335+
325336
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
326337
"""
327338
Transform the data into overlapping open sets.
@@ -384,6 +395,10 @@ def fit(self, X: ArrayLike) -> TrivialCover:
384395
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
385396
yield list(range(len(X)))
386397

398+
def fit_transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
399+
self.fit(X)
400+
return self.transform(X)
401+
387402

388403
class _MapperAlgorithm(EstimatorMixin, ParamsMixin):
389404

0 commit comments

Comments
 (0)