Skip to content

Commit 1fc4731

Browse files
nightosongbdpedigoguodong.song
authored
bump hyppo version (#1092)
* bump hyppo * bump lockfile * try another bump * explicitly add future dep * to 0.5.2 need more detail * update lock file * fix mds * fix type check * bump to 3.4.2 --------- Co-authored-by: Ben Pedigo <benjamindpedigo@gmail.com> Co-authored-by: guodong.song <guodong.song@kunlun-inc.com>
1 parent cc202e5 commit 1fc4731

5 files changed

Lines changed: 1796 additions & 1404 deletions

File tree

graspologic/embed/mds.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,10 @@ def __init__(
9999
dissimilarity: Literal["euclidean", "precomputed"] = "euclidean",
100100
svd_seed: Optional[int] = None,
101101
) -> None:
102-
# Check inputs
103-
if n_components is not None:
104-
if not isinstance(n_components, int):
105-
msg = "n_components must be an integer, not {}.".format(
106-
type(n_components)
107-
)
108-
raise TypeError(msg)
109-
elif n_components <= 0:
110-
msg = "n_components must be >= 1 or None."
111-
raise ValueError(msg)
102+
# Store parameters without validation (sklearn convention)
103+
# Validation will be done in fit() method
112104
self.n_components = n_components
113-
114-
if dissimilarity not in ["euclidean", "precomputed"]:
115-
msg = "Dissimilarity measure must be either 'euclidean' or 'precomputed'."
116-
raise ValueError(msg)
117105
self.dissimilarity = dissimilarity
118-
119106
self.n_elbows = n_elbows
120107
self.svd_seed = svd_seed
121108

@@ -174,6 +161,29 @@ def fit(self, X: np.ndarray, y: Optional[Any] = None) -> "ClassicalMDS":
174161
self : object
175162
Returns an instance of self.
176163
"""
164+
# Validate parameters (sklearn convention: validate in fit, not __init__)
165+
if self.n_components is not None:
166+
if not isinstance(self.n_components, int):
167+
msg = "n_components must be an integer, not {}.".format(
168+
type(self.n_components)
169+
)
170+
raise TypeError(msg)
171+
elif self.n_components < 0:
172+
msg = "n_components must be >= 0 or None."
173+
raise ValueError(msg)
174+
175+
if self.dissimilarity not in ["euclidean", "precomputed"]:
176+
msg = "Dissimilarity measure must be either 'euclidean' or 'precomputed'."
177+
raise ValueError(msg)
178+
179+
if not isinstance(self.n_elbows, int) or self.n_elbows < 0:
180+
msg = "n_elbows must be a non-negative integer."
181+
raise ValueError(msg)
182+
183+
if self.svd_seed is not None and (not isinstance(self.svd_seed, int) or self.svd_seed < 0):
184+
msg = "svd_seed must be a non-negative integer or None."
185+
raise ValueError(msg)
186+
177187
# Check X type
178188
if not isinstance(X, np.ndarray):
179189
msg = "X must be a numpy array, not {}.".format(type(X))
@@ -184,6 +194,16 @@ def fit(self, X: np.ndarray, y: Optional[Any] = None) -> "ClassicalMDS":
184194
if self.n_components > n_samples:
185195
msg = "n_components must be <= n_samples."
186196
raise ValueError(msg)
197+
# Handle special case of n_components=0
198+
if self.n_components == 0:
199+
self.n_components_ = 0
200+
self.components_ = np.empty(
201+
(0, X.shape[1] if X.ndim == 2 else X.shape[0])
202+
)
203+
self.singular_values_ = np.empty(0)
204+
self.dissimilarity_matrix_ = np.empty((n_samples, n_samples))
205+
self.n_features_in_ = X.shape[1] if X.ndim == 2 else X.shape[0]
206+
return self
187207

188208
# Handle dissimilarity
189209
if self.dissimilarity == "precomputed":
@@ -244,6 +264,10 @@ def fit_transform(self, X: np.ndarray, y: Optional[Any] = None) -> np.ndarray:
244264
"""
245265
self.fit(X)
246266

267+
# Handle special case of n_components=0
268+
if self.n_components_ == 0:
269+
return np.empty((X.shape[0], 0))
270+
247271
X_new = self.components_ @ np.diag(self.singular_values_)
248272

249273
return X_new

graspologic/plot/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def pairplot_with_gmm(
788788
alpha: float = 0.7,
789789
figsize: Tuple[int, int] = (12, 12),
790790
histplot_kws: Optional[Dict[str, Any]] = None,
791-
) -> Tuple[matplotlib.pyplot.Figure, matplotlib.pyplot.Axes]:
791+
) -> Tuple[matplotlib.pyplot.Figure, np.ndarray[Any, Any]]:
792792
r"""
793793
Plot pairwise relationships in a dataset, also showing a clustering predicted by
794794
a Gaussian mixture model.

0 commit comments

Comments
 (0)