Skip to content

Commit 936a7dc

Browse files
committed
Added Array protocol
1 parent fa618f0 commit 936a7dc

1 file changed

Lines changed: 42 additions & 17 deletions

File tree

src/tdamapper/_common.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,41 @@
22
This module provides common functionalities for internal use.
33
"""
44

5+
from __future__ import annotations
6+
57
import cProfile
68
import io
79
import pstats
810
import warnings
11+
from typing import Any, Callable, Protocol
912

1013
import numpy as np
14+
from numpy.typing import NDArray
1115

1216
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
1317

1418

15-
def deprecated(msg):
16-
def deprecated_func(func):
17-
def wrapper(*args, **kwargs):
19+
class Array(Protocol):
20+
21+
def __getitem__(self, index: int) -> Any:
22+
"""
23+
Get an item from the array.
24+
"""
25+
26+
def __len__(self) -> int:
27+
"""
28+
Get the length of the array.
29+
"""
30+
31+
def __setitem__(self, index: int, value: Any) -> None:
32+
"""
33+
Set an item in the array.
34+
"""
35+
36+
37+
def deprecated(msg: str) -> Callable[..., Any]:
38+
def deprecated_func(func: Callable[..., Any]) -> Callable[..., Any]:
39+
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
1840
warnings.warn(msg, DeprecationWarning, stacklevel=2)
1941
return func(*args, **kwargs)
2042

@@ -23,17 +45,19 @@ def wrapper(*args, **kwargs):
2345
return deprecated_func
2446

2547

26-
def warn_user(msg):
48+
def warn_user(msg: str) -> None:
2749
warnings.warn(msg, UserWarning, stacklevel=2)
2850

2951

3052
class EstimatorMixin:
3153

32-
def _is_sparse(self, X):
54+
def _is_sparse(self, X: Array) -> bool:
3355
# simple alternative use scipy.sparse.issparse
3456
return hasattr(X, "toarray")
3557

36-
def _validate_X_y(self, X, y):
58+
def _validate_X_y(
59+
self, X: Array, y: Array
60+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
3761
if self._is_sparse(X):
3862
raise ValueError("Sparse data not supported.")
3963

@@ -70,8 +94,9 @@ def _validate_X_y(self, X, y):
7094

7195
return X, y
7296

73-
def _set_n_features_in(self, X):
74-
self.n_features_in_ = X.shape[1]
97+
def _set_n_features_in(self, X: Array) -> None:
98+
if hasattr(X, "shape"):
99+
self.n_features_in_ = X.shape[1]
75100

76101

77102
class ParamsMixin:
@@ -80,16 +105,16 @@ class ParamsMixin:
80105
scikit-learn `get_params` and `set_params`.
81106
"""
82107

83-
def _is_param_public(self, k):
108+
def _is_param_public(self, k: str) -> bool:
84109
return (not k.startswith("_")) and (not k.endswith("_"))
85110

86-
def _split_param(self, k):
111+
def _split_param(self, k: str) -> tuple[str, str]:
87112
k_split = k.split("__")
88113
outer = k_split[0]
89114
inner = "__".join(k_split[1:])
90115
return outer, inner
91116

92-
def get_params(self, deep=True):
117+
def get_params(self, deep: bool = True) -> dict[str, Any]:
93118
"""
94119
Get all public parameters of the object as a dictionary.
95120
@@ -105,7 +130,7 @@ def get_params(self, deep=True):
105130
params[f"{k}__{_k}"] = _v
106131
return params
107132

108-
def set_params(self, **params):
133+
def set_params(self, **params: dict[str, Any]) -> ParamsMixin:
109134
"""
110135
Set public parameters. Only updates attributes that already exist.
111136
"""
@@ -124,7 +149,7 @@ def set_params(self, **params):
124149
k_attr.set_params(**{k_inner: v})
125150
return self
126151

127-
def __repr__(self):
152+
def __repr__(self) -> str:
128153
obj_noargs = type(self)()
129154
args_repr = []
130155
for k, v in self.__dict__.items():
@@ -136,7 +161,7 @@ def __repr__(self):
136161
return f"{self.__class__.__name__}({', '.join(args_repr)})"
137162

138163

139-
def clone(obj):
164+
def clone(obj: Any) -> Any:
140165
"""
141166
Clone an estimator, returning a new one, unfitted, having the same public
142167
parameters.
@@ -152,9 +177,9 @@ def clone(obj):
152177
return obj_noargs
153178

154179

155-
def profile(n_lines=10):
156-
def decorator(func):
157-
def wrapper(*args, **kwargs):
180+
def profile(n_lines: int = 10) -> Callable[..., Any]:
181+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
182+
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
158183
profiler = cProfile.Profile()
159184
profiler.enable()
160185
result = func(*args, **kwargs)

0 commit comments

Comments
 (0)