Skip to content

Commit 4e38e57

Browse files
authored
[BUG] fix TSCOptCV integration for metric function input (#190)
fixes an unreported bug in the `TSCOptCV` integration, for metric function input
1 parent a2972d4 commit 4e38e57

3 files changed

Lines changed: 9 additions & 13 deletions

File tree

src/hyperactive/experiment/integrations/_skl_metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def _coerce_to_scorer(scoring, estimator):
7171
scorer = scoring # passthrough scorer signature
7272
else:
7373
scorer = make_scorer(scoring)
74+
elif isinstance(estimator, str):
75+
metric = _default_metric_for(estimator)
76+
scorer = make_scorer(metric)
7477
else:
7578
# string (scorer name)
7679
scorer = check_scoring(estimator, scoring=scoring)

src/hyperactive/integrations/sktime/_classification.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,11 @@ class labels for fitting
232232
-------
233233
self : Reference to self.
234234
"""
235-
from sklearn.dummy import DummyClassifier
236-
from sklearn.metrics import check_scoring
237-
238235
estimator = self.estimator.clone()
239236

240-
# use dummy classifier from sklearn to get default coercion behaviour
241-
# for classificatoin metrics
242-
scoring = check_scoring(DummyClassifier(), self.scoring)
243-
# scoring_name = f"test_{scoring.name}"
244-
245237
experiment = SktimeClassificationExperiment(
246238
estimator=estimator,
247-
scoring=scoring,
239+
scoring=self.scoring,
248240
cv=self.cv,
249241
X=X,
250242
y=y,
@@ -316,6 +308,7 @@ def get_test_params(cls, parameter_set="default"):
316308
"""
317309
from sklearn.metrics import accuracy_score
318310
from sklearn.model_selection import KFold
311+
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
319312
from sktime.classification.dummy import DummyClassifier
320313

321314
from hyperactive.opt.gfo import HillClimbing
@@ -337,10 +330,10 @@ def get_test_params(cls, parameter_set="default"):
337330
"scoring": accuracy_score,
338331
}
339332
params_hillclimb = {
340-
"estimator": DummyClassifier(strategy="stratified"),
333+
"estimator": KNeighborsTimeSeriesClassifier(),
341334
"cv": KFold(n_splits=2, shuffle=False),
342335
"optimizer": HillClimbing(
343-
search_space={"strategy": ["most_frequent", "stratified"]},
336+
search_space={"n_neighbors": [1, 2, 4]},
344337
n_iter=10,
345338
n_neighbours=5,
346339
),

src/hyperactive/integrations/sktime/tests/test_sktime_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from skbase.utils.dependencies import _check_soft_dependencies
66

77
if _check_soft_dependencies("sktime", severity="none"):
8-
from hyperactive.integrations.sktime import ForecastingOptCV
8+
from hyperactive.integrations.sktime import ForecastingOptCV, TSCOptCV
99

10-
EST_TO_TEST = [ForecastingOptCV]
10+
EST_TO_TEST = [ForecastingOptCV, TSCOptCV]
1111
else:
1212
EST_TO_TEST = []
1313

0 commit comments

Comments
 (0)