22This module provides common functionalities for internal use.
33"""
44
5+ from __future__ import annotations
6+
57import cProfile
68import io
79import pstats
810import warnings
11+ from typing import Any , Callable , Protocol
912
1013import numpy as np
14+ from numpy .typing import NDArray
1115
1216warnings .filterwarnings ("default" , category = DeprecationWarning , module = r"^tdamapper\." )
1317
1418
15- def deprecated (msg ):
16- def deprecated_func (func ):
17- def wrapper (* args , ** kwargs ):
19+ class Array (Protocol ):
20+
21+ def __getitem__ (self , index : int ) -> Any :
22+ """
23+ Get an item from the array.
24+ """
25+
26+ def __len__ (self ) -> int :
27+ """
28+ Get the length of the array.
29+ """
30+
31+ def __setitem__ (self , index : int , value : Any ) -> None :
32+ """
33+ Set an item in the array.
34+ """
35+
36+
37+ def deprecated (msg : str ) -> Callable [..., Any ]:
38+ def deprecated_func (func : Callable [..., Any ]) -> Callable [..., Any ]:
39+ def wrapper (* args : list [Any ], ** kwargs : dict [str , Any ]) -> Any :
1840 warnings .warn (msg , DeprecationWarning , stacklevel = 2 )
1941 return func (* args , ** kwargs )
2042
@@ -23,17 +45,19 @@ def wrapper(*args, **kwargs):
2345 return deprecated_func
2446
2547
26- def warn_user (msg ) :
48+ def warn_user (msg : str ) -> None :
2749 warnings .warn (msg , UserWarning , stacklevel = 2 )
2850
2951
3052class EstimatorMixin :
3153
32- def _is_sparse (self , X ) :
54+ def _is_sparse (self , X : Array ) -> bool :
3355 # simple alternative use scipy.sparse.issparse
3456 return hasattr (X , "toarray" )
3557
36- def _validate_X_y (self , X , y ):
58+ def _validate_X_y (
59+ self , X : Array , y : Array
60+ ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
3761 if self ._is_sparse (X ):
3862 raise ValueError ("Sparse data not supported." )
3963
@@ -70,8 +94,9 @@ def _validate_X_y(self, X, y):
7094
7195 return X , y
7296
73- def _set_n_features_in (self , X ):
74- self .n_features_in_ = X .shape [1 ]
97+ def _set_n_features_in (self , X : Array ) -> None :
98+ if hasattr (X , "shape" ):
99+ self .n_features_in_ = X .shape [1 ]
75100
76101
77102class ParamsMixin :
@@ -80,16 +105,16 @@ class ParamsMixin:
80105 scikit-learn `get_params` and `set_params`.
81106 """
82107
83- def _is_param_public (self , k ) :
108+ def _is_param_public (self , k : str ) -> bool :
84109 return (not k .startswith ("_" )) and (not k .endswith ("_" ))
85110
86- def _split_param (self , k ) :
111+ def _split_param (self , k : str ) -> tuple [ str , str ] :
87112 k_split = k .split ("__" )
88113 outer = k_split [0 ]
89114 inner = "__" .join (k_split [1 :])
90115 return outer , inner
91116
92- def get_params (self , deep = True ):
117+ def get_params (self , deep : bool = True ) -> dict [ str , Any ] :
93118 """
94119 Get all public parameters of the object as a dictionary.
95120
@@ -105,7 +130,7 @@ def get_params(self, deep=True):
105130 params [f"{ k } __{ _k } " ] = _v
106131 return params
107132
108- def set_params (self , ** params ) :
133+ def set_params (self , ** params : dict [ str , Any ]) -> ParamsMixin :
109134 """
110135 Set public parameters. Only updates attributes that already exist.
111136 """
@@ -124,7 +149,7 @@ def set_params(self, **params):
124149 k_attr .set_params (** {k_inner : v })
125150 return self
126151
127- def __repr__ (self ):
152+ def __repr__ (self ) -> str :
128153 obj_noargs = type (self )()
129154 args_repr = []
130155 for k , v in self .__dict__ .items ():
@@ -136,7 +161,7 @@ def __repr__(self):
136161 return f"{ self .__class__ .__name__ } ({ ', ' .join (args_repr )} )"
137162
138163
139- def clone (obj ) :
164+ def clone (obj : Any ) -> Any :
140165 """
141166 Clone an estimator, returning a new one, unfitted, having the same public
142167 parameters.
@@ -152,9 +177,9 @@ def clone(obj):
152177 return obj_noargs
153178
154179
155- def profile (n_lines = 10 ):
156- def decorator (func ) :
157- def wrapper (* args , ** kwargs ) :
180+ def profile (n_lines : int = 10 ) -> Callable [..., Any ] :
181+ def decorator (func : Callable [..., Any ]) -> Callable [..., Any ] :
182+ def wrapper (* args : list [ Any ] , ** kwargs : dict [ str , Any ]) -> Any :
158183 profiler = cProfile .Profile ()
159184 profiler .enable ()
160185 result = func (* args , ** kwargs )
0 commit comments