Skip to content

Commit edc33a9

Browse files
committed
check_estimator
1 parent fb74eb6 commit edc33a9

2 files changed

Lines changed: 97 additions & 2 deletions

File tree

src/hyperactive/tests/test_all_objects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55

66
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator
7+
from skbase.testing import QuickTester as _QuickTester
78
from skbase.testing import TestAllObjects as _TestAllObjects
89

910
from hyperactive._registry import all_objects
@@ -154,7 +155,7 @@ class ExperimentFixtureGenerator(BaseFixtureGenerator):
154155
object_type_filter = "experiment"
155156

156157

157-
class TestAllExperiments(ExperimentFixtureGenerator):
158+
class TestAllExperiments(ExperimentFixtureGenerator, _QuickTester):
158159
"""Module level tests for all experiment classes."""
159160

160161
def test_paramnames(self, object_class):
@@ -204,7 +205,7 @@ class OptimizerFixtureGenerator(BaseFixtureGenerator):
204205
object_type_filter = "optimizer"
205206

206207

207-
class TestAllOptimizers(OptimizerFixtureGenerator):
208+
class TestAllOptimizers(OptimizerFixtureGenerator, _QuickTester):
208209
"""Module level tests for all optimizer classes."""
209210

210211
def test_opt_run(self, object_instance):
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
2+
"""Registry and dispatcher for test classes.
3+
4+
Module does not contain tests, only test utilities.
5+
"""
6+
7+
__author__ = ["fkiraly"]
8+
9+
from inspect import isclass
10+
11+
12+
def get_test_class_registry():
13+
"""Return test class registry.
14+
15+
Wrapped in a function to avoid circular imports.
16+
17+
Returns
18+
-------
19+
testclass_dict : dict
20+
test class registry
21+
keys are scitypes, values are test classes TestAll[Scitype]
22+
"""
23+
from hyperactive.tests.test_all_objects import (
24+
TestAllExperiments,
25+
TestAllObjects,
26+
TestAllOptimizers,
27+
)
28+
29+
testclass_dict = dict()
30+
# every object in sktime inherits from BaseObject
31+
# "object" tests are run for all objects
32+
testclass_dict["object"] = TestAllObjects
33+
# more specific base classes
34+
# these inherit either from BaseEstimator or BaseObject,
35+
# so also imply estimator and object tests, or only object tests
36+
testclass_dict["experiment"] = TestAllExperiments
37+
testclass_dict["optimizer"] = TestAllOptimizers
38+
39+
return testclass_dict
40+
41+
42+
def get_test_classes_for_obj(obj):
43+
"""Get all test classes relevant for an object or estimator.
44+
45+
Parameters
46+
----------
47+
obj : object or estimator, descendant of sktime BaseObject or BaseEstimator
48+
object or estimator for which to get test classes
49+
50+
Returns
51+
-------
52+
test_classes : list of test classes
53+
list of test classes relevant for obj
54+
these are references to the actual classes, not strings
55+
if obj was not a descendant of BaseObject or BaseEstimator, returns empty list
56+
"""
57+
from skbase.base import BaseObject
58+
59+
def is_object(obj):
60+
"""Return whether obj is an estimator class or estimator object."""
61+
if isclass(obj):
62+
return issubclass(obj, BaseObject)
63+
else:
64+
return isinstance(obj, BaseObject)
65+
66+
# warning: BaseEstimator does not inherit from BaseObject,
67+
# therefore we need to check both
68+
if not is_object(obj):
69+
return []
70+
71+
testclass_dict = get_test_class_registry()
72+
73+
# we always need to run "object" tests
74+
test_clss = [testclass_dict["object"]]
75+
76+
try:
77+
if isclass(obj):
78+
obj_scitypes = obj.get_class_tag("object_type")
79+
elif hasattr(obj, "get_tag"):
80+
obj_scitypes = obj.get_tag("object_type")
81+
else:
82+
obj_scitypes = []
83+
except Exception:
84+
obj_scitypes = []
85+
86+
if isinstance(obj_scitypes, str):
87+
# if obj_scitypes is a string, convert to list
88+
obj_scitypes = [obj_scitypes]
89+
90+
for obj_scitype in obj_scitypes:
91+
if obj_scitype in testclass_dict:
92+
test_clss += [testclass_dict[obj_scitype]]
93+
94+
return test_clss

0 commit comments

Comments
 (0)