Skip to content

Commit 7831d4c

Browse files
committed
override '_predict_proba'
1 parent 6b7c47e commit 7831d4c

1 file changed

Lines changed: 37 additions & 3 deletions

File tree

src/hyperactive/integrations/sktime/_classification.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,39 @@ class labels for fitting
258258

259259
return self
260260

261+
def _predict_proba(self, X):
262+
"""Predict class probabilities for sequences in X.
263+
264+
private _predict_proba containing the core logic, called from predict_proba
265+
266+
State required:
267+
Requires state to be "fitted".
268+
269+
Accesses in self:
270+
Fitted model attributes ending in "_"
271+
272+
Parameters
273+
----------
274+
X : guaranteed to be of a type in self.get_tag("X_inner_mtype")
275+
if self.get_tag("X_inner_mtype") = "numpy3D":
276+
3D np.ndarray of shape = [n_instances, n_dimensions, series_length]
277+
if self.get_tag("X_inner_mtype") = "nested_univ":
278+
pd.DataFrame with each column a dimension, each cell a pd.Series
279+
for list of other mtypes, see datatypes.SCITYPE_REGISTER
280+
for specifications, see examples/AA_datatypes_and_datasets.ipynb
281+
282+
Returns
283+
-------
284+
y : 2D array of shape [n_instances, n_classes] - predicted class probabilities
285+
"""
286+
if not self.refit:
287+
raise RuntimeError(
288+
f"In {self.__class__.__name__}, refit must be True to make predictions,"
289+
f" but found refit=False. If refit=False, {self.__class__.__name__} can"
290+
" be used only to tune hyper-parameters, as a parameter estimator."
291+
)
292+
return super()._predict_proba(X=X)
293+
261294
def _predict(self, X):
262295
"""Predict labels for sequences in X.
263296
@@ -317,15 +350,16 @@ def get_test_params(cls, parameter_set="default"):
317350

318351
params_gridsearch = {
319352
"estimator": DummyClassifier(),
353+
"cv": KFold(n_splits=2, shuffle=False),
320354
"optimizer": GridSearchSk(
321-
param_grid={"strategy": ["most_frequent", "stratified"]}
355+
param_grid={"strategy": ["most_frequent", "prior"]}
322356
),
323357
}
324358
params_randomsearch = {
325359
"estimator": DummyClassifier(),
326-
"cv": 2,
360+
"cv": KFold(n_splits=2, shuffle=False),
327361
"optimizer": RandomSearchSk(
328-
param_distributions={"strategy": ["most_frequent", "stratified"]},
362+
param_distributions={"strategy": ["most_frequent", "prior"]},
329363
),
330364
"scoring": accuracy_score,
331365
}

0 commit comments

Comments
 (0)