2929
3030if sklearn_check_version ("1.9" ):
3131 from sklearn .utils ._sparse import _align_api_if_sparse
32+ from sklearn .utils ._array_api import get_namespace_and_device , move_to
3233
3334from onedal ._device_offload import _transfer_to_host
3435from onedal .utils ._array_api import _is_numpy_namespace
3738from .._utils import PatchingConditionsChain
3839from ..base import oneDALEstimator
3940from ..utils ._array_api import get_namespace
40- from ..utils .validation import validate_data
4141
4242
4343class KNeighborsDispatchingBase (oneDALEstimator ):
@@ -51,11 +51,20 @@ def _get_weights(self, dist, weights):
5151 # if user attempts to classify a point that was zero distance from one
5252 # or more training points, those training points are weighted as 1.0
5353 # and the other points as 0.0
54- with xp .errstate (divide = "ignore" ):
55- dist = 1.0 / dist
54+ if _is_numpy_namespace (xp ):
55+ with xp .errstate (divide = "ignore" ):
56+ dist = 1.0 / dist
57+ else :
58+ with warnings .catch_warnings ():
59+ warnings .simplefilter ("ignore" )
60+ dist = 1.0 / dist
5661 inf_mask = xp .isinf (dist )
5762 inf_row = xp .any (inf_mask , axis = 1 )
58- dist [inf_row ] = inf_mask [inf_row ]
63+ if _is_numpy_namespace (xp ):
64+ # Note: older numpy do not have 'np.astype'
65+ dist [inf_row ] = inf_mask [inf_row ]
66+ else :
67+ dist [inf_row ] = xp .astype (inf_mask [inf_row ], dist .dtype )
5968 return dist
6069 elif callable (weights ):
6170 return weights (dist )
@@ -84,11 +93,19 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t
8493 array-like
8594 Predicted values.
8695 """
87- xp , _ = get_namespace (y_train )
88- if not _is_numpy_namespace (xp ):
96+ # Note: in theory, the logic should be that 'y_train' should be converted
97+ # to the namespace of 'neigh_dist', but by this point, 'y_train' should
98+ # already have been moved to X's namespace, so it's fine to move 'neigh_dist'.
99+ if sklearn_check_version ("1.9" ):
100+ xp , _ , device = get_namespace_and_device (y_train )
101+ neigh_dist = move_to (neigh_dist , xp = xp , device = device )
102+ neigh_ind = move_to (neigh_ind , xp = xp , device = device )
103+ else :
104+ xp , _ = get_namespace (y_train )
89105 device = getattr (y_train , "device" , None )
90- neigh_dist = xp .asarray (neigh_dist , device = device )
91- neigh_ind = xp .asarray (neigh_ind , device = device )
106+ if not _is_numpy_namespace (xp ):
107+ neigh_dist = xp .asarray (neigh_dist , device = device )
108+ neigh_ind = xp .asarray (neigh_ind , device = device )
92109
93110 weights = self ._get_weights (neigh_dist , weights_param )
94111
@@ -113,9 +130,7 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t
113130 y_pred_shape = (neigh_ind .shape [0 ], _y .shape [1 ])
114131 if not _is_numpy_namespace (xp ):
115132 # Array API: pass device to ensure same device as input
116- y_pred = xp .empty (
117- y_pred_shape , dtype = neigh_dist .dtype , device = neigh_ind .device
118- )
133+ y_pred = xp .empty (y_pred_shape , dtype = neigh_dist .dtype , device = device )
119134 else :
120135 # Numpy: no device parameter
121136 y_pred = xp .empty (y_pred_shape , dtype = neigh_dist .dtype )
@@ -164,11 +179,16 @@ def _compute_class_probabilities(
164179 array-like
165180 Class probabilities.
166181 """
167- xp , _ = get_namespace (y_train )
168- if not _is_numpy_namespace (xp ):
182+ if sklearn_check_version ("1.9" ):
183+ xp , _ , device = get_namespace_and_device (y_train )
184+ neigh_dist = move_to (neigh_dist , xp = xp , device = device )
185+ neigh_ind = move_to (neigh_ind , xp = xp , device = device )
186+ else :
187+ xp , _ = get_namespace (y_train )
169188 device = getattr (y_train , "device" , None )
170- neigh_dist = xp .asarray (neigh_dist , device = device )
171- neigh_ind = xp .asarray (neigh_ind , device = device )
189+ if not _is_numpy_namespace (xp ):
190+ neigh_dist = xp .asarray (neigh_dist , device = device )
191+ neigh_ind = xp .asarray (neigh_ind , device = device )
172192
173193 _y = y_train
174194 classes_ = classes
@@ -207,9 +227,9 @@ def _compute_class_probabilities(
207227 proba_k = xp .zeros (
208228 (n_classes , n_queries ),
209229 dtype = neigh_dist .dtype ,
210- device = neigh_dist . device ,
230+ device = device ,
211231 )
212- zero = xp .zeros (1 , dtype = neigh_dist .dtype , device = neigh_dist . device )
232+ zero = xp .zeros (1 , dtype = neigh_dist .dtype , device = device )
213233 for c in range (n_classes ):
214234 mask = pred_labels == c
215235 proba_k [c , :] = xp .sum (xp .where (mask , weights_k , zero ), axis = 1 )
@@ -654,6 +674,8 @@ def _onedal_gpu_supported(self, method_name, *data):
654674 def _onedal_cpu_supported (self , method_name , * data ):
655675 return self ._onedal_supported ("cpu" , method_name , * data )
656676
677+ # Note: since this transfers the data to host, it doesn't validate
678+ # that the array namespaces and devices of 'X' and '_fit_X' match.
657679 def kneighbors_graph (self , X = None , n_neighbors = None , mode = "connectivity" ):
658680 check_is_fitted (self )
659681 if n_neighbors is None :
0 commit comments