Skip to content

Commit 4fb1b63

Browse files
check for namespace and device compatibility (#3149)
1 parent 2862e31 commit 4fb1b63

2 files changed

Lines changed: 52 additions & 4 deletions

File tree

sklearnex/cluster/k_means.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@
4040
from onedal.cluster import KMeans as onedal_KMeans
4141
from onedal.utils.validation import _is_arraylike_not_scalar, _is_csr
4242

43-
from .._config import get_config
4443
from .._device_offload import dispatch, wrap_output_data
4544
from .._utils import PatchingConditionsChain
4645
from ..base import oneDALEstimator
4746
from ..utils._array_api import enable_array_api, get_namespace
4847
from ..utils.validation import validate_data
4948

49+
if sklearn_check_version("1.9"):
50+
from sklearn.utils._array_api import check_same_namespace
51+
5052
@enable_array_api
5153
@control_n_jobs(decorated_methods=["fit", "fit_transform", "predict", "score"])
5254
class KMeans(oneDALEstimator, _sklearn_KMeans):
@@ -383,7 +385,10 @@ def predict(
383385
)
384386

385387
def _onedal_predict(self, X, sample_weight=None, queue=None):
386-
388+
if sklearn_check_version("1.9"):
389+
check_same_namespace(
390+
X, self, attribute="cluster_centers_", method="predict"
391+
)
387392
xp, _ = get_namespace(X)
388393
X = validate_data(
389394
self,
@@ -450,7 +455,10 @@ def transform(self, X):
450455
)
451456

452457
def _onedal_transform(self, X, queue=None):
453-
458+
if sklearn_check_version("1.9"):
459+
check_same_namespace(
460+
X, self, attribute="cluster_centers_", method="transform"
461+
)
454462
xp, is_array_api = get_namespace(X)
455463
X = validate_data(
456464
self,
@@ -492,7 +500,10 @@ def score(self, X, y=None, sample_weight=None):
492500
)
493501

494502
def _onedal_score(self, X, y=None, sample_weight=None, queue=None):
495-
503+
if sklearn_check_version("1.9"):
504+
check_same_namespace(
505+
X, self, attribute="cluster_centers_", method="score"
506+
)
496507
xp, _ = get_namespace(X)
497508
X = validate_data(
498509
self,

sklearnex/cluster/tests/test_kmeans.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
from onedal.tests.utils._dataframes_support import (
2727
_as_numpy,
2828
_convert_to_dataframe,
29+
dpnp_available,
2930
get_dataframes_and_queues,
3031
get_queues,
3132
)
33+
from onedal.tests.utils._device_selection import is_sycl_device_available
3234
from sklearnex import config_context
3335
from sklearnex.cluster import KMeans
3436
from sklearnex.tests.utils import _IS_INTEL
@@ -224,3 +226,38 @@ def test_array_api_dispatch_output_type(dataframe, queue):
224226
assert type(trans) == type(X)
225227
assert type(km.cluster_centers_) == type(X)
226228
assert isinstance(sc, float)
229+
230+
231+
@pytest.mark.skipif(
232+
not sklearn_check_version("1.9"),
233+
reason="Relies on functionality introduced in later scikit-learn versions.",
234+
)
235+
@pytest.mark.skipif(not dpnp_available, reason="Functionality to test requires DPNP.")
236+
@pytest.mark.skipif(
237+
not is_sycl_device_available("gpu"), reason="Test for GPU-specific functionality."
238+
)
239+
def test_cov_error_on_incompatible_devices(with_array_api):
240+
import dpnp
241+
242+
rng = np.random.default_rng(seed=123)
243+
X = rng.random(size=(50, 3), dtype=np.float32)
244+
X_cpu = dpnp.array(X, device="cpu")
245+
X_gpu = dpnp.array(X, device="gpu")
246+
247+
err_match = "device|queue"
248+
249+
model = KMeans(algorithm="lloyd").fit(X_gpu)
250+
with pytest.raises(ValueError, match=err_match):
251+
_ = model.predict(X_cpu)
252+
with pytest.raises(ValueError, match=err_match):
253+
_ = model.transform(X_cpu)
254+
with pytest.raises(ValueError, match=err_match):
255+
_ = model.score(X_cpu)
256+
257+
model.fit(X_cpu)
258+
with pytest.raises(ValueError, match=err_match):
259+
_ = model.predict(X_gpu)
260+
with pytest.raises(ValueError, match=err_match):
261+
_ = model.transform(X_gpu)
262+
with pytest.raises(ValueError, match=err_match):
263+
_ = model.score(X_gpu)

0 commit comments

Comments
 (0)