2626- Cosine: A distance on unit vectors based on cosine similarity.
2727"""
2828
29- from typing import Callable , Union
29+ from enum import Enum
30+ from typing import Callable , List , Union
3031
3132import numpy as np
3233
3334import tdamapper ._metrics as _metrics
3435
35- _EUCLIDEAN = "euclidean"
36- _MANHATTAN = "manhattan"
37- _MINKOWSKI = "minkowski"
3836_MINKOWSKI_P = "p"
39- _CHEBYSHEV = "chebyshev"
40- _COSINE = "cosine"
4137
4238
43- def get_supported_metrics ():
44- """
45- Return a list of supported metric names.
46-
47- :return: A list of supported metric names.
48- :rtype: list of str
49- """
50- return [
51- _EUCLIDEAN ,
52- _MANHATTAN ,
53- _MINKOWSKI ,
54- _CHEBYSHEV ,
55- _COSINE ,
56- ]
39+ class Metric (str , Enum ):
40+ EUCLIDEAN = "euclidean"
41+ MANHATTAN = "manhattan"
42+ MINKOWSKI = "minkowski"
43+ CHEBYSHEV = "chebyshev"
44+ COSINE = "cosine"
5745
5846
59- def euclidean () :
47+ def euclidean (* args , ** kwargs ) -> Callable :
6048 """
6149 Return the Euclidean distance function for vectors.
6250
@@ -69,7 +57,7 @@ def euclidean():
6957 return _metrics .euclidean
7058
7159
72- def manhattan () :
60+ def manhattan (* args , ** kwargs ) -> Callable :
7361 """
7462 Return the Manhattan distance function for vectors.
7563
@@ -82,7 +70,7 @@ def manhattan():
8270 return _metrics .manhattan
8371
8472
85- def chebyshev () :
73+ def chebyshev (* args , ** kwargs ) -> Callable :
8674 """
8775 Return the Chebyshev distance function for vectors.
8876
@@ -95,7 +83,7 @@ def chebyshev():
9583 return _metrics .chebyshev
9684
9785
98- def minkowski (p ) :
86+ def minkowski (* args , ** kwargs ) -> Callable :
9987 """
10088 Return the Minkowski distance function for order p on vectors.
10189
@@ -104,12 +92,11 @@ def minkowski(p):
10492 when p = 2, it is equivalent to the Euclidean distance. When p is infinite,
10593 it is equivalent to the Chebyshev distance.
10694
107- :param p: The order of the Minkowski distance.
108- :type p: int
109-
11095 :return: The Minkowski distance function.
11196 :rtype: callable
11297 """
98+ p = kwargs .get (_MINKOWSKI_P , 2 )
99+
113100 if p == 1 :
114101 return manhattan ()
115102 elif p == 2 :
@@ -123,7 +110,7 @@ def dist(x, y):
123110 return dist
124111
125112
126- def cosine () :
113+ def cosine (* args , ** kwargs ) -> Callable :
127114 """
128115 Return the cosine distance function for vectors.
129116
@@ -145,7 +132,42 @@ def cosine():
145132 return _metrics .cosine
146133
147134
148- def get_metric (metric : Union [str , Callable ], ** kwargs ) -> Callable :
135+ def _get_supported_metrics () -> List [str ]:
136+ """
137+ Return a list of supported metric names.
138+
139+ :return: A list of supported metric names.
140+ :rtype: list of str
141+ """
142+ return [m .value for m in Metric ]
143+
144+
145+ def get_metric_function (metric : Metric , * args , ** kwargs ) -> Callable :
146+ """
147+ Return the distance function for the specified metric.
148+
149+ :param metric: The metric to use, as a string from the supported metrics.
150+ :type metric: Metric
151+
152+ :return: The selected distance metric function.
153+ :rtype: callable
154+
155+ :raises ValueError: If an invalid metric string is provided.
156+ """
157+ match metric :
158+ case Metric .EUCLIDEAN :
159+ return euclidean (* args , ** kwargs )
160+ case Metric .MANHATTAN :
161+ return manhattan (* args , ** kwargs )
162+ case Metric .MINKOWSKI :
163+ return minkowski (* args , ** kwargs )
164+ case Metric .CHEBYSHEV :
165+ return chebyshev (* args , ** kwargs )
166+ case Metric .COSINE :
167+ return cosine (* args , ** kwargs )
168+
169+
170+ def get_metric (metric : Union [str , Metric , Callable ], * args , ** kwargs ) -> Callable :
149171 """
150172 Return a distance function based on the specified string or callable.
151173
@@ -165,16 +187,29 @@ def get_metric(metric: Union[str, Callable], **kwargs) -> Callable:
165187 """
166188 if callable (metric ):
167189 return metric
168- elif metric == _EUCLIDEAN :
169- return euclidean ()
170- elif metric == _MANHATTAN :
171- return manhattan ()
172- elif metric == _MINKOWSKI :
173- p = kwargs .get (_MINKOWSKI_P , 2 )
174- return minkowski (p )
175- elif metric == _CHEBYSHEV :
176- return chebyshev ()
177- elif metric == _COSINE :
178- return cosine ()
190+ elif isinstance (metric , str ):
191+ metric_enum = Metric (metric )
192+ if metric_enum not in _get_supported_metrics ():
193+ raise ValueError (
194+ f"Unsupported metric: { metric } . "
195+ f"Supported metrics are: { ', ' .join (_get_supported_metrics ())} "
196+ )
197+ return get_metric_function (metric_enum , * args , ** kwargs )
198+ elif isinstance (metric , Metric ):
199+ return get_metric_function (metric , * args , ** kwargs )
179200 else :
180201 raise ValueError ("metric must be a string or callable" )
202+
203+
204+ def _first_run () -> None :
205+ """
206+ Ensure that the metric functions are compiled with Numba on the first run.
207+ """
208+ a = np .array ([0.0 , 1.0 ])
209+ b = np .array ([1.0 , 0.0 ])
210+ for metric in Metric :
211+ f = get_metric_function (metric )
212+ f (a , b ) # Trigger the function to ensure it compiles with Numba
213+
214+
215+ _first_run () # Ensure the functions are compiled on first import
0 commit comments