Skip to content

Commit 0eb73fe

Browse files
committed
Added types
1 parent 0c9e853 commit 0eb73fe

2 files changed

Lines changed: 91 additions & 52 deletions

File tree

src/tdamapper/_common.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22
This module provides common functionalities for internal use.
33
"""
44

5+
from __future__ import annotations
6+
57
import cProfile
68
import io
79
import pstats
810
import warnings
11+
from typing import Any, Callable, Dict
912

1013
import numpy as np
14+
from numpy.typing import NDArray
1115

1216
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
1317

1418

15-
def deprecated(msg):
19+
def deprecated(msg: str) -> Callable:
1620
def deprecated_func(func):
1721
def wrapper(*args, **kwargs):
1822
warnings.warn(msg, DeprecationWarning, stacklevel=2)
@@ -23,17 +27,17 @@ def wrapper(*args, **kwargs):
2327
return deprecated_func
2428

2529

26-
def warn_user(msg):
30+
def warn_user(msg: str) -> None:
2731
warnings.warn(msg, UserWarning, stacklevel=2)
2832

2933

3034
class EstimatorMixin:
3135

32-
def _is_sparse(self, X):
36+
def _is_sparse(self, X: NDArray) -> bool:
3337
# simple alternative use scipy.sparse.issparse
3438
return hasattr(X, "toarray")
3539

36-
def _validate_X_y(self, X, y):
40+
def _validate_X_y(self, X: NDArray, y: NDArray) -> tuple[NDArray, NDArray]:
3741
if self._is_sparse(X):
3842
raise ValueError("Sparse data not supported.")
3943

@@ -70,7 +74,7 @@ def _validate_X_y(self, X, y):
7074

7175
return X, y
7276

73-
def _set_n_features_in(self, X):
77+
def _set_n_features_in(self, X: NDArray) -> None:
7478
self.n_features_in_ = X.shape[1]
7579

7680

@@ -80,16 +84,16 @@ class ParamsMixin:
8084
scikit-learn `get_params` and `set_params`.
8185
"""
8286

83-
def _is_param_public(self, k):
87+
def _is_param_public(self, k: str) -> bool:
8488
return (not k.startswith("_")) and (not k.endswith("_"))
8589

86-
def _split_param(self, k):
90+
def _split_param(self, k: str) -> tuple[str, str]:
8791
k_split = k.split("__")
8892
outer = k_split[0]
8993
inner = "__".join(k_split[1:])
9094
return outer, inner
9195

92-
def get_params(self, deep=True):
96+
def get_params(self, deep: bool = True) -> Dict[str, Any]:
9397
"""
9498
Get all public parameters of the object as a dictionary.
9599
@@ -105,7 +109,7 @@ def get_params(self, deep=True):
105109
params[f"{k}__{_k}"] = _v
106110
return params
107111

108-
def set_params(self, **params):
112+
def set_params(self, **params: Dict[str, Any]) -> ParamsMixin:
109113
"""
110114
Set public parameters. Only updates attributes that already exist.
111115
"""
@@ -124,7 +128,7 @@ def set_params(self, **params):
124128
k_attr.set_params(**{k_inner: v})
125129
return self
126130

127-
def __repr__(self):
131+
def __repr__(self) -> str:
128132
obj_noargs = type(self)()
129133
args_repr = []
130134
for k, v in self.__dict__.items():
@@ -136,7 +140,7 @@ def __repr__(self):
136140
return f"{self.__class__.__name__}({', '.join(args_repr)})"
137141

138142

139-
def clone(obj):
143+
def clone(obj: Any) -> Any:
140144
"""
141145
Clone an estimator, returning a new one, unfitted, having the same public
142146
parameters.
@@ -152,7 +156,7 @@ def clone(obj):
152156
return obj_noargs
153157

154158

155-
def profile(n_lines=10):
159+
def profile(n_lines: int = 10) -> Callable:
156160
def decorator(func):
157161
def wrapper(*args, **kwargs):
158162
profiler = cProfile.Profile()

src/tdamapper/metrics.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,37 +26,25 @@
2626
- Cosine: A distance on unit vectors based on cosine similarity.
2727
"""
2828

29-
from typing import Callable, Union
29+
from enum import Enum
30+
from typing import Callable, List, Union
3031

3132
import numpy as np
3233

3334
import tdamapper._metrics as _metrics
3435

35-
_EUCLIDEAN = "euclidean"
36-
_MANHATTAN = "manhattan"
37-
_MINKOWSKI = "minkowski"
3836
_MINKOWSKI_P = "p"
39-
_CHEBYSHEV = "chebyshev"
40-
_COSINE = "cosine"
4137

4238

43-
def get_supported_metrics():
44-
"""
45-
Return a list of supported metric names.
46-
47-
:return: A list of supported metric names.
48-
:rtype: list of str
49-
"""
50-
return [
51-
_EUCLIDEAN,
52-
_MANHATTAN,
53-
_MINKOWSKI,
54-
_CHEBYSHEV,
55-
_COSINE,
56-
]
39+
class Metric(str, Enum):
40+
EUCLIDEAN = "euclidean"
41+
MANHATTAN = "manhattan"
42+
MINKOWSKI = "minkowski"
43+
CHEBYSHEV = "chebyshev"
44+
COSINE = "cosine"
5745

5846

59-
def euclidean():
47+
def euclidean(*args, **kwargs) -> Callable:
6048
"""
6149
Return the Euclidean distance function for vectors.
6250
@@ -69,7 +57,7 @@ def euclidean():
6957
return _metrics.euclidean
7058

7159

72-
def manhattan():
60+
def manhattan(*args, **kwargs) -> Callable:
7361
"""
7462
Return the Manhattan distance function for vectors.
7563
@@ -82,7 +70,7 @@ def manhattan():
8270
return _metrics.manhattan
8371

8472

85-
def chebyshev():
73+
def chebyshev(*args, **kwargs) -> Callable:
8674
"""
8775
Return the Chebyshev distance function for vectors.
8876
@@ -95,7 +83,7 @@ def chebyshev():
9583
return _metrics.chebyshev
9684

9785

98-
def minkowski(p):
86+
def minkowski(*args, **kwargs) -> Callable:
9987
"""
10088
Return the Minkowski distance function for order p on vectors.
10189
@@ -104,12 +92,11 @@ def minkowski(p):
10492
when p = 2, it is equivalent to the Euclidean distance. When p is infinite,
10593
it is equivalent to the Chebyshev distance.
10694
107-
:param p: The order of the Minkowski distance.
108-
:type p: int
109-
11095
:return: The Minkowski distance function.
11196
:rtype: callable
11297
"""
98+
p = kwargs.get(_MINKOWSKI_P, 2)
99+
113100
if p == 1:
114101
return manhattan()
115102
elif p == 2:
@@ -123,7 +110,7 @@ def dist(x, y):
123110
return dist
124111

125112

126-
def cosine():
113+
def cosine(*args, **kwargs) -> Callable:
127114
"""
128115
Return the cosine distance function for vectors.
129116
@@ -145,7 +132,42 @@ def cosine():
145132
return _metrics.cosine
146133

147134

148-
def get_metric(metric: Union[str, Callable], **kwargs) -> Callable:
135+
def _get_supported_metrics() -> List[str]:
136+
"""
137+
Return a list of supported metric names.
138+
139+
:return: A list of supported metric names.
140+
:rtype: list of str
141+
"""
142+
return [m.value for m in Metric]
143+
144+
145+
def get_metric_function(metric: Metric, *args, **kwargs) -> Callable:
146+
"""
147+
Return the distance function for the specified metric.
148+
149+
:param metric: The metric to use, as a string from the supported metrics.
150+
:type metric: Metric
151+
152+
:return: The selected distance metric function.
153+
:rtype: callable
154+
155+
:raises ValueError: If an invalid metric string is provided.
156+
"""
157+
match metric:
158+
case Metric.EUCLIDEAN:
159+
return euclidean(*args, **kwargs)
160+
case Metric.MANHATTAN:
161+
return manhattan(*args, **kwargs)
162+
case Metric.MINKOWSKI:
163+
return minkowski(*args, **kwargs)
164+
case Metric.CHEBYSHEV:
165+
return chebyshev(*args, **kwargs)
166+
case Metric.COSINE:
167+
return cosine(*args, **kwargs)
168+
169+
170+
def get_metric(metric: Union[str, Metric, Callable], *args, **kwargs) -> Callable:
149171
"""
150172
Return a distance function based on the specified string or callable.
151173
@@ -165,16 +187,29 @@ def get_metric(metric: Union[str, Callable], **kwargs) -> Callable:
165187
"""
166188
if callable(metric):
167189
return metric
168-
elif metric == _EUCLIDEAN:
169-
return euclidean()
170-
elif metric == _MANHATTAN:
171-
return manhattan()
172-
elif metric == _MINKOWSKI:
173-
p = kwargs.get(_MINKOWSKI_P, 2)
174-
return minkowski(p)
175-
elif metric == _CHEBYSHEV:
176-
return chebyshev()
177-
elif metric == _COSINE:
178-
return cosine()
190+
elif isinstance(metric, str):
191+
metric_enum = Metric(metric)
192+
if metric_enum not in _get_supported_metrics():
193+
raise ValueError(
194+
f"Unsupported metric: {metric}. "
195+
f"Supported metrics are: {', '.join(_get_supported_metrics())}"
196+
)
197+
return get_metric_function(metric_enum, *args, **kwargs)
198+
elif isinstance(metric, Metric):
199+
return get_metric_function(metric, *args, **kwargs)
179200
else:
180201
raise ValueError("metric must be a string or callable")
202+
203+
204+
def _first_run() -> None:
205+
"""
206+
Ensure that the metric functions are compiled with Numba on the first run.
207+
"""
208+
a = np.array([0.0, 1.0])
209+
b = np.array([1.0, 0.0])
210+
for metric in Metric:
211+
f = get_metric_function(metric)
212+
f(a, b) # Trigger the function to ensure it compiles with Numba
213+
214+
215+
_first_run() # Ensure the functions are compiled on first import

0 commit comments

Comments
 (0)