Skip to content

Commit a31cb31

Browse files
committed
Improved docstrings
1 parent be036b0 commit a31cb31

16 files changed

Lines changed: 497 additions & 82 deletions

File tree

src/tdamapper/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def _init_draw_settings(self):
533533
).classes("w-full")
534534

535535
def _init_footnotes(self):
536-
ui.label(text=("Made in Rome, with ❤️ and ☕️.")).classes(
536+
ui.label(text="Made in Rome, with ❤️ and ☕️.").classes(
537537
"text-caption text-gray-500"
538538
).classes("text-caption text-gray-500")
539539

@@ -543,7 +543,7 @@ def _init_draw_area(self):
543543

544544
def get_mapper_config(self):
545545
return MapperConfig(
546-
lens_type=str(self.lens_type.value) if self.lens_type.value else LENS_PCA,
546+
lens_type=(str(self.lens_type.value) if self.lens_type.value else LENS_PCA),
547547
cover_type=(
548548
str(self.cover_type.value) if self.cover_type.value else COVER_CUBICAL
549549
),

src/tdamapper/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,16 @@
3131
from __future__ import annotations
3232

3333
import logging
34-
from typing import Any, Callable, Dict, Generator, List, Optional, Protocol, Union
34+
from typing import (
35+
Any,
36+
Callable,
37+
Dict,
38+
Generator,
39+
List,
40+
Optional,
41+
Protocol,
42+
Union,
43+
)
3544

3645
import networkx as nx
3746
import numpy as np

src/tdamapper/cover.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020
from tdamapper._common import ParamsMixin
2121
from tdamapper.core import ArrayLike, PointLike, SpatialSearch
22-
from tdamapper.search import BallSearch, CubicalLandmarks, CubicalSearch, KNNSearch
22+
from tdamapper.search import (
23+
BallSearch,
24+
CubicalLandmarks,
25+
CubicalSearch,
26+
KNNSearch,
27+
)
2328

2429

2530
class ProximityNet(ABC, ParamsMixin):
@@ -454,13 +459,11 @@ def _get_cubical_cover(self):
454459
)
455460
if self.algorithm == "proximity":
456461
return ProximityCubicalCover(**params)
457-
elif self.algorithm == "standard":
462+
if self.algorithm == "standard":
458463
return StandardCubicalCover(**params)
459-
else:
460-
raise ValueError(
461-
"The only possible values for algorithm are 'standard' and "
462-
"'proximity'."
463-
)
464+
raise ValueError(
465+
"The only possible values for algorithm are 'standard' and 'proximity'."
466+
)
464467

465468
def fit(self, X: ArrayLike) -> CubicalCover:
466469
"""

src/tdamapper/heap.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def _fix_down(self, i: int) -> int:
152152
def _fix_up(self, i: int) -> int:
153153
parent = _parent(i)
154154
if self._heap[parent] < self._heap[i]:
155-
self._heap[i], self._heap[parent] = self._heap[parent], self._heap[i]
155+
self._heap[i], self._heap[parent] = (
156+
self._heap[parent],
157+
self._heap[i],
158+
)
156159
return parent
157160
return i
158161

src/tdamapper/metrics.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,25 @@
3131

3232
import numpy as np
3333

34-
import tdamapper._metrics as _metrics
34+
from tdamapper._metrics import (
35+
chebyshev as _chebyshev,
36+
cosine as _cosine,
37+
euclidean as _euclidean,
38+
manhattan as _manhattan,
39+
minkowski as _minkowski,
40+
)
3541

3642
_MINKOWSKI_P = "p"
3743

3844

3945
class Metric(str, Enum):
46+
"""
47+
Enum representing supported distance metrics.
48+
49+
Each metric corresponds to a specific distance function that can be
50+
used to compute distances between vectors.
51+
"""
52+
4053
EUCLIDEAN = "euclidean"
4154
MANHATTAN = "manhattan"
4255
MINKOWSKI = "minkowski"
@@ -54,7 +67,7 @@ def euclidean(*args, **kwargs) -> Callable:
5467
:return: The Euclidean distance function.
5568
:rtype: callable
5669
"""
57-
return _metrics.euclidean
70+
return _euclidean
5871

5972

6073
def manhattan(*args, **kwargs) -> Callable:
@@ -67,7 +80,7 @@ def manhattan(*args, **kwargs) -> Callable:
6780
:return: The Manhattan distance function.
6881
:rtype: callable
6982
"""
70-
return _metrics.manhattan
83+
return _manhattan
7184

7285

7386
def chebyshev(*args, **kwargs) -> Callable:
@@ -80,7 +93,7 @@ def chebyshev(*args, **kwargs) -> Callable:
8093
:return: The Chebyshev distance function.
8194
:rtype: callable
8295
"""
83-
return _metrics.chebyshev
96+
return _chebyshev
8497

8598

8699
def minkowski(*args, **kwargs) -> Callable:
@@ -105,7 +118,7 @@ def minkowski(*args, **kwargs) -> Callable:
105118
return chebyshev()
106119

107120
def dist(x, y):
108-
return _metrics.minkowski(p, x, y)
121+
return _minkowski(p, x, y)
109122

110123
return dist
111124

@@ -129,7 +142,7 @@ def cosine(*args, **kwargs) -> Callable:
129142
:return: The cosine distance function.
130143
:rtype: callable
131144
"""
132-
return _metrics.cosine
145+
return _cosine
133146

134147

135148
def _get_supported_metrics() -> List[str]:
@@ -152,13 +165,13 @@ def get_metric_function(metric: Metric, *args, **kwargs) -> Callable:
152165
"""
153166
if metric == Metric.EUCLIDEAN:
154167
return euclidean(*args, **kwargs)
155-
elif metric == Metric.MANHATTAN:
168+
if metric == Metric.MANHATTAN:
156169
return manhattan(*args, **kwargs)
157-
elif metric == Metric.MINKOWSKI:
170+
if metric == Metric.MINKOWSKI:
158171
return minkowski(*args, **kwargs)
159-
elif metric == Metric.CHEBYSHEV:
172+
if metric == Metric.CHEBYSHEV:
160173
return chebyshev(*args, **kwargs)
161-
elif metric == Metric.COSINE:
174+
if metric == Metric.COSINE:
162175
return cosine(*args, **kwargs)
163176
raise ValueError(
164177
f"Unsupported metric: {metric}. "

src/tdamapper/plot_plotly.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,53 +86,46 @@ def _to_cmaps(cmap: Optional[Union[str, List[str]]]) -> List[str]:
8686
return [DEFAULT_CMAP]
8787
if isinstance(cmap, str):
8888
return [cmap]
89-
elif isinstance(cmap, list):
89+
if isinstance(cmap, list):
9090
return cmap
91-
else:
92-
raise ValueError(f"Invalid cmap type: {type(cmap)}. Expected str or list[str].")
91+
raise ValueError(f"Invalid cmap type: {type(cmap)}. Expected str or list[str].")
9392

9493

9594
def _to_colors(colors: Union[np.ndarray, List[float]]) -> np.ndarray:
9695
"""Convert colors to a numpy array."""
9796
colors_arr = np.array(colors)
9897
if colors_arr.ndim == 1:
9998
return colors_arr.reshape(-1, 1)
100-
elif colors_arr.ndim == 2:
99+
if colors_arr.ndim == 2:
101100
return colors_arr
102-
else:
103-
raise ValueError(
104-
f"Invalid colors shape: {colors_arr.shape}. Expected 1D or 2D array."
105-
)
101+
raise ValueError(
102+
f"Invalid colors shape: {colors_arr.shape}. Expected 1D or 2D array."
103+
)
106104

107105

108106
def _to_titles(title: Optional[Union[str, List[str]]], colors_num: int) -> List[str]:
109107
if title is None:
110108
return [f"{i}" for i in range(colors_num)]
111-
elif isinstance(title, str):
109+
if isinstance(title, str):
112110
if colors_num == 1:
113111
return [title]
114-
else:
115-
return [f"{title} [{i}]" for i in range(colors_num)]
116-
elif isinstance(title, list) and len(title) == colors_num:
112+
return [f"{title} [{i}]" for i in range(colors_num)]
113+
if isinstance(title, list) and len(title) == colors_num:
117114
return title
118-
else:
119-
raise ValueError(
120-
f"Invalid title type: {type(title)}. Expected str or list[str]."
121-
)
115+
raise ValueError(f"Invalid title type: {type(title)}. Expected str or list[str].")
122116

123117

124118
def _to_node_sizes(
125119
node_size: Optional[Union[int, float, List[Union[int, float]]]],
126120
) -> List[float]:
127121
if isinstance(node_size, (int, float)):
128122
return [node_size]
129-
elif isinstance(node_size, list):
123+
if isinstance(node_size, list):
130124
return node_size
131-
else:
132-
raise ValueError(
133-
f"Invalid node_size type: {type(node_size)}. "
134-
"Expected int, float or list[int, float]."
135-
)
125+
raise ValueError(
126+
f"Invalid node_size type: {type(node_size)}. "
127+
"Expected int, float or list[int, float]."
128+
)
136129

137130

138131
def _get_cmap_rgb(cmap: str):
@@ -152,6 +145,19 @@ def plot_plotly(
152145
width: Optional[int] = None,
153146
height: Optional[int] = None,
154147
) -> go.Figure:
148+
"""
149+
Plot a Mapper graph using Plotly.
150+
151+
:param mapper_plot: The Mapper plot object containing the graph and positions.
152+
:param colors: Colors for the nodes, can be a 1D or 2D numpy array or a list.
153+
:param node_size: Size of the nodes, can be a single value or a list of values.
154+
:param title: Title for the color bar, can be a string or a list of strings.
155+
:param agg: Aggregation function to apply to the colors, defaults to np.nanmean.
156+
:param cmap: Colormap to use for the nodes, can be a string or a list of strings.
157+
:param width: Width of the plot, defaults to None (auto).
158+
:param height: Height of the plot, defaults to None (auto).
159+
:return: A Plotly Figure object containing the Mapper graph.
160+
"""
155161
cmaps = _to_cmaps(cmap)
156162
colors = _to_colors(colors)
157163
colors_num = colors.shape[1]

src/tdamapper/vptree_flat/ball_search.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
VP-tree Ball Search Module.
33
4-
This module provides a BallSearch class for searching points within a specified distance (epsilon)
5-
from a given point in a VP-tree. It uses an iterative approach to traverse the VP-tree
6-
and collect points that meet the distance criteria.
4+
This module provides a BallSearch class for searching points within a specified
5+
distance (epsilon) from a given point in a VP-tree. It uses an iterative
6+
approach to traverse the VP-tree and collect points that meet the distance
7+
criteria.
78
"""
89

910
from typing import Generic, List, TypeVar
@@ -18,16 +19,16 @@ class BallSearch(Generic[T]):
1819
BallSearch class for searching points within a specified distance (epsilon)
1920
from a given point in a VP-tree.
2021
21-
This class performs a search in a VP-tree to find all points that are within
22-
a specified distance (epsilon) from a given point. It uses an iterative
23-
approach to traverse the VP-tree and collect points that meet the distance
24-
criteria.
22+
This class performs a search in a VP-tree to find all points that are
23+
within a specified distance (epsilon) from a given point. It uses an
24+
iterative approach to traverse the VP-tree and collect points that meet the
25+
distance criteria.
2526
2627
:param vpt: VPTreeType instance containing distance function and parameters.
2728
:param point: The point from which the search is performed.
2829
:param eps: The distance threshold (epsilon) for the search.
29-
:param inclusive: If True, points exactly at distance eps are included in the
30-
results. If False, they are excluded. Defaults to True.
30+
:param inclusive: If True, points exactly at distance eps are included in
31+
the results. If False, they are excluded. Defaults to True.
3132
"""
3233

3334
def __init__(
@@ -45,13 +46,14 @@ def __init__(
4546

4647
def search(self) -> List[T]:
4748
"""
48-
Perform the search for points within the specified distance from the point.
49+
Perform the search for points within the specified distance from the
50+
point.
4951
5052
This method initiates the search process and returns a list of points
5153
that are within the specified distance (epsilon) from the given point.
5254
53-
:return: A list of points that are within the specified distance from the
54-
given point.
55+
:return: A list of points that are within the specified distance from
56+
the given point.
5557
"""
5658
return self._search_iter()
5759

src/tdamapper/vptree_flat/builder.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
2-
VP-tree Builder Module
3-
This module provides a Builder class for constructing a VP-tree from a collection of items.
4-
It supports different pivoting strategies and allows customization of the tree's parameters.
2+
VP-tree Builder Module.
3+
4+
This module provides a Builder class for constructing a VP-tree from a
5+
collection of items. It supports different pivoting strategies and allows
6+
customization of the tree's parameters.
57
"""
68

79
from __future__ import annotations
@@ -25,10 +27,11 @@ class Builder(Generic[T]):
2527
"""
2628
Builder for constructing a VP-tree from a collection of items.
2729
28-
This class takes a VPTreeType and an iterable of items, and builds a VP-tree
29-
using the specified pivoting strategy and parameters.
30+
This class takes a VPTreeType and an iterable of items, and builds a
31+
VP-tree using the specified pivoting strategy and parameters.
3032
31-
:param vpt: VPTreeType instance containing distance function and parameters.
33+
:param vpt: VPTreeType instance containing distance function and
34+
parameters.
3235
:param items: Iterable of items to be included in the VP-tree.
3336
"""
3437

@@ -92,9 +95,12 @@ def _update(self, start: int, end: int) -> None:
9295

9396
def build(self) -> VPArray[T]:
9497
"""
95-
Build the VP-tree from the given items.
98+
Build the VP-tree from the items provided during initialization.
99+
100+
This method constructs the VP-tree iteratively, starting from the root
101+
node.
96102
97-
:return: VPArray containing the VP-tree structure.
103+
:return: The VPArray instance containing the constructed VP-tree.
98104
"""
99105
self._build_iter()
100106
return self._arr

src/tdamapper/vptree_flat/common.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1-
""" """
1+
"""
2+
Common types and structures for VP-tree implementation.
3+
4+
This module defines the basic types and structures used in the VP-tree
5+
implementation, including the VPArray for managing the dataset and distances,
6+
the Node and Leaf classes for representing the tree structure, and the
7+
VPTreeType protocol for type checking.
8+
"""
29

310
from __future__ import annotations
411

5-
from typing import Callable, Generic, Iterable, List, Optional, Protocol, TypeVar
12+
from typing import (
13+
Callable,
14+
Generic,
15+
Iterable,
16+
List,
17+
Optional,
18+
Protocol,
19+
TypeVar,
20+
)
621

722
import numpy as np
823
from numpy.typing import NDArray
@@ -78,7 +93,7 @@ class VPArray(Generic[T]):
7893
:param dataset: A list of points of type T.
7994
:param distances: A NumPy array of distances corresponding to the points.
8095
:param indices: A NumPy array of indices mapping points to their positions
81-
in the dataset
96+
in the dataset.
8297
:param is_terminal: A NumPy array indicating whether each point is terminal
8398
(True) or not (False). A terminal point is a leaf node in the VP-tree.
8499
It is used to determine if a point is a leaf node in the VP-tree.

0 commit comments

Comments
 (0)