|
20 | 20 | # limitations under the License. |
21 | 21 | # |
22 | 22 | import collections.abc as collections_abc |
| 23 | +import inspect |
23 | 24 |
|
24 | 25 | import pytest |
25 | 26 | import sklearn.utils.estimator_checks |
@@ -69,21 +70,45 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T |
69 | 70 | fast_arguments: List of estimators to use for fast tests. |
70 | 71 | slow_arguments: List of estimators to use for slow tests. |
71 | 72 | generate_only: If True, only generate tests without running them (default: True). |
72 | | - This is passed to sklearn.utils.estimator_checks.check_estimator. |
| 73 | + This parameter is only used with sklearn < 1.5. In newer versions, |
| 74 | + tests are always generated (not run immediately). |
73 | 75 | |
74 | 76 | Returns: |
75 | 77 | A pytest parametrize decorator configured with fast and slow test parameters. |
76 | 78 | """ |
77 | | - fast_params = [ |
78 | | - list( |
79 | | - sklearn.utils.estimator_checks.check_estimator( |
80 | | - fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments |
81 | | - ] |
82 | | - slow_params = [ |
83 | | - list( |
84 | | - sklearn.utils.estimator_checks.check_estimator( |
85 | | - slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments |
86 | | - ] |
| 79 | + # Check if check_estimator supports generate_only parameter (sklearn < 1.5) |
| 80 | + check_estimator_sig = inspect.signature(sklearn.utils.estimator_checks.check_estimator) |
| 81 | + supports_generate_only = 'generate_only' in check_estimator_sig.parameters |
| 82 | + |
| 83 | + if supports_generate_only: |
| 84 | + # Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True |
| 85 | + fast_params = [ |
| 86 | + list( |
| 87 | + sklearn.utils.estimator_checks.check_estimator( |
| 88 | + fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments |
| 89 | + ] |
| 90 | + slow_params = [ |
| 91 | + list( |
| 92 | + sklearn.utils.estimator_checks.check_estimator( |
| 93 | + slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments |
| 94 | + ] |
| 95 | + else: |
| 96 | + # New sklearn API (>= 1.5): use parametrize_with_checks to get test params |
| 97 | + # For each estimator, get the first check |
| 98 | + fast_params = [] |
| 99 | + for fast_arg in fast_arguments: |
| 100 | + decorator = sklearn.utils.estimator_checks.parametrize_with_checks([fast_arg]) |
| 101 | + # Extract the generator from the decorator and get first item |
| 102 | + gen = decorator.mark.args[1] |
| 103 | + fast_params.append(next(gen)) |
| 104 | + |
| 105 | + slow_params = [] |
| 106 | + for slow_arg in slow_arguments: |
| 107 | + decorator = sklearn.utils.estimator_checks.parametrize_with_checks([slow_arg]) |
| 108 | + # Extract the generator from the decorator and get first item |
| 109 | + gen = decorator.mark.args[1] |
| 110 | + slow_params.append(next(gen)) |
| 111 | + |
87 | 112 | return parametrize_slow("estimator,check", fast_params, slow_params) |
88 | 113 |
|
89 | 114 |
|
|
0 commit comments