@@ -99,23 +99,10 @@ def __init__(
9999 dissimilarity : Literal ["euclidean" , "precomputed" ] = "euclidean" ,
100100 svd_seed : Optional [int ] = None ,
101101 ) -> None :
102- # Check inputs
103- if n_components is not None :
104- if not isinstance (n_components , int ):
105- msg = "n_components must be an integer, not {}." .format (
106- type (n_components )
107- )
108- raise TypeError (msg )
109- elif n_components <= 0 :
110- msg = "n_components must be >= 1 or None."
111- raise ValueError (msg )
102+ # Store parameters without validation (sklearn convention)
103+ # Validation will be done in fit() method
112104 self .n_components = n_components
113-
114- if dissimilarity not in ["euclidean" , "precomputed" ]:
115- msg = "Dissimilarity measure must be either 'euclidean' or 'precomputed'."
116- raise ValueError (msg )
117105 self .dissimilarity = dissimilarity
118-
119106 self .n_elbows = n_elbows
120107 self .svd_seed = svd_seed
121108
@@ -174,6 +161,29 @@ def fit(self, X: np.ndarray, y: Optional[Any] = None) -> "ClassicalMDS":
174161 self : object
175162 Returns an instance of self.
176163 """
164+ # Validate parameters (sklearn convention: validate in fit, not __init__)
165+ if self .n_components is not None :
166+ if not isinstance (self .n_components , int ):
167+ msg = "n_components must be an integer, not {}." .format (
168+ type (self .n_components )
169+ )
170+ raise TypeError (msg )
171+ elif self .n_components < 0 :
172+ msg = "n_components must be >= 0 or None."
173+ raise ValueError (msg )
174+
175+ if self .dissimilarity not in ["euclidean" , "precomputed" ]:
176+ msg = "Dissimilarity measure must be either 'euclidean' or 'precomputed'."
177+ raise ValueError (msg )
178+
179+ if not isinstance (self .n_elbows , int ) or self .n_elbows < 0 :
180+ msg = "n_elbows must be a non-negative integer."
181+ raise ValueError (msg )
182+
183+ if self .svd_seed is not None and (not isinstance (self .svd_seed , int ) or self .svd_seed < 0 ):
184+ msg = "svd_seed must be a non-negative integer or None."
185+ raise ValueError (msg )
186+
177187 # Check X type
178188 if not isinstance (X , np .ndarray ):
179189 msg = "X must be a numpy array, not {}." .format (type (X ))
@@ -184,6 +194,16 @@ def fit(self, X: np.ndarray, y: Optional[Any] = None) -> "ClassicalMDS":
184194 if self .n_components > n_samples :
185195 msg = "n_components must be <= n_samples."
186196 raise ValueError (msg )
197+ # Handle special case of n_components=0
198+ if self .n_components == 0 :
199+ self .n_components_ = 0
200+ self .components_ = np .empty (
201+ (0 , X .shape [1 ] if X .ndim == 2 else X .shape [0 ])
202+ )
203+ self .singular_values_ = np .empty (0 )
204+ self .dissimilarity_matrix_ = np .empty ((n_samples , n_samples ))
205+ self .n_features_in_ = X .shape [1 ] if X .ndim == 2 else X .shape [0 ]
206+ return self
187207
188208 # Handle dissimilarity
189209 if self .dissimilarity == "precomputed" :
@@ -244,6 +264,10 @@ def fit_transform(self, X: np.ndarray, y: Optional[Any] = None) -> np.ndarray:
244264 """
245265 self .fit (X )
246266
267+ # Handle special case of n_components=0
268+ if self .n_components_ == 0 :
269+ return np .empty ((X .shape [0 ], 0 ))
270+
247271 X_new = self .components_ @ np .diag (self .singular_values_ )
248272
249273 return X_new
0 commit comments