-
Notifications
You must be signed in to change notification settings - Fork 187
FIX, MAINT: Implement 'everything follows X' and namespace checks for KNN #3127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
|
|
||
| if sklearn_check_version("1.9"): | ||
| from sklearn.utils._sparse import _align_api_if_sparse | ||
| from sklearn.utils._array_api import get_namespace_and_device, move_to | ||
|
|
||
| from onedal._device_offload import _transfer_to_host | ||
| from onedal.utils._array_api import _is_numpy_namespace | ||
|
|
@@ -37,7 +38,6 @@ | |
| from .._utils import PatchingConditionsChain | ||
| from ..base import oneDALEstimator | ||
| from ..utils._array_api import get_namespace | ||
| from ..utils.validation import validate_data | ||
|
|
||
|
|
||
| class KNeighborsDispatchingBase(oneDALEstimator): | ||
|
|
@@ -51,11 +51,20 @@ def _get_weights(self, dist, weights): | |
| # if user attempts to classify a point that was zero distance from one | ||
| # or more training points, those training points are weighted as 1.0 | ||
| # and the other points as 0.0 | ||
| with xp.errstate(divide="ignore"): | ||
| dist = 1.0 / dist | ||
| if _is_numpy_namespace(xp): | ||
| with xp.errstate(divide="ignore"): | ||
| dist = 1.0 / dist | ||
| else: | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore") | ||
| dist = 1.0 / dist | ||
| inf_mask = xp.isinf(dist) | ||
| inf_row = xp.any(inf_mask, axis=1) | ||
| dist[inf_row] = inf_mask[inf_row] | ||
| if _is_numpy_namespace(xp): | ||
| # Note: older numpy do not have 'np.astype' | ||
| dist[inf_row] = inf_mask[inf_row] | ||
| else: | ||
| dist[inf_row] = xp.astype(inf_mask[inf_row], dist.dtype) | ||
| return dist | ||
| elif callable(weights): | ||
| return weights(dist) | ||
|
|
@@ -84,11 +93,19 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t | |
| array-like | ||
| Predicted values. | ||
| """ | ||
| xp, _ = get_namespace(y_train) | ||
| if not _is_numpy_namespace(xp): | ||
| # Note: in theory, the logic should be that 'y_train' should be converted | ||
| # to the namespace of 'neigh_dist', but by this point, 'y_train' should | ||
| # already have been moved to X's namespace, so it's fine to move 'neigh_dist'. | ||
| if sklearn_check_version("1.9"): | ||
| xp, _, device = get_namespace_and_device(y_train) | ||
| neigh_dist = move_to(neigh_dist, xp=xp, device=device) | ||
| neigh_ind = move_to(neigh_ind, xp=xp, device=device) | ||
| else: | ||
| xp, _ = get_namespace(y_train) | ||
| device = getattr(y_train, "device", None) | ||
| neigh_dist = xp.asarray(neigh_dist, device=device) | ||
| neigh_ind = xp.asarray(neigh_ind, device=device) | ||
| if not _is_numpy_namespace(xp): | ||
| neigh_dist = xp.asarray(neigh_dist, device=device) | ||
| neigh_ind = xp.asarray(neigh_ind, device=device) | ||
|
|
||
| weights = self._get_weights(neigh_dist, weights_param) | ||
|
|
||
|
|
@@ -113,9 +130,7 @@ def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_t | |
| y_pred_shape = (neigh_ind.shape[0], _y.shape[1]) | ||
| if not _is_numpy_namespace(xp): | ||
| # Array API: pass device to ensure same device as input | ||
| y_pred = xp.empty( | ||
| y_pred_shape, dtype=neigh_dist.dtype, device=neigh_ind.device | ||
| ) | ||
| y_pred = xp.empty(y_pred_shape, dtype=neigh_dist.dtype, device=device) | ||
| else: | ||
| # Numpy: no device parameter | ||
| y_pred = xp.empty(y_pred_shape, dtype=neigh_dist.dtype) | ||
|
|
@@ -164,11 +179,16 @@ def _compute_class_probabilities( | |
| array-like | ||
| Class probabilities. | ||
| """ | ||
| xp, _ = get_namespace(y_train) | ||
| if not _is_numpy_namespace(xp): | ||
| if sklearn_check_version("1.9"): | ||
| xp, _, device = get_namespace_and_device(y_train) | ||
| neigh_dist = move_to(neigh_dist, xp=xp, device=device) | ||
| neigh_ind = move_to(neigh_ind, xp=xp, device=device) | ||
| else: | ||
| xp, _ = get_namespace(y_train) | ||
| device = getattr(y_train, "device", None) | ||
| neigh_dist = xp.asarray(neigh_dist, device=device) | ||
| neigh_ind = xp.asarray(neigh_ind, device=device) | ||
| if not _is_numpy_namespace(xp): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the logic of this if-statement? As I understand it previously it was that neigh_dist and neigh_ind originally have numpy type and we only need to convert them if y is not a numpy. Is it correct? If yes do we need the same logic after the array-api update?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is how it was before if you look at the changes. I guess the purpose is to have them work with the other arrays. |
||
| neigh_dist = xp.asarray(neigh_dist, device=device) | ||
| neigh_ind = xp.asarray(neigh_ind, device=device) | ||
|
|
||
| _y = y_train | ||
| classes_ = classes | ||
|
|
@@ -207,9 +227,9 @@ def _compute_class_probabilities( | |
| proba_k = xp.zeros( | ||
| (n_classes, n_queries), | ||
| dtype=neigh_dist.dtype, | ||
| device=neigh_dist.device, | ||
| device=device, | ||
| ) | ||
| zero = xp.zeros(1, dtype=neigh_dist.dtype, device=neigh_dist.device) | ||
| zero = xp.zeros(1, dtype=neigh_dist.dtype, device=device) | ||
| for c in range(n_classes): | ||
| mask = pred_labels == c | ||
| proba_k[c, :] = xp.sum(xp.where(mask, weights_k, zero), axis=1) | ||
|
|
@@ -654,6 +674,8 @@ def _onedal_gpu_supported(self, method_name, *data): | |
| def _onedal_cpu_supported(self, method_name, *data): | ||
| return self._onedal_supported("cpu", method_name, *data) | ||
|
|
||
| # Note: since this transfers the data to host, it doesn't validate | ||
| # that the array namespaces and devices of 'X' and '_fit_X' match. | ||
| def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"): | ||
| check_is_fitted(self) | ||
| if n_neighbors is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need some explanation about why do we need this numpy check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has a different codepath for numpy with operations that are not supported by array API.