@@ -258,6 +258,39 @@ class labels for fitting
258258
259259 return self
260260
261+ def _predict_proba (self , X ):
262+ """Predict class probabilities for sequences in X.
263+
264+ private _predict_proba containing the core logic, called from predict_proba
265+
266+ State required:
267+ Requires state to be "fitted".
268+
269+ Accesses in self:
270+ Fitted model attributes ending in "_"
271+
272+ Parameters
273+ ----------
274+ X : guaranteed to be of a type in self.get_tag("X_inner_mtype")
275+ if self.get_tag("X_inner_mtype") = "numpy3D":
276+ 3D np.ndarray of shape = [n_instances, n_dimensions, series_length]
277+ if self.get_tag("X_inner_mtype") = "nested_univ":
278+ pd.DataFrame with each column a dimension, each cell a pd.Series
279+ for list of other mtypes, see datatypes.SCITYPE_REGISTER
280+ for specifications, see examples/AA_datatypes_and_datasets.ipynb
281+
282+ Returns
283+ -------
284+ y : 2D array of shape [n_instances, n_classes] - predicted class probabilities
285+ """
286+ if not self .refit :
287+ raise RuntimeError (
288+ f"In { self .__class__ .__name__ } , refit must be True to make predictions,"
289+ f" but found refit=False. If refit=False, { self .__class__ .__name__ } can"
290+ " be used only to tune hyper-parameters, as a parameter estimator."
291+ )
292+ return super ()._predict_proba (X = X )
293+
261294 def _predict (self , X ):
262295 """Predict labels for sequences in X.
263296
@@ -317,15 +350,16 @@ def get_test_params(cls, parameter_set="default"):
317350
318351 params_gridsearch = {
319352 "estimator" : DummyClassifier (),
353+ "cv" : KFold (n_splits = 2 , shuffle = False ),
320354 "optimizer" : GridSearchSk (
321- param_grid = {"strategy" : ["most_frequent" , "stratified " ]}
355+ param_grid = {"strategy" : ["most_frequent" , "prior " ]}
322356 ),
323357 }
324358 params_randomsearch = {
325359 "estimator" : DummyClassifier (),
326- "cv" : 2 ,
360+ "cv" : KFold ( n_splits = 2 , shuffle = False ) ,
327361 "optimizer" : RandomSearchSk (
328- param_distributions = {"strategy" : ["most_frequent" , "stratified " ]},
362+ param_distributions = {"strategy" : ["most_frequent" , "prior " ]},
329363 ),
330364 "scoring" : accuracy_score ,
331365 }
0 commit comments