Skip to content

Commit be036b0

Browse files
committed
Improved types and socd
1 parent 582f55f commit be036b0

11 files changed

Lines changed: 397 additions & 133 deletions

File tree

src/tdamapper/_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
9898
Get all public parameters of the object as a dictionary.
9999
100100
:param deep: A flag for returning also nested parameters.
101-
:type deep: bool, optional.
102101
"""
103102
params = {}
104103
for k, v in self.__dict__.items():

src/tdamapper/core.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,13 @@ def mapper_labels(
8585
than those at position j.
8686
8787
:param X: A dataset of n points.
88-
:type X: array-like of shape (n, m) or list-like of length n
8988
:param y: Lens values for the n points of the dataset.
90-
:type y: array-like of shape (n, k) or list-like of length n
9189
:param cover: A cover algorithm.
92-
:type cover: A class compatible with :class:`tdamapper.core.Cover`
9390
:param clustering: A clustering algorithm.
94-
:type clustering: An estimator compatible with scikit-learn's clustering
95-
interface, typically from :mod:`sklearn.cluster`.
9691
:param n_jobs: The maximum number of parallel clustering jobs. This
9792
parameter is passed to the constructor of :class:`joblib.Parallel`.
9893
Defaults to 1.
99-
:type n_jobs: int
10094
:return: A list of node labels for each point in the dataset.
101-
:rtype: list[list[int]]
10295
"""
10396

10497
def _run_clustering(local_ids, X_local, clust):
@@ -146,22 +139,15 @@ def mapper_connected_components(
146139
:func:`networkx.connected_components` on it.
147140
148141
:param X: A dataset of n points.
149-
:type X: array-like of shape (n, m) or list-like of length n
150142
:param y: Lens values for the n points of the dataset.
151-
:type y: array-like of shape (n, k) or list-like of length n
152143
:param cover: A cover algorithm.
153-
:type cover: A class compatible with :class:`tdamapper.core.Cover`
154144
:param clustering: The clustering algorithm to apply to each subset of the
155145
dataset.
156-
:type clustering: An estimator compatible with scikit-learn's clustering
157-
interface, typically from :mod:`sklearn.cluster`.
158146
:param n_jobs: The maximum number of parallel clustering jobs. This
159147
parameter is passed to the constructor of :class:`joblib.Parallel`.
160148
Defaults to 1.
161-
:type n_jobs: int
162149
:return: A list of labels. The label at position i identifies the connected
163150
component of the point at position i in the dataset.
164-
:rtype: list[int]
165151
"""
166152
itm_lbls = mapper_labels(X, y, cover, clustering, n_jobs=n_jobs)
167153
label_values = set()
@@ -202,21 +188,14 @@ def mapper_graph(
202188
contained in the cluster.
203189
204190
:param X: A dataset of n points.
205-
:type X: array-like of shape (n, m) or list-like of length n
206191
:param y: Lens values for the n points of the dataset.
207-
:type y: array-like of shape (n, k) or list-like of length n
208192
:param cover: A cover algorithm.
209-
:type cover: A class compatible with :class:`tdamapper.core.Cover`
210193
:param clustering: The clustering algorithm to apply to each subset of the
211194
dataset.
212-
:type clustering: An estimator compatible with scikit-learn's clustering
213-
interface, typically from :mod:`sklearn.cluster`.
214195
:param n_jobs: The maximum number of parallel clustering jobs. This
215196
parameter is passed to the constructor of :class:`joblib.Parallel`.
216197
Defaults to 1.
217-
:type n_jobs: int
218198
:return: The Mapper graph.
219-
:rtype: :class:`networkx.Graph`
220199
"""
221200
itm_lbls = mapper_labels(X, y, cover, clustering, n_jobs=n_jobs)
222201
graph = nx.Graph()
@@ -253,13 +232,9 @@ def aggregate_graph(X: ArrayLike, graph: nx.Graph, agg: Callable) -> Dict:
253232
and the values are the aggregation values.
254233
255234
:param X: A dataset of n points.
256-
:type X: array-like of shape (n, m) or list-like of length n
257235
:param graph: The graph to apply the aggregation function to.
258-
:type graph: :class:`networkx.Graph`.
259236
:param agg: The aggregation function to use.
260-
:type agg: Callable.
261237
:return: A dictionary of node-aggregation pairs.
262-
:rtype: dict
263238
"""
264239
agg_values = {}
265240
nodes = graph.nodes()
@@ -286,7 +261,6 @@ def fit(self, X: ArrayLike) -> SpatialSearch:
286261
Fit the spatial search algorithm to the data.
287262
288263
:param X: A dataset of n points.
289-
:type X: array-like of shape (n, m) or list-like of length n
290264
:return: self
291265
"""
292266

@@ -295,9 +269,7 @@ def search(self, x: PointLike) -> List[int]:
295269
Search for the nearest neighbors of a point.
296270
297271
:param x: A point to search for.
298-
:type x: A point-like object, such as a list or a numpy array.
299272
:return: A list of indices of the nearest neighbors of the point.
300-
:rtype: list[int]
301273
"""
302274

303275

@@ -317,7 +289,6 @@ def fit(self, X: ArrayLike) -> Cover:
317289
Fit the cover algorithm to the data.
318290
319291
:param X: A dataset of n points.
320-
:type X: array-like of shape (n, m) or list-like of length n
321292
:return: self
322293
"""
323294

@@ -329,7 +300,6 @@ def fit_transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
329300
the indices of the points in the dataset that belong to the open set.
330301
331302
:param X: A dataset of n points.
332-
:type X: array-like of shape (n, m) or list-like of length n
333303
:yield: A generator of lists of indices.
334304
"""
335305

@@ -341,7 +311,6 @@ def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
341311
the indices of the points in the dataset that belong to the open set.
342312
343313
:param X: A dataset of n points.
344-
:type X: array-like of shape (n, m) or list-like of length n
345314
:yield: A generator of lists of indices.
346315
"""
347316

@@ -374,7 +343,6 @@ def fit(self, X: ArrayLike, y: Any = None) -> Clustering:
374343
Fit the clustering algorithm to the data.
375344
376345
:param X: A dataset of n points.
377-
:type X: array-like of shape (n, m) or list-like of length n
378346
:param y: Ignored.
379347
:return: self
380348
"""
@@ -402,6 +370,13 @@ def fit_transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
402370

403371
class _MapperAlgorithm(EstimatorMixin, ParamsMixin):
404372

373+
_cover: Cover
374+
_clustering: Clustering
375+
_verbose: bool
376+
_failsafe: bool
377+
_n_jobs: int
378+
graph_: nx.Graph
379+
405380
def __init__(
406381
self,
407382
cover: Optional[Cover] = None,
@@ -472,13 +447,14 @@ class FailSafeClustering(ParamsMixin):
472447
returned. This can be useful for robustness and debugging purposes.
473448
474449
:param clustering: A clustering algorithm to delegate to.
475-
:type clustering: An estimator compatible with scikit-learn's clustering
476-
interface, typically from :mod:`sklearn.cluster`.
477450
:param verbose: A flag to log clustering exceptions. Set to True to
478451
enable logging, or False to suppress it. Defaults to True.
479-
:type verbose: bool, optional.
480452
"""
481453

454+
_clustering: Clustering
455+
_verbose: bool
456+
labels_: List[int]
457+
482458
def __init__(self, clustering: Optional[Clustering] = None, verbose: bool = True):
483459
self.clustering = clustering
484460
self.verbose = verbose
@@ -488,7 +464,7 @@ def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None):
488464
TrivialClustering() if self.clustering is None else self.clustering
489465
)
490466
self._verbose = self.verbose
491-
self.labels_ = None
467+
self.labels_ = []
492468
try:
493469
self._clustering.fit(X, y)
494470
self.labels_ = self._clustering.labels_
@@ -509,15 +485,16 @@ class TrivialClustering(ParamsMixin):
509485
construction of the Mapper graph.
510486
"""
511487

488+
labels_: List[int]
489+
512490
def __init__(self):
513491
pass
514492

515-
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> TrivialClustering:
493+
def fit(self, X: ArrayLike, _: Optional[ArrayLike] = None) -> TrivialClustering:
516494
"""
517495
Fit the clustering algorithm to the data.
518496
519497
:param X: A dataset of n points.
520-
:type X: array-like of shape (n, m) or list-like of length n
521498
:param y: Ignored.
522499
:return: self
523500
"""

src/tdamapper/cover.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def _cover(self, X: ArrayLike) -> Generator[Tuple[int, List[int]], None, None]:
5858
landmark, and the list of ids contains the indices of its neighbors.
5959
6060
:param X: A dataset of n points.
61-
:type X: array-like of shape (n, m) or list-like of length n
6261
:return: A generator of pairs of indices and lists of ids.
63-
:rtype: generator of tuples (int, List[int])
6462
:yield: A tuple containing the index of the point and a list of ids of
6563
its covered neighbors. If the index is -1, it indicates a
6664
landmark, and the list contains the indices of its neighbors.

src/tdamapper/heap.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
and iterate over the elements in the heap.
55
"""
66

7+
from __future__ import annotations
8+
9+
from typing import Generic, Optional, Tuple, TypeVar
10+
711

812
def _left(i: int) -> int:
913
return 2 * i + 1
@@ -17,7 +21,11 @@ def _parent(i: int) -> int:
1721
return max(0, (i - 1) // 2)
1822

1923

20-
class _HeapNode:
24+
K = TypeVar("K")
25+
T = TypeVar("T")
26+
27+
28+
class _HeapNode(Generic[K, T]):
2129
"""
2230
A node in the max-heap, storing a key and a value.
2331
@@ -28,36 +36,34 @@ class _HeapNode:
2836
2937
:param key: The key used for ordering in the heap.
3038
:param value: The value associated with the key.
31-
:type key: Any
32-
:type value: Any
3339
"""
3440

35-
def __init__(self, key, value):
41+
def __init__(self, key: K, value: T):
3642
self._key = key
3743
self._value = value
3844

39-
def get(self):
45+
def get(self) -> Tuple[K, T]:
4046
"""
4147
Returns the key and value of the node.
4248
:return: A tuple containing the key and value.
4349
:rtype: tuple
4450
"""
4551
return self._key, self._value
4652

47-
def __lt__(self, other):
53+
def __lt__(self, other: _HeapNode[K, T]) -> bool:
4854
return self._key < other._key
4955

50-
def __le__(self, other):
56+
def __le__(self, other: _HeapNode[K, T]) -> bool:
5157
return self._key <= other._key
5258

53-
def __gt__(self, other):
59+
def __gt__(self, other: _HeapNode[K, T]) -> bool:
5460
return self._key > other._key
5561

56-
def __ge__(self, other):
62+
def __ge__(self, other: _HeapNode[K, T]) -> bool:
5763
return self._key >= other._key
5864

5965

60-
class MaxHeap:
66+
class MaxHeap(Generic[K, T]):
6167
"""
6268
A max-heap data structure that allows for efficient retrieval of the maximum element.
6369
@@ -69,58 +75,54 @@ def __init__(self):
6975
self._heap = []
7076
self._iter = None
7177

72-
def __iter__(self):
78+
def __iter__(self) -> MaxHeap[K, T]:
7379
self._iter = iter(self._heap)
7480
return self
7581

76-
def __next__(self):
82+
def __next__(self) -> Tuple[K, T]:
7783
node = next(self._iter)
7884
return node.get()
7985

80-
def __len__(self):
86+
def __len__(self) -> int:
8187
return len(self._heap)
8288

83-
def top(self):
89+
def top(self) -> Tuple[Optional[K], Optional[T]]:
8490
"""
8591
Returns the maximum element in the heap without removing it.
8692
8793
:return: A tuple containing the key and value of the maximum element, or (None, None) if the heap is empty.
88-
:rtype: tuple
8994
"""
9095
if not self._heap:
9196
return (None, None)
9297
return self._heap[0].get()
9398

94-
def pop(self):
99+
def pop(self) -> Optional[Tuple[K, T]]:
95100
"""
96101
Removes and returns the maximum element from the heap.
97102
98103
:return: A tuple containing the key and value of the maximum element, or None if the heap is empty.
99104
:rtype: tuple
100105
"""
101106
if not self._heap:
102-
return
107+
return None
103108
max_val = self._heap[0]
104109
self._heap[0] = self._heap[-1]
105110
self._heap.pop()
106111
self._bubble_down()
107112
return max_val.get()
108113

109-
def add(self, key, val):
114+
def add(self, key: K, val: T) -> None:
110115
"""
111116
Adds a new element to the heap with the specified key and value.
112117
113118
:param key: The key used for ordering in the heap.
114119
:param val: The value associated with the key.
115-
:type key: Any
116-
:type val: Any
117120
:return: None
118-
:rtype: None
119121
"""
120122
self._heap.append(_HeapNode(key, val))
121123
self._bubble_up()
122124

123-
def _get_local_max(self, i):
125+
def _get_local_max(self, i: int) -> int:
124126
heap_len = len(self._heap)
125127
left = _left(i)
126128
right = _right(i)
@@ -137,7 +139,7 @@ def _get_local_max(self, i):
137139
return max_child
138140
return i
139141

140-
def _fix_down(self, i):
142+
def _fix_down(self, i: int) -> int:
141143
local_max = self._get_local_max(i)
142144
if i < local_max:
143145
self._heap[i], self._heap[local_max] = (
@@ -147,22 +149,22 @@ def _fix_down(self, i):
147149
return local_max
148150
return i
149151

150-
def _fix_up(self, i):
152+
def _fix_up(self, i: int) -> int:
151153
parent = _parent(i)
152154
if self._heap[parent] < self._heap[i]:
153155
self._heap[i], self._heap[parent] = self._heap[parent], self._heap[i]
154156
return parent
155157
return i
156158

157-
def _bubble_down(self):
159+
def _bubble_down(self) -> None:
158160
current = 0
159161
done = False
160162
while not done:
161163
local_max = self._fix_down(current)
162164
done = current == local_max
163165
current = local_max
164166

165-
def _bubble_up(self):
167+
def _bubble_up(self) -> None:
166168
current = len(self._heap) - 1
167169
done = False
168170
while not done:

0 commit comments

Comments
 (0)