Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions sklearnex/cluster/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
from onedal.cluster import KMeans as onedal_KMeans
from onedal.utils.validation import _is_arraylike_not_scalar, _is_csr

from .._config import get_config
from .._device_offload import dispatch, wrap_output_data
from .._utils import PatchingConditionsChain
from ..base import oneDALEstimator
from ..utils._array_api import enable_array_api, get_namespace
from ..utils.validation import validate_data

if sklearn_check_version("1.9"):
from sklearn.utils._array_api import check_same_namespace

@enable_array_api
@control_n_jobs(decorated_methods=["fit", "fit_transform", "predict", "score"])
class KMeans(oneDALEstimator, _sklearn_KMeans):
Expand Down Expand Up @@ -383,7 +385,10 @@ def predict(
)

def _onedal_predict(self, X, sample_weight=None, queue=None):

if sklearn_check_version("1.9"):
check_same_namespace(
X, self, attribute="cluster_centers_", method="predict"
)
xp, _ = get_namespace(X)
X = validate_data(
self,
Expand Down Expand Up @@ -450,7 +455,10 @@ def transform(self, X):
)

def _onedal_transform(self, X, queue=None):

if sklearn_check_version("1.9"):
check_same_namespace(
X, self, attribute="cluster_centers_", method="transform"
)
xp, is_array_api = get_namespace(X)
X = validate_data(
self,
Expand Down Expand Up @@ -492,7 +500,10 @@ def score(self, X, y=None, sample_weight=None):
)

def _onedal_score(self, X, y=None, sample_weight=None, queue=None):

if sklearn_check_version("1.9"):
check_same_namespace(
X, self, attribute="cluster_centers_", method="score"
)
xp, _ = get_namespace(X)
X = validate_data(
self,
Expand Down
37 changes: 37 additions & 0 deletions sklearnex/cluster/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from onedal.tests.utils._dataframes_support import (
_as_numpy,
_convert_to_dataframe,
dpnp_available,
get_dataframes_and_queues,
get_queues,
)
from onedal.tests.utils._device_selection import is_sycl_device_available
from sklearnex import config_context
from sklearnex.cluster import KMeans
from sklearnex.tests.utils import _IS_INTEL
Expand Down Expand Up @@ -224,3 +226,38 @@ def test_array_api_dispatch_output_type(dataframe, queue):
assert type(trans) == type(X)
assert type(km.cluster_centers_) == type(X)
assert isinstance(sc, float)


@pytest.mark.skipif(
not sklearn_check_version("1.9"),
reason="Relies on functionality introduced in later scikit-learn versions.",
)
@pytest.mark.skipif(not dpnp_available, reason="Functionality to test requires DPNP.")
@pytest.mark.skipif(
not is_sycl_device_available("gpu"), reason="Test for GPU-specific functionality."
)
def test_cov_error_on_incompatible_devices(with_array_api):
import dpnp

rng = np.random.default_rng(seed=123)
X = rng.random(size=(50, 3), dtype=np.float32)
X_cpu = dpnp.array(X, device="cpu")
X_gpu = dpnp.array(X, device="gpu")

err_match = "device|queue"

model = KMeans(algorithm="lloyd").fit(X_gpu)
with pytest.raises(ValueError, match=err_match):
_ = model.predict(X_cpu)
with pytest.raises(ValueError, match=err_match):
_ = model.transform(X_cpu)
with pytest.raises(ValueError, match=err_match):
_ = model.score(X_cpu)

model.fit(X_cpu)
with pytest.raises(ValueError, match=err_match):
_ = model.predict(X_gpu)
with pytest.raises(ValueError, match=err_match):
_ = model.transform(X_gpu)
with pytest.raises(ValueError, match=err_match):
_ = model.score(X_gpu)
Loading