|
20 | 20 | # limitations under the License. |
21 | 21 | # |
22 | 22 | import collections.abc as collections_abc |
23 | | -import packaging.version |
24 | 23 | import functools |
| 24 | + |
| 25 | +import packaging.version |
25 | 26 | import pytest |
26 | 27 | import sklearn.utils.estimator_checks |
27 | 28 | import torch |
@@ -66,12 +67,18 @@ def parametrize_slow(arg_names, fast_arguments, slow_arguments): |
66 | 67 | def parametrize_with_checks_slow(fast_arguments, slow_arguments): |
67 | 68 |
|
68 | 69 | # NOTE(stes): See https://github.com/AdaptiveMotorControlLab/CEBRA/issues/280, sklearn API changed in 1.6. |
69 | | - if packaging.version.parse(sklearn.__version__) <= packaging.version.parse("1.6"): |
70 | | - generate_checks = functools.partial(sklearn.utils.estimator_checks.check_estimator, generate_only=True) |
| 70 | + if packaging.version.parse( |
| 71 | + sklearn.__version__) <= packaging.version.parse("1.6"): |
| 72 | + generate_checks = functools.partial( |
| 73 | + sklearn.utils.estimator_checks.check_estimator, generate_only=True) |
71 | 74 | else: |
72 | 75 | generate_checks = sklearn.utils.estimator_checks.estimator_checks_generator |
73 | | - generate_params = lambda args: [next(generate_checks(arg)) for arg in args] |
74 | | - return parametrize_slow("estimator,check", generate_params(fast_arguments), generate_params(slow_arguments)) |
| 76 | + |
| 77 | + def _generate_params(args): |
| 78 | + return [next(generate_checks(arg)) for arg in args] |
| 79 | + |
| 80 | + return parametrize_slow("estimator,check", _generate_params(fast_arguments), |
| 81 | + _generate_params(slow_arguments)) |
75 | 82 |
|
76 | 83 |
|
77 | 84 | def parametrize_device(func): |
|
0 commit comments