Skip to content

Commit 1f307d1

Browse files
CopilotMMathisLab
andcommitted
Fix sklearn version compatibility for parametrize_with_checks_slow
Handle both sklearn 1.4.2 (legacy) and 1.8.0+ (latest) APIs: - Old API: check_estimator with generate_only parameter - New API: parametrize_with_checks (generate_only removed) Fixes TypeError: got an unexpected keyword argument 'generate_only' Co-authored-by: MMathisLab <28102185+MMathisLab@users.noreply.github.com>
1 parent 985cb2c commit 1f307d1

1 file changed

Lines changed: 36 additions & 11 deletions

File tree

tests/_util.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# limitations under the License.
2121
#
2222
import collections.abc as collections_abc
23+
import inspect
2324

2425
import pytest
2526
import sklearn.utils.estimator_checks
@@ -69,21 +70,45 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T
6970
fast_arguments: List of estimators to use for fast tests.
7071
slow_arguments: List of estimators to use for slow tests.
7172
generate_only: If True, only generate tests without running them (default: True).
72-
This is passed to sklearn.utils.estimator_checks.check_estimator.
73+
This parameter is only used with sklearn < 1.5. In newer versions,
74+
tests are always generated (not run immediately).
7375
7476
Returns:
7577
A pytest parametrize decorator configured with fast and slow test parameters.
7678
"""
77-
fast_params = [
78-
list(
79-
sklearn.utils.estimator_checks.check_estimator(
80-
fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments
81-
]
82-
slow_params = [
83-
list(
84-
sklearn.utils.estimator_checks.check_estimator(
85-
slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments
86-
]
79+
# Check if check_estimator supports generate_only parameter (sklearn < 1.5)
80+
check_estimator_sig = inspect.signature(sklearn.utils.estimator_checks.check_estimator)
81+
supports_generate_only = 'generate_only' in check_estimator_sig.parameters
82+
83+
if supports_generate_only:
84+
# Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True
85+
fast_params = [
86+
list(
87+
sklearn.utils.estimator_checks.check_estimator(
88+
fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments
89+
]
90+
slow_params = [
91+
list(
92+
sklearn.utils.estimator_checks.check_estimator(
93+
slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments
94+
]
95+
else:
96+
# New sklearn API (>= 1.5): use parametrize_with_checks to get test params
97+
# For each estimator, get the first check
98+
fast_params = []
99+
for fast_arg in fast_arguments:
100+
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([fast_arg])
101+
# Extract the generator from the decorator and get first item
102+
gen = decorator.mark.args[1]
103+
fast_params.append(next(gen))
104+
105+
slow_params = []
106+
for slow_arg in slow_arguments:
107+
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([slow_arg])
108+
# Extract the generator from the decorator and get first item
109+
gen = decorator.mark.args[1]
110+
slow_params.append(next(gen))
111+
87112
return parametrize_slow("estimator,check", fast_params, slow_params)
88113

89114

0 commit comments

Comments
 (0)