88import io
99import pstats
1010import warnings
11- from typing import Any , Callable , Dict
11+ from typing import Any , Callable , Dict , List , Union
1212
1313import numpy as np
1414from numpy .typing import NDArray
1515
1616warnings .filterwarnings ("default" , category = DeprecationWarning , module = r"^tdamapper\." )
1717
1818
19+ PointLike = Union [Any , NDArray [np .float64 ]]
20+
21+ ArrayLike = Union [List [Any ], NDArray [np .float64 ]]
22+
23+
1924def deprecated (msg : str ) -> Callable :
25+ """
26+ Decorator to mark a function as deprecated.
27+
28+ :param msg: A message to be shown when the function is called.
29+ :return: A decorator that wraps the function and issues a warning when called.
30+ """
31+
2032 def deprecated_func (func ):
2133 def wrapper (* args , ** kwargs ):
2234 warnings .warn (msg , DeprecationWarning , stacklevel = 2 )
@@ -28,54 +40,91 @@ def wrapper(*args, **kwargs):
2840
2941
3042def warn_user (msg : str ) -> None :
43+ """
44+ Issues a warning to the user.
45+
46+ :param msg: A message to be shown to the user.
47+ """
3148 warnings .warn (msg , UserWarning , stacklevel = 2 )
3249
3350
3451class EstimatorMixin :
52+ """
53+ Mixin to add common functionalities to estimators, such as validation of
54+ input data, setting the number of features, and checking for sparse data.
55+ This mixin is intended to be used with scikit-learn compatible estimators.
56+ """
57+
58+ def _is_sparse (self , x_arr : ArrayLike ) -> bool :
59+ """
60+ Checks if the input array `x_arr` is sparse.
3561
36- def _is_sparse (self , X : NDArray ) -> bool :
62+ :param x_arr: The input array to check.
63+ :return: True if `x_arr` is sparse, False otherwise.
64+ """
3765 # simple alternative use scipy.sparse.issparse
38- return hasattr (X , "toarray" )
66+ return hasattr (x_arr , "toarray" )
3967
40- def _validate_X_y (self , X : NDArray , y : NDArray ) -> tuple [NDArray , NDArray ]:
41- if self ._is_sparse (X ):
68+ def _validate_x_y (
69+ self ,
70+ x_arr : ArrayLike ,
71+ y_arr : ArrayLike ,
72+ ) -> tuple [NDArray , NDArray ]:
73+ """
74+ Validates the input arrays `x_arr` and `y_arr`.
75+
76+ :param x_arr: The input features array.
77+ :param y_arr: The target values array.
78+ :return: A tuple of validated numpy arrays (x_arr, y_arr).
79+ :raises ValueError: If the input arrays are not valid, e.g., if they
80+ are sparse, empty, 1-dimensional, contain complex numbers, or have
81+ NaNs or infinite values.
82+ """
83+ if self ._is_sparse (x_arr ):
4284 raise ValueError ("Sparse data not supported." )
4385
44- X = np .asarray (X )
45- y = np .asarray (y )
86+ x_arr_ = np .asarray (x_arr )
87+ y_arr_ = np .asarray (y_arr )
4688
47- if X .size == 0 :
48- msg = f"0 feature(s) (shape={ X .shape } ) while a minimum of 1 is " "required."
89+ if x_arr_ .size == 0 :
90+ msg = (
91+ f"0 feature(s) (shape={ x_arr_ .shape } ) while a minimum of 1 is "
92+ "required."
93+ )
4994 raise ValueError (msg )
5095
51- if y .size == 0 :
52- msg = f"0 feature(s) (shape={ y .shape } ) while a minimum of 1 is " "required."
96+ if y_arr_ .size == 0 :
97+ msg = (
98+ f"0 feature(s) (shape={ y_arr_ .shape } ) while a minimum of 1 is "
99+ "required."
100+ )
53101 raise ValueError (msg )
54102
55- if X .ndim == 1 :
103+ if x_arr_ .ndim == 1 :
56104 raise ValueError ("1d-arrays not supported." )
57105
58- if np .iscomplexobj (X ) or np .iscomplexobj (y ):
106+ if np .iscomplexobj (x_arr_ ) or np .iscomplexobj (y_arr_ ):
59107 raise ValueError ("Complex data not supported." )
60108
61- if X .dtype == np .object_ :
62- X = np .array (X , dtype = float )
109+ if x_arr_ .dtype == np .object_ :
110+ x_arr_ = np .array (x_arr_ , dtype = float )
63111
64- if y .dtype == np .object_ :
65- y = np .array (y , dtype = float )
112+ if y_arr_ .dtype == np .object_ :
113+ y_arr_ = np .array (y_arr_ , dtype = float )
66114
67115 if (
68- np .isnan (X ).any ()
69- or np .isinf (X ).any ()
70- or np .isnan (y ).any ()
71- or np .isinf (y ).any ()
116+ np .isnan (x_arr_ ).any ()
117+ or np .isinf (x_arr_ ).any ()
118+ or np .isnan (y_arr_ ).any ()
119+ or np .isinf (y_arr_ ).any ()
72120 ):
73121 raise ValueError ("NaNs or infinite values not supported." )
74122
75- return X , y
123+ return x_arr_ , y_arr_
76124
77- def _set_n_features_in (self , X : NDArray ) -> None :
78- self .n_features_in_ = X .shape [1 ]
125+ def _set_n_features_in (self , arr : ArrayLike ) -> None :
126+ if hasattr (arr , "shape" ):
127+ self .n_features_in_ = arr .shape [1 ]
79128
80129
81130class ParamsMixin :
@@ -156,6 +205,14 @@ def clone(obj: Any) -> Any:
156205
157206
158207def profile (n_lines : int = 10 ) -> Callable :
208+ """
209+ Decorator to profile a function using cProfile and print the top `n_lines`
210+ lines of the profiling report.
211+
212+ :param n_lines: The number of lines to print from the profiling report.
213+ :return: A decorator that wraps the function and profiles its execution.
214+ """
215+
159216 def decorator (func ):
160217 def wrapper (* args , ** kwargs ):
161218 profiler = cProfile .Profile ()
0 commit comments