|
| 1 | +""" |
| 2 | +Data Valuation with Nearest Neighbor Explainers |
| 3 | +=============================================== |
| 4 | +
|
| 5 | +This notebook shows how explainers of nearest-neighbor (NN) models can be used for Data Valuation, the task of evaluating the usefulness of individual training data points in classification problems. |
| 6 | +When explaining NN models, a game is defined by first choosing an explanation point :math:`x_\\text{explain}` and class :math:`y_\\text{explain}`; the training data points :math:`\\mathcal{D} := \\mathcal{X} \\times \\mathcal{Y}` are the game's players, and the definition of the utility :math:`\\nu(S)` of a coalition :math:`S \\subseteq \\mathcal{D}` is based on the probability of the model predicting class :math:`y_\\text{explain}` on :math:`x_\\text{explain}` if it's training data were limited to :math:`S`. |
| 7 | +""" |
| 8 | + |
| 9 | +# %% |
| 10 | +# There is support for explaining the the ``KNeighborsClassifier`` model (with ``'uniform'`` or ``'distance'`` weights) and ``RadiusNeighborsClassifier`` model from the ``scikit-learn`` library. |
| 11 | +# The algorithms are based on the publications from `Jia et al. (2019) <https://doi.org/10.48550/arXiv.1908.08619/>`__, |
| 12 | +# `Wang et al. (2024) <https://doi.org/10.48550/arXiv.1908.08619>`__ |
| 13 | +# and `Wang et al. (2023) <https://doi.org/10.48550/arXiv.2308.15709>`__, respectively. |
| 14 | +# |
| 15 | +# Let's start by generating a synthetic classification datset and fitting a simple `KNeighborsClassifier` to it. |
| 16 | + |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import numpy as np |
| 19 | +from sklearn.datasets import make_classification |
| 20 | +from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier |
| 21 | +from util_plot import plot_datasets |
| 22 | + |
| 23 | +X_train, y_train = make_classification( |
| 24 | + n_samples=30, |
| 25 | + n_features=2, |
| 26 | + n_redundant=0, |
| 27 | + n_clusters_per_class=1, |
| 28 | + n_informative=2, |
| 29 | + n_classes=2, |
| 30 | + random_state=45, |
| 31 | +) |
| 32 | + |
| 33 | +fig, ax = plt.subplots(figsize=(6, 6)) |
| 34 | +plot_datasets(ax, X_train, y_train) |
| 35 | + |
| 36 | +model = KNeighborsClassifier(n_neighbors=3) |
| 37 | +model.fit(X_train, y_train) |
| 38 | + |
| 39 | +x_explain = np.array([[-0.75, -0.4]]) |
| 40 | +y_explain_pred = model.predict(x_explain)[0] |
| 41 | +print(f"Prediction: class {y_explain_pred}") |
| 42 | + |
| 43 | +y_explain_proba = model.predict_proba(x_explain)[0] |
| 44 | +print(f"Prediction probabilities: {y_explain_proba}") |
| 45 | + |
| 46 | + |
| 47 | +# %% |
| 48 | +# Using the ``KNNExplainer`` for Unweighted :math:`k`-Nearest Neighbor Models |
| 49 | +# --------------------------------------------------------------------------- |
| 50 | +# |
| 51 | +# To explain the prediction, we create an explainer for the model by passing it to the constructor of ``Explainer``, which will automatically dispatch to the adequate subclass ``KNNExplainer``. |
| 52 | + |
| 53 | +from shapiq import Explainer |
| 54 | + |
| 55 | +explainer = Explainer(model, class_index=y_explain_pred, index="SV", max_order=1) |
| 56 | +print(type(explainer)) |
| 57 | + |
| 58 | + |
| 59 | +# %% |
| 60 | +# Note that we set ``class_index=y_explain_pred``, since for now, we want to quantify the contribution of the training data to the class that was actually predicted. (We could also set a different class index if we wished to see how much the data points contribute to shifting the prediction towards another class.) |
| 61 | +# |
| 62 | +# Now we can get an explanation for the prediction we saw above: |
| 63 | + |
| 64 | +iv = explainer.explain(x_explain) |
| 65 | +print(iv) |
| 66 | + |
| 67 | + |
| 68 | +# %% |
| 69 | +# Explaining Weighted :math:`k`-Nearest Neighbor and Threshold Nearest Neighbor Models |
| 70 | +# ------------------------------------------------------------------------------------ |
| 71 | + |
| 72 | +# %% |
| 73 | +# There are separate explainers for weighted :math:`k`-NN and threshold NN models, which are selected automatically when an `Explainer` is instantiated with a corresponding model: |
| 74 | + |
| 75 | +wknn_model = KNeighborsClassifier(n_neighbors=3, weights="distance") |
| 76 | +wknn_model.fit(X_train, y_train) |
| 77 | +wknn_explainer = Explainer(wknn_model, class_index=0, index="SV", max_order=1) |
| 78 | +print(type(wknn_explainer)) |
| 79 | + |
| 80 | +tnn_model = RadiusNeighborsClassifier() |
| 81 | +tnn_model.fit(X_train, y_train) |
| 82 | +tnn_explainer = Explainer(tnn_model, class_index=0, index="SV", max_order=1) |
| 83 | +print(type(tnn_explainer)) |
| 84 | + |
| 85 | + |
| 86 | +# %% |
| 87 | +# They can be used just the same way: |
| 88 | + |
| 89 | +print(wknn_explainer.explain(x_explain)) |
| 90 | +print(tnn_explainer.explain(x_explain)) |
| 91 | + |
| 92 | + |
| 93 | +# %% |
| 94 | +# Large numbers of training samples |
| 95 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 96 | +# |
| 97 | +# Since the algorithms are pretty efficient, we can run them on large sets of training data. |
| 98 | + |
| 99 | +from time import time |
| 100 | + |
| 101 | + |
| 102 | +def print_explain_times(model, n, n_test) -> None: |
| 103 | + X_train, y_train = make_classification( |
| 104 | + n_samples=n, |
| 105 | + n_features=5, |
| 106 | + n_redundant=0, |
| 107 | + n_clusters_per_class=1, |
| 108 | + n_informative=3, |
| 109 | + n_classes=2, |
| 110 | + random_state=45, |
| 111 | + ) |
| 112 | + X_test = X_train[:n_test] |
| 113 | + X_train = X_train[n_test:] |
| 114 | + y_train = y_train[n_test:] |
| 115 | + model.fit(X_train, y_train) |
| 116 | + explainer = Explainer(model, class_index=0, index="SV", max_order=1) |
| 117 | + |
| 118 | + times = np.zeros((n_test,)) |
| 119 | + for i, x_test in enumerate(X_test): |
| 120 | + t_start = time() |
| 121 | + explainer.explain(x_test) |
| 122 | + t_end = time() |
| 123 | + times[i] = t_end - t_start |
| 124 | + mean = np.mean(times) * 1000 |
| 125 | + std = np.std(times) * 1000 |
| 126 | + print(f"{explainer.__class__.__name__} on {n} samples: average {mean:.1f}±{std:.1f}ms") |
| 127 | + |
| 128 | + |
| 129 | +# %% |
| 130 | +# The cell below which uses the KNN explainer takes roughly 0.15 s to explain a single data point on a consumer-grade laptop with a 12th Gen Intel i5 processor. |
| 131 | + |
| 132 | +print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="uniform"), n=100_000, n_test=50) |
| 133 | + |
| 134 | + |
| 135 | +# %% |
| 136 | +# Since the algorithm of the WKNN explainer is less efficient, featuring a quadratic runtime complexity, the number of data points needs to be limited. |
| 137 | + |
| 138 | +print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="distance"), n=200, n_test=10) |
| 139 | + |
| 140 | + |
| 141 | +# %% |
| 142 | +# The TNN algorithm, on the other hand, is faster: |
| 143 | + |
| 144 | +print_explain_times(RadiusNeighborsClassifier(radius=5), n=100_000, n_test=50) |
| 145 | + |
| 146 | + |
| 147 | +# %% |
| 148 | +# ## Identifying corrupted training samples |
| 149 | +# ----------------------------------------- |
| 150 | +# |
| 151 | +# We can estimate the usefulness of each point of a training datset by calculating Shapley values for a set of test data points and averaging the results. This will allow us to identify potentially mislabeled data points. |
| 152 | +# |
| 153 | +# First, let's create a classification datset and split it into train and test sets. We will corrupt the training data by changing the class of a few randomly selected data points. |
| 154 | + |
| 155 | +from sklearn.model_selection import train_test_split |
| 156 | + |
| 157 | +X, y = make_classification( |
| 158 | + n_samples=100, |
| 159 | + n_features=2, |
| 160 | + n_redundant=0, |
| 161 | + n_clusters_per_class=1, |
| 162 | + n_informative=2, |
| 163 | + n_classes=2, |
| 164 | + flip_y=0, |
| 165 | + random_state=49, |
| 166 | + class_sep=1.5, |
| 167 | +) |
| 168 | + |
| 169 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42) |
| 170 | + |
| 171 | +y_train_corrupted = y_train.copy() |
| 172 | +n_corrupt = 7 |
| 173 | +rng = np.random.default_rng(seed=43) |
| 174 | +corrupted = rng.choice(np.arange(X_train.shape[0]), size=n_corrupt, replace=False) |
| 175 | +# Since our only class indices are 0 and 1, this is a quick way to flip the class |
| 176 | +y_train_corrupted[corrupted] = 1 - y_train[corrupted] |
| 177 | + |
| 178 | +fig, ax = plt.subplots(figsize=(6, 6)) |
| 179 | +plot_datasets(ax, X_train, y_train_corrupted, X_test, y_test) |
| 180 | +# Mark corrupted datapoints |
| 181 | +ax.scatter( |
| 182 | + X_train[corrupted, 0], |
| 183 | + X_train[corrupted, 1], |
| 184 | + marker="o", |
| 185 | + edgecolors="#b1170c", |
| 186 | + facecolors="none", |
| 187 | + s=100, |
| 188 | +) |
| 189 | +# %% |
| 190 | +# Now, we can use the `KNNExplainer` to compute the training points' Shapley values based on the entire test dataset by averaging the Shapley values computed using each test point. |
| 191 | + |
| 192 | +# Train the model with the corrupted training data |
| 193 | +model = KNeighborsClassifier(n_neighbors=5) |
| 194 | +model.fit(X_train, y_train_corrupted) |
| 195 | + |
| 196 | +sv_test = np.zeros(X_train.shape[0], dtype=np.float64) |
| 197 | + |
| 198 | +for x_test_current, y_test_current in zip(X_test, y_test, strict=True): |
| 199 | + explainer = Explainer(model, class_index=y_test_current, index="SV", max_order=1) |
| 200 | + iv = explainer.explain(x_test_current) |
| 201 | + sv_test += iv.to_first_order_array() |
| 202 | + |
| 203 | +sv_test /= X_test.shape[0] |
| 204 | + |
| 205 | + |
| 206 | +# %% |
| 207 | +# We can reasonably assume that the corrupted training data points will on average make the model's prediction worse, resulting in negative Shapley values. So let's filter out just those indices where the Shapley value is below zero and compare with our original array of corrupted indices: |
| 208 | + |
| 209 | +print(f"Corrupted: {np.sort(corrupted)}") # Sort for easier comparison |
| 210 | +print(f"Negative Shapley values: {np.where(sv_test < 0)[0]}") |
| 211 | + |
| 212 | + |
| 213 | +# %% |
| 214 | +# We have identified the set corrupted samples almost exactly. |
0 commit comments