Skip to content

Commit e79fa0c

Browse files
committed
Added type-hints
1 parent c8c1332 commit e79fa0c

8 files changed

Lines changed: 667 additions & 260 deletions

File tree

src/tdamapper/_common.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,31 @@
22
This module provides common functionalities for internal use.
33
"""
44

5+
from __future__ import annotations
6+
57
import cProfile
68
import io
79
import pstats
810
import warnings
11+
from typing import Any, Callable, Union
912

1013
import numpy as np
14+
from numpy.typing import NDArray
1115

1216
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
1317

18+
PointLike = Union[NDArray[np.float64], list[Any]]
19+
ArrayLike = Union[NDArray[np.float64], list[PointLike]]
20+
21+
22+
def deprecated(msg: str) -> Callable:
23+
"""
24+
Decorator to mark a function as deprecated.
25+
26+
:param msg: A message to be shown when the function is called.
27+
"""
1428

15-
def deprecated(msg):
16-
def deprecated_func(func):
29+
def deprecated_func(func: Callable) -> Callable:
1730
def wrapper(*args, **kwargs):
1831
warnings.warn(msg, DeprecationWarning, stacklevel=2)
1932
return func(*args, **kwargs)
@@ -23,17 +36,45 @@ def wrapper(*args, **kwargs):
2336
return deprecated_func
2437

2538

26-
def warn_user(msg):
39+
def warn_user(msg: str) -> None:
40+
"""
41+
Warn the user with a message.
42+
43+
:param msg: A message to be shown to the user.
44+
"""
2745
warnings.warn(msg, UserWarning, stacklevel=2)
2846

2947

3048
class EstimatorMixin:
49+
"""
50+
Mixin to add common functionalities to estimators, such as validation of
51+
input data, setting the number of features, and checking for sparse data.
52+
"""
53+
54+
def _is_sparse(self, X: ArrayLike) -> bool:
55+
"""
56+
Check if the input data is sparse.
3157
32-
def _is_sparse(self, X):
58+
:param X: Input data.
59+
:type X: array-like, shape (n_samples, n_features)
60+
"""
3361
# simple alternative use scipy.sparse.issparse
3462
return hasattr(X, "toarray")
3563

36-
def _validate_X_y(self, X, y):
64+
def _validate_X_y(
65+
self, X: ArrayLike, y: ArrayLike
66+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
67+
"""
68+
Validate input data and target values.
69+
70+
:param X: Input data.
71+
:type X: array-like, shape (n_samples, n_features)
72+
:param y: Target values.
73+
:type y: array-like, shape (n_samples,)
74+
:return: Validated input data and target values.
75+
:rtype: tuple of (X, y)
76+
:raises ValueError: If the input data or target values are invalid.
77+
"""
3778
if self._is_sparse(X):
3879
raise ValueError("Sparse data not supported.")
3980

@@ -70,8 +111,9 @@ def _validate_X_y(self, X, y):
70111

71112
return X, y
72113

73-
def _set_n_features_in(self, X):
74-
self.n_features_in_ = X.shape[1]
114+
def _set_n_features_in(self, X: ArrayLike) -> None:
115+
if hasattr(X, "shape"):
116+
self.n_features_in_ = X.shape[1]
75117

76118

77119
class ParamsMixin:
@@ -80,16 +122,16 @@ class ParamsMixin:
80122
scikit-learn `get_params` and `set_params`.
81123
"""
82124

83-
def _is_param_public(self, k):
125+
def _is_param_public(self, k: str) -> bool:
84126
return (not k.startswith("_")) and (not k.endswith("_"))
85127

86-
def _split_param(self, k):
128+
def _split_param(self, k: str):
87129
k_split = k.split("__")
88130
outer = k_split[0]
89131
inner = "__".join(k_split[1:])
90132
return outer, inner
91133

92-
def get_params(self, deep=True):
134+
def get_params(self, deep: bool = True) -> dict[str, Any]:
93135
"""
94136
Get all public parameters of the object as a dictionary.
95137
@@ -105,7 +147,7 @@ def get_params(self, deep=True):
105147
params[f"{k}__{_k}"] = _v
106148
return params
107149

108-
def set_params(self, **params):
150+
def set_params(self, **params: dict[str, Any]) -> ParamsMixin:
109151
"""
110152
Set public parameters. Only updates attributes that already exist.
111153
"""
@@ -124,7 +166,7 @@ def set_params(self, **params):
124166
k_attr.set_params(**{k_inner: v})
125167
return self
126168

127-
def __repr__(self):
169+
def __repr__(self) -> str:
128170
obj_noargs = type(self)()
129171
args_repr = []
130172
for k, v in self.__dict__.items():
@@ -136,7 +178,7 @@ def __repr__(self):
136178
return f"{self.__class__.__name__}({', '.join(args_repr)})"
137179

138180

139-
def clone(obj):
181+
def clone(obj: Any) -> Any:
140182
"""
141183
Clone an estimator, returning a new one, unfitted, having the same public
142184
parameters.
@@ -152,8 +194,15 @@ def clone(obj):
152194
return obj_noargs
153195

154196

155-
def profile(n_lines=10):
156-
def decorator(func):
197+
def profile(n_lines: int = 10) -> Callable:
198+
"""
199+
Decorator to profile a function using cProfile and print the top `n_lines`
200+
lines of the profiling report.
201+
202+
:param n_lines: Number of lines to print from the profiling report.
203+
"""
204+
205+
def decorator(func: Callable) -> Callable:
157206
def wrapper(*args, **kwargs):
158207
profiler = cProfile.Profile()
159208
profiler.enable()

src/tdamapper/_run_app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
"""
2+
This module is the entry point for running tdamapper app.
3+
"""
4+
15
from tdamapper.app import main
26

37

4-
def run():
8+
def run() -> None:
9+
"""
10+
Run the tdamapper application.
11+
"""
512
main()
613

714

0 commit comments

Comments
 (0)