|
40 | 40 | from onedal.cluster import KMeans as onedal_KMeans |
41 | 41 | from onedal.utils.validation import _is_arraylike_not_scalar, _is_csr |
42 | 42 |
|
43 | | - from .._config import get_config |
44 | 43 | from .._device_offload import dispatch, wrap_output_data |
45 | 44 | from .._utils import PatchingConditionsChain |
46 | 45 | from ..base import oneDALEstimator |
47 | 46 | from ..utils._array_api import enable_array_api, get_namespace |
48 | 47 | from ..utils.validation import validate_data |
49 | 48 |
|
| 49 | + if sklearn_check_version("1.9"): |
| 50 | + from sklearn.utils._array_api import check_same_namespace |
| 51 | + |
50 | 52 | @enable_array_api |
51 | 53 | @control_n_jobs(decorated_methods=["fit", "fit_transform", "predict", "score"]) |
52 | 54 | class KMeans(oneDALEstimator, _sklearn_KMeans): |
@@ -383,7 +385,10 @@ def predict( |
383 | 385 | ) |
384 | 386 |
|
385 | 387 | def _onedal_predict(self, X, sample_weight=None, queue=None): |
386 | | - |
| 388 | + if sklearn_check_version("1.9"): |
| 389 | + check_same_namespace( |
| 390 | + X, self, attribute="cluster_centers_", method="predict" |
| 391 | + ) |
387 | 392 | xp, _ = get_namespace(X) |
388 | 393 | X = validate_data( |
389 | 394 | self, |
@@ -450,7 +455,10 @@ def transform(self, X): |
450 | 455 | ) |
451 | 456 |
|
452 | 457 | def _onedal_transform(self, X, queue=None): |
453 | | - |
| 458 | + if sklearn_check_version("1.9"): |
| 459 | + check_same_namespace( |
| 460 | + X, self, attribute="cluster_centers_", method="transform" |
| 461 | + ) |
454 | 462 | xp, is_array_api = get_namespace(X) |
455 | 463 | X = validate_data( |
456 | 464 | self, |
@@ -492,7 +500,10 @@ def score(self, X, y=None, sample_weight=None): |
492 | 500 | ) |
493 | 501 |
|
494 | 502 | def _onedal_score(self, X, y=None, sample_weight=None, queue=None): |
495 | | - |
| 503 | + if sklearn_check_version("1.9"): |
| 504 | + check_same_namespace( |
| 505 | + X, self, attribute="cluster_centers_", method="score" |
| 506 | + ) |
496 | 507 | xp, _ = get_namespace(X) |
497 | 508 | X = validate_data( |
498 | 509 | self, |
|
0 commit comments