22This module provides common functionalities for internal use.
33"""
44
5+ from __future__ import annotations
6+
57import cProfile
68import io
79import pstats
810import warnings
11+ from typing import Any , Callable , Union
912
1013import numpy as np
14+ from numpy .typing import NDArray
1115
1216warnings .filterwarnings ("default" , category = DeprecationWarning , module = r"^tdamapper\." )
1317
18+ PointLike = Union [NDArray [np .float64 ], list [Any ]]
19+ ArrayLike = Union [NDArray [np .float64 ], list [PointLike ]]
20+
21+
22+ def deprecated (msg : str ) -> Callable :
23+ """
24+ Decorator to mark a function as deprecated.
25+
26+ :param msg: A message to be shown when the function is called.
27+ """
1428
15- def deprecated (msg ):
16- def deprecated_func (func ):
29+ def deprecated_func (func : Callable ) -> Callable :
1730 def wrapper (* args , ** kwargs ):
1831 warnings .warn (msg , DeprecationWarning , stacklevel = 2 )
1932 return func (* args , ** kwargs )
@@ -23,17 +36,45 @@ def wrapper(*args, **kwargs):
2336 return deprecated_func
2437
2538
26- def warn_user (msg ):
39+ def warn_user (msg : str ) -> None :
40+ """
41+ Warn the user with a message.
42+
43+ :param msg: A message to be shown to the user.
44+ """
2745 warnings .warn (msg , UserWarning , stacklevel = 2 )
2846
2947
3048class EstimatorMixin :
49+ """
50+ Mixin to add common functionalities to estimators, such as validation of
51+ input data, setting the number of features, and checking for sparse data.
52+ """
53+
54+ def _is_sparse (self , X : ArrayLike ) -> bool :
55+ """
56+ Check if the input data is sparse.
3157
32- def _is_sparse (self , X ):
58+ :param X: Input data.
59+ :type X: array-like, shape (n_samples, n_features)
60+ """
3361 # simple alternative use scipy.sparse.issparse
3462 return hasattr (X , "toarray" )
3563
36- def _validate_X_y (self , X , y ):
64+ def _validate_X_y (
65+ self , X : ArrayLike , y : ArrayLike
66+ ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
67+ """
68+ Validate input data and target values.
69+
70+ :param X: Input data.
71+ :type X: array-like, shape (n_samples, n_features)
72+ :param y: Target values.
73+ :type y: array-like, shape (n_samples,)
74+ :return: Validated input data and target values.
75+ :rtype: tuple of (X, y)
76+ :raises ValueError: If the input data or target values are invalid.
77+ """
3778 if self ._is_sparse (X ):
3879 raise ValueError ("Sparse data not supported." )
3980
@@ -70,8 +111,9 @@ def _validate_X_y(self, X, y):
70111
71112 return X , y
72113
73- def _set_n_features_in (self , X ):
74- self .n_features_in_ = X .shape [1 ]
114+ def _set_n_features_in (self , X : ArrayLike ) -> None :
115+ if hasattr (X , "shape" ):
116+ self .n_features_in_ = X .shape [1 ]
75117
76118
77119class ParamsMixin :
@@ -80,16 +122,16 @@ class ParamsMixin:
80122 scikit-learn `get_params` and `set_params`.
81123 """
82124
83- def _is_param_public (self , k ) :
125+ def _is_param_public (self , k : str ) -> bool :
84126 return (not k .startswith ("_" )) and (not k .endswith ("_" ))
85127
86- def _split_param (self , k ):
128+ def _split_param (self , k : str ):
87129 k_split = k .split ("__" )
88130 outer = k_split [0 ]
89131 inner = "__" .join (k_split [1 :])
90132 return outer , inner
91133
92- def get_params (self , deep = True ):
134+ def get_params (self , deep : bool = True ) -> dict [ str , Any ] :
93135 """
94136 Get all public parameters of the object as a dictionary.
95137
@@ -105,7 +147,7 @@ def get_params(self, deep=True):
105147 params [f"{ k } __{ _k } " ] = _v
106148 return params
107149
108- def set_params (self , ** params ) :
150+ def set_params (self , ** params : dict [ str , Any ]) -> ParamsMixin :
109151 """
110152 Set public parameters. Only updates attributes that already exist.
111153 """
@@ -124,7 +166,7 @@ def set_params(self, **params):
124166 k_attr .set_params (** {k_inner : v })
125167 return self
126168
127- def __repr__ (self ):
169+ def __repr__ (self ) -> str :
128170 obj_noargs = type (self )()
129171 args_repr = []
130172 for k , v in self .__dict__ .items ():
@@ -136,7 +178,7 @@ def __repr__(self):
136178 return f"{ self .__class__ .__name__ } ({ ', ' .join (args_repr )} )"
137179
138180
139- def clone (obj ) :
181+ def clone (obj : Any ) -> Any :
140182 """
141183 Clone an estimator, returning a new one, unfitted, having the same public
142184 parameters.
@@ -152,8 +194,15 @@ def clone(obj):
152194 return obj_noargs
153195
154196
155- def profile (n_lines = 10 ):
156- def decorator (func ):
197+ def profile (n_lines : int = 10 ) -> Callable :
198+ """
199+ Decorator to profile a function using cProfile and print the top `n_lines`
200+ lines of the profiling report.
201+
202+ :param n_lines: Number of lines to print from the profiling report.
203+ """
204+
205+ def decorator (func : Callable ) -> Callable :
157206 def wrapper (* args , ** kwargs ):
158207 profiler = cProfile .Profile ()
159208 profiler .enable ()
0 commit comments