Skip to content

Commit 1507c24

Browse files
namespace checks and movements for KNN
1 parent d95c247 commit 1507c24

6 files changed

Lines changed: 210 additions & 28 deletions

File tree

sklearnex/neighbors/_lof.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from sklearnex.neighbors.knn_unsupervised import NearestNeighbors
3232

3333
from ..utils._array_api import enable_array_api, get_namespace
34-
from ..utils.validation import validate_data
34+
35+
if sklearn_check_version("1.9"):
36+
from sklearn.utils._array_api import check_same_namespace
3537

3638

3739
@enable_array_api
@@ -141,6 +143,10 @@ def fit(self, X, y=None):
141143
self._fit_X = xp.asarray(self._fit_X, device=device)
142144
return self
143145

146+
# Note: this is overriding an internal method from scikit-learn with
147+
# the same signature. In this case, 'validate_data' is called during
148+
# 'decision_function', which calls '.kneighbors()'. Hence, it doesn't
149+
# need to validate the namespace of 'X' with '_fit_X' here.
144150
def _predict(self, X=None):
145151
check_is_fitted(self)
146152

sklearnex/neighbors/common.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
if sklearn_check_version("1.9"):
3131
from sklearn.utils._sparse import _align_api_if_sparse
32+
from sklearn.utils._array_api import get_namespace_and_device, move_to
3233

3334
from onedal._device_offload import _transfer_to_host
3435
from onedal.utils._array_api import _is_numpy_namespace
@@ -37,7 +38,6 @@
3738
from .._utils import PatchingConditionsChain
3839
from ..base import oneDALEstimator
3940
from ..utils._array_api import get_namespace
40-
from ..utils.validation import validate_data
4141

4242

4343
class KNeighborsDispatchingBase(oneDALEstimator):
@@ -51,11 +51,20 @@ def _get_weights(self, dist, weights):
5151
# if user attempts to classify a point that was zero distance from one
5252
# or more training points, those training points are weighted as 1.0
5353
# and the other points as 0.0
54-
with xp.errstate(divide="ignore"):
55-
dist = 1.0 / dist
54+
if _is_numpy_namespace(xp):
55+
with xp.errstate(divide="ignore"):
56+
dist = 1.0 / dist
57+
else:
58+
with warnings.catch_warnings():
59+
warnings.simplefilter("ignore")
60+
dist = 1.0 / dist
5661
inf_mask = xp.isinf(dist)
5762
inf_row = xp.any(inf_mask, axis=1)
58-
dist[inf_row] = inf_mask[inf_row]
63+
if _is_numpy_namespace(xp):
64+
# Note: older numpy do not have 'np.astype'
65+
dist[inf_row] = inf_mask[inf_row]
66+
else:
67+
dist[inf_row] = xp.astype(inf_mask[inf_row], dist.dtype)
5968
return dist
6069
elif callable(weights):
6170
return weights(dist)
@@ -84,11 +93,19 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t
8493
array-like
8594
Predicted values.
8695
"""
87-
xp, _ = get_namespace(y_train)
88-
if not _is_numpy_namespace(xp):
96+
# Note: in theory, the logic should be that 'y_train' should be converted
97+
# to the namespace of 'neigh_dist', but by this point, 'y_train' should
98+
# already have been moved to X's namespace, so it's fine to move 'neigh_dist'.
99+
if sklearn_check_version("1.9"):
100+
xp, _, device = get_namespace_and_device(y_train)
101+
neigh_dist = move_to(neigh_dist, xp=xp, device=device)
102+
neigh_ind = move_to(neigh_ind, xp=xp, device=device)
103+
else:
104+
xp, _ = get_namespace(y_train)
89105
device = getattr(y_train, "device", None)
90-
neigh_dist = xp.asarray(neigh_dist, device=device)
91-
neigh_ind = xp.asarray(neigh_ind, device=device)
106+
if not _is_numpy_namespace(xp):
107+
neigh_dist = xp.asarray(neigh_dist, device=device)
108+
neigh_ind = xp.asarray(neigh_ind, device=device)
92109

93110
weights = self._get_weights(neigh_dist, weights_param)
94111

@@ -113,9 +130,7 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t
113130
y_pred_shape = (neigh_ind.shape[0], _y.shape[1])
114131
if not _is_numpy_namespace(xp):
115132
# Array API: pass device to ensure same device as input
116-
y_pred = xp.empty(
117-
y_pred_shape, dtype=neigh_dist.dtype, device=neigh_ind.device
118-
)
133+
y_pred = xp.empty(y_pred_shape, dtype=neigh_dist.dtype, device=device)
119134
else:
120135
# Numpy: no device parameter
121136
y_pred = xp.empty(y_pred_shape, dtype=neigh_dist.dtype)
@@ -164,11 +179,16 @@ def _compute_class_probabilities(
164179
array-like
165180
Class probabilities.
166181
"""
167-
xp, _ = get_namespace(y_train)
168-
if not _is_numpy_namespace(xp):
182+
if sklearn_check_version("1.9"):
183+
xp, _, device = get_namespace_and_device(y_train)
184+
neigh_dist = move_to(neigh_dist, xp=xp, device=device)
185+
neigh_ind = move_to(neigh_ind, xp=xp, device=device)
186+
else:
187+
xp, _ = get_namespace(y_train)
169188
device = getattr(y_train, "device", None)
170-
neigh_dist = xp.asarray(neigh_dist, device=device)
171-
neigh_ind = xp.asarray(neigh_ind, device=device)
189+
if not _is_numpy_namespace(xp):
190+
neigh_dist = xp.asarray(neigh_dist, device=device)
191+
neigh_ind = xp.asarray(neigh_ind, device=device)
172192

173193
_y = y_train
174194
classes_ = classes
@@ -207,9 +227,9 @@ def _compute_class_probabilities(
207227
proba_k = xp.zeros(
208228
(n_classes, n_queries),
209229
dtype=neigh_dist.dtype,
210-
device=neigh_dist.device,
230+
device=device,
211231
)
212-
zero = xp.zeros(1, dtype=neigh_dist.dtype, device=neigh_dist.device)
232+
zero = xp.zeros(1, dtype=neigh_dist.dtype, device=device)
213233
for c in range(n_classes):
214234
mask = pred_labels == c
215235
proba_k[c, :] = xp.sum(xp.where(mask, weights_k, zero), axis=1)
@@ -654,6 +674,8 @@ def _onedal_gpu_supported(self, method_name, *data):
654674
def _onedal_cpu_supported(self, method_name, *data):
655675
return self._onedal_supported("cpu", method_name, *data)
656676

677+
# Note: since this transfers the data to host, it doesn't validate
678+
# that the array namespaces and devices of 'X' and '_fit_X' match.
657679
def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
658680
check_is_fitted(self)
659681
if n_neighbors is None:

sklearnex/neighbors/knn_classification.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333
from ..utils.validation import validate_data
3434
from .common import KNeighborsDispatchingBase
3535

36+
if sklearn_check_version("1.9"):
37+
from sklearn.utils._array_api import (
38+
check_same_namespace,
39+
get_namespace_and_device,
40+
move_to,
41+
)
42+
3643

3744
@enable_array_api
3845
@control_n_jobs(
@@ -72,7 +79,12 @@ def __init__(
7279
)
7380

7481
def fit(self, X, y):
75-
xp, is_array_api = get_namespace(X)
82+
if sklearn_check_version("1.9"):
83+
xp, is_array_api, device = get_namespace_and_device(X)
84+
else:
85+
xp, is_array_api = get_namespace(X)
86+
device = getattr(X, "device", None)
87+
7688
dispatch(
7789
self,
7890
"fit",
@@ -86,7 +98,6 @@ def fit(self, X, y):
8698
# Ensure _fit_X matches the input namespace so that
8799
# kneighbors(X=None) can use get_namespace(self._fit_X).
88100
if is_array_api and not _is_numpy_namespace(xp):
89-
device = getattr(X, "device", None)
90101
self._fit_X = xp.asarray(self._fit_X, device=device)
91102
return self
92103

@@ -169,7 +180,7 @@ def _onedal_fit(self, X, y, queue=None):
169180
)
170181

171182
# Process classification targets before passing to onedal
172-
self._process_classification_targets(y, skip_validation=False)
183+
self._process_classification_targets(X, y, skip_validation=False)
173184

174185
# Call onedal backend
175186
onedal_params = {
@@ -200,7 +211,7 @@ def _onedal_fit(self, X, y, queue=None):
200211
# Post-processing
201212
self._save_attributes()
202213

203-
def _process_classification_targets(self, y, skip_validation=False):
214+
def _process_classification_targets(self, X, y, skip_validation=False):
204215
"""Process classification targets and set class-related attributes.
205216
206217
Parameters
@@ -246,6 +257,10 @@ def _process_classification_targets(self, y, skip_validation=False):
246257
self.classes_ = self.classes_[0]
247258
self._y = xp.reshape(self._y, (-1,))
248259

260+
if sklearn_check_version("1.9"):
261+
xp_X, _, device = get_namespace_and_device(X)
262+
self._y = move_to(self._y, xp=xp_X, device=device)
263+
249264
def _onedal_predict(self, X, queue=None):
250265
if X is not None:
251266
xp, _ = get_namespace(X)
@@ -256,14 +271,20 @@ def _onedal_predict(self, X, queue=None):
256271
accept_sparse="csr",
257272
reset=False,
258273
)
274+
if sklearn_check_version("1.9"):
275+
check_same_namespace(X, self, attribute="_fit_X", method="predict")
259276

260277
params = self._onedal_estimator._get_onedal_params(X)
261278
params["result_option"] = "responses"
262279
result = self._onedal_estimator._onedal_predict(
263280
self._onedal_estimator._onedal_model, X, params
264281
)
265-
xp, _ = get_namespace(X)
266282
responses = from_table(result.responses, like=X)
283+
if sklearn_check_version("1.9"):
284+
xp, _, device = get_namespace_and_device(self.classes_)
285+
responses = move_to(responses, xp=xp, device=device)
286+
else:
287+
xp, _ = get_namespace(X)
267288
return xp.take(
268289
self.classes_, xp.asarray(xp.reshape(responses, (-1,)), dtype=xp.int64)
269290
)
@@ -278,6 +299,8 @@ def _onedal_predict_proba(self, X, queue=None):
278299
accept_sparse="csr",
279300
reset=False,
280301
)
302+
if sklearn_check_version("1.9"):
303+
check_same_namespace(X, self, attribute="_fit_X", method="predict_proba")
281304

282305
neigh_dist, neigh_ind = self._onedal_estimator.kneighbors(X)
283306

@@ -299,6 +322,8 @@ def _onedal_kneighbors(
299322
accept_sparse="csr",
300323
reset=False,
301324
)
325+
if sklearn_check_version("1.9"):
326+
check_same_namespace(X, self, attribute="_fit_X", method="kneighbors")
302327
else:
303328
query_is_train = True
304329
X = self._fit_X

sklearnex/neighbors/knn_regression.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
from ..utils.validation import validate_data
3232
from .common import KNeighborsDispatchingBase
3333

34+
if sklearn_check_version("1.9"):
35+
from sklearn.utils._array_api import (
36+
check_same_namespace,
37+
get_namespace_and_device,
38+
move_to,
39+
)
40+
3441

3542
@enable_array_api("1.5") # validate_data y_numeric requires sklearn >=1.5
3643
@control_n_jobs(decorated_methods=["fit", "predict", "kneighbors", "score"])
@@ -68,7 +75,11 @@ def __init__(
6875
)
6976

7077
def fit(self, X, y):
71-
xp, is_array_api = get_namespace(X)
78+
if sklearn_check_version("1.9"):
79+
xp, is_array_api, device = get_namespace_and_device(X)
80+
else:
81+
xp, is_array_api = get_namespace(X)
82+
device = getattr(X, "device", None)
7283
dispatch(
7384
self,
7485
"fit",
@@ -82,7 +93,6 @@ def fit(self, X, y):
8293
# Ensure _fit_X matches the input namespace so that
8394
# kneighbors(X=None) can use get_namespace(self._fit_X).
8495
if is_array_api and not _is_numpy_namespace(xp):
85-
device = getattr(X, "device", None)
8696
self._fit_X = xp.asarray(self._fit_X, device=device)
8797
return self
8898

@@ -138,7 +148,10 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
138148
)
139149

140150
def _onedal_fit(self, X, y, queue=None):
141-
xp, _ = get_namespace(X, y)
151+
if sklearn_check_version("1.9"):
152+
xp, _, device = get_namespace_and_device(X)
153+
else:
154+
xp, _ = get_namespace(X, y)
142155
self._set_effective_metric()
143156

144157
X, y = validate_data(
@@ -151,6 +164,9 @@ def _onedal_fit(self, X, y, queue=None):
151164
y_numeric=True,
152165
)
153166

167+
if sklearn_check_version("1.9"):
168+
y = move_to(y, xp=xp, device=device)
169+
154170
self._process_regression_targets(y)
155171
onedal_params = {
156172
"n_neighbors": self.n_neighbors,
@@ -215,6 +231,13 @@ def _predict_gpu(self, X, queue=None):
215231
accept_sparse="csr",
216232
reset=False,
217233
)
234+
# Note: if called before 'validate_data', this check would fail if 'X' is
235+
# a 'DataFrame', since '_fit_X' would have already been converted to NumPy.
236+
# Hence, it must come after the call to 'validate_data'. If the behavior
237+
# of this validator changes in scikit-learn, these checks could be done
238+
# earlier in the code for quicker errors.
239+
if sklearn_check_version("1.9"):
240+
check_same_namespace(X, self, attribute="_fit_X", method="predict")
218241
result = self._onedal_estimator._predict_gpu(X)
219242
return result
220243

@@ -246,6 +269,8 @@ def _predict_skl(self, X, queue=None):
246269
X = validate_data(
247270
self, X, dtype=[xp.float64, xp.float32], accept_sparse="csr", reset=False
248271
)
272+
if sklearn_check_version("1.9"):
273+
check_same_namespace(X, self, attribute="_fit_X", method="predict")
249274
return self._predict_skl_regression(X)
250275

251276
def _onedal_kneighbors(
@@ -262,6 +287,8 @@ def _onedal_kneighbors(
262287
accept_sparse="csr",
263288
reset=False,
264289
)
290+
if sklearn_check_version("1.9"):
291+
check_same_namespace(X, self, attribute="_fit_X", method="kenighbors")
265292
else:
266293
query_is_train = True
267294
X = self._fit_X

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from ..utils.validation import validate_data
3030
from .common import KNeighborsDispatchingBase
3131

32+
if sklearn_check_version("1.9"):
33+
from sklearn.utils._array_api import check_same_namespace, get_namespace_and_device
34+
3235

3336
@enable_array_api
3437
@control_n_jobs(decorated_methods=["fit", "kneighbors", "radius_neighbors"])
@@ -66,7 +69,12 @@ def __init__(
6669
)
6770

6871
def fit(self, X, y=None):
69-
xp, is_array_api = get_namespace(X)
72+
if sklearn_check_version("1.9"):
73+
xp, is_array_api, device = get_namespace_and_device(X)
74+
else:
75+
xp, is_array_api = get_namespace(X)
76+
device = getattr(X, "device", None)
77+
7078
dispatch(
7179
self,
7280
"fit",
@@ -80,7 +88,6 @@ def fit(self, X, y=None):
8088
# Ensure _fit_X matches the input namespace so that
8189
# kneighbors(X=None) can use get_namespace(self._fit_X).
8290
if is_array_api and not _is_numpy_namespace(xp):
83-
device = getattr(X, "device", None)
8491
self._fit_X = xp.asarray(self._fit_X, device=device)
8592
return self
8693

@@ -189,6 +196,8 @@ def _onedal_predict(self, X, queue=None):
189196
reset=False,
190197
force_all_finite=False,
191198
)
199+
if sklearn_check_version("1.9"):
200+
check_same_namespace(X, self, attribute="_fit_X", method="predict")
192201
return self._onedal_estimator.predict(X, queue=queue)
193202

194203
def _onedal_kneighbors(
@@ -205,6 +214,8 @@ def _onedal_kneighbors(
205214
accept_sparse="csr",
206215
reset=False,
207216
)
217+
if sklearn_check_version("1.9"):
218+
check_same_namespace(X, self, attribute="_fit_X", method="kneighbors")
208219
else:
209220
query_is_train = True
210221
X = self._fit_X

0 commit comments

Comments
 (0)