Skip to content

Commit f9565f0

Browse files
Zaphooodmmschlk
andauthored
Add explainers for nearest neighbor models (#487)
* feat: copy nn explainers from shapiq_student * feat: integrate nn explainers into dynamic explainer dispatch * feat: remove unused property `mode` from nn explainers * feat: clean up NN explainer base class * refactor: turn BruteForceKNNExplainer into benchmark game * refactor: factor out base class for nn explainer benchmarks * refactor: delete unused brute force knn explainer * refactor: turn BruteForceWKNNExplainer into benchmark game * feat: merge WeightedKNNExplainer with its base class * docs: improve wording in WeightedKNNExplainer docstring * refactor: clean up nn benchmark games * refacgtor delete unused brute force tnn explainer * refactor: delete unused lookup game * refactor: merge CommonKNNExplainer into its subclasses * feat: check that index and order are valid for nn explainers * refactor: remove custom exception classes * refactor: clean up nn explainers' constructors * fix: set normalization value in NN benchmark game base * fix: sort coalition in KNN benchmark game * tests: add unit test for KNN explainer * tests: add unit test for WKNN explainer * refactor: rename WKNN test function * feat: add TNN benchmark game * fix: utility of coalition with no points in radius of threshold nn classifier * tests: add unit test for TNN explainer * tests: test NN explainers with all train points instead of just 1 * fix: handle case N < k in KNN benchmark and add test * feat: handle access to sklearn private members gracefully * fix: remove kwarg `class_index` from KNN's explain_function * refactor: factor out InteractionValues from/to array helpers * feat: add notebook for NN explainers * chore: remove obsolete TODO * tests: use randomly generated test points for testing NN explainers * feat: use footnote citations in docstrings * feat: improve NN explainer base class docstring * fix: unify spelling of 'nearest neighbor' without hyphen * feat: improve wording in notebook * document new explainers in changelog * improve wording in changelog * fix: unterminated f-string literal * fix: execute data valuation notebook * fix: name tests correctly to make them discoverable * refactor: place nn explainer benchmark games alongside efficient explainers * refactor: parametrize nn explainer unit tests * tests: tests error handling of nn explainer base class * feat: test knn explainer error handling * tests: increase radius of TNN classifier * tests: test error handling of tnn classifier * chore: delete duplicate function * feat: add more data checks to nn explainer games * tests: test nn explainer game base * tests: test index/max_order verification util * tests: test warning for ignored parameters * tests: add case for automatic dispatch to nn explainers * feat: add stress test for knn explainer to notebook * refactor: rename binary weighted knn explainer game * refactor: add conversion from/to array into InteractionValues object * fix: value of empty coalition of tnn game * feat: replace binom fraction with product * feat: optimize tnn explainer * refactor: rename knn data valuation notebook * fix: re-run knn notebook * feat: demonstrate performance of wknn, tnn explainers in notebook * chore: remove first order index check when creating iv from array * tests: add unit tests from iv from/to array * refactor: require setting `class_index` for nn explainers * chore: fix imports * chore: fix type errors * updated CHANGELOG.md Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> * Made NN Explainer only work for SVs. Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> * updated reference urls Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> * fixed issue in background_clf_dataset_small Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> * updated docs inconsistency Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> * wip: add sphinx-gallery script for nn example * fix: set index to SV in nn explainer example * fix: track plotting helpers for nn example * chore: delete old nn notebook * fix rest syntax * rename plot helpers * fix imports * delete lightgbm example * fix underline length * add sphinx gallery ignore pattern * rename nn example to start with `plot_` * Revert "delete lightgbm example" This reverts commit c47ea53. --------- Signed-off-by: Maximilian <maximilian.muschalik@gmail.com> Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>
1 parent 8f23fbd commit f9565f0

31 files changed

Lines changed: 1909 additions & 46 deletions

CHANGELOG.md

Lines changed: 70 additions & 38 deletions
Large diffs are not rendered by default.

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"doc_module": ("shapiq",),
5858
"reference_url": {"shapiq": None},
5959
"filename_pattern": r"plot_.*\.py",
60+
"ignore_pattern": r"util_.*\.py",
6061
"plot_gallery": True,
6162
"download_all_examples": False,
6263
"show_signature": False,

docs/source/references.bib

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,31 @@ @inproceedings{Yu.2022
285285
booktitle = {Advances in Neural Information Processing Systems 35: Annual Conference on Neural Information Processing Systems 2022 ({NeurIPS} 2022)},
286286
url = {http://papers.nips.cc/paper\_files/paper/2022/hash/a5a3b1ef79520b7cd122d888673a3ebc-Abstract-Conference.html}
287287
}
288+
@article{Jia.2019,
289+
title = {Efficient task-specific data valuation for nearest neighbor algorithms},
290+
author = {Jia, Ruoxi and Dao, David and Wang, Boxin and Hubis, Frances Ann and Gurel, Nezihe Merve and Li, Bo and Zhang, Ce and Spanos, Costas J and Song, Dawn},
291+
journal = {arXiv preprint arXiv:1908.08619},
292+
year = {2019},
293+
url = {https://doi.org/10.48550/arXiv.1908.08619}
294+
}
295+
@article{Wang.2023,
296+
title = {A privacy-friendly approach to data valuation},
297+
author = {Wang, Jiachen Tianhao and Zhu, Yuqing and Wang, Yu-Xiang and Jia, Ruoxi and Mittal, Prateek},
298+
journal = {Advances in Neural Information Processing Systems},
299+
volume = {36},
300+
pages = {60429--60467},
301+
year = {2023},
302+
url = {https://openreview.net/forum?id=FAZ3i0hvm0}
303+
}
304+
@inproceedings{Wang.2024,
305+
title = {Efficient data shapley for weighted nearest neighbor algorithms},
306+
author = {Wang, Jiachen T and Mittal, Prateek and Jia, Ruoxi},
307+
booktitle = {International Conference on Artificial Intelligence and Statistics},
308+
pages = {2557--2565},
309+
year = {2024},
310+
organization = {PMLR},
311+
url = {https://arxiv.org/abs/2401.11103}
312+
}
288313
@article{Castro.2009,
289314
title = {Polynomial calculation of the {Shapley} value based on sampling},
290315
author = {Javier Castro and Daniel G{\'o}mez and Juan Tejada},
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Nearest Neighbor Models
2+
=======================
3+
4+
Examples for data valuation using efficient explanations of nearest-neighbor models.
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
Data Valuation with Nearest Neighbor Explainers
3+
===============================================
4+
5+
This notebook shows how explainers of nearest-neighbor (NN) models can be used for Data Valuation, the task of evaluating the usefulness of individual training data points in classification problems.
6+
When explaining NN models, a game is defined by first choosing an explanation point :math:`x_\\text{explain}` and class :math:`y_\\text{explain}`; the training data points :math:`\\mathcal{D} := \\mathcal{X} \\times \\mathcal{Y}` are the game's players, and the definition of the utility :math:`\\nu(S)` of a coalition :math:`S \\subseteq \\mathcal{D}` is based on the probability of the model predicting class :math:`y_\\text{explain}` on :math:`x_\\text{explain}` if it's training data were limited to :math:`S`.
7+
"""
8+
9+
# %%
10+
# There is support for explaining the the ``KNeighborsClassifier`` model (with ``'uniform'`` or ``'distance'`` weights) and ``RadiusNeighborsClassifier`` model from the ``scikit-learn`` library.
11+
# The algorithms are based on the publications from `Jia et al. (2019) <https://doi.org/10.48550/arXiv.1908.08619/>`__,
12+
# `Wang et al. (2024) <https://doi.org/10.48550/arXiv.1908.08619>`__
13+
# and `Wang et al. (2023) <https://doi.org/10.48550/arXiv.2308.15709>`__, respectively.
14+
#
15+
# Let's start by generating a synthetic classification datset and fitting a simple `KNeighborsClassifier` to it.
16+
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
from sklearn.datasets import make_classification
20+
from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier
21+
from util_plot import plot_datasets
22+
23+
X_train, y_train = make_classification(
24+
n_samples=30,
25+
n_features=2,
26+
n_redundant=0,
27+
n_clusters_per_class=1,
28+
n_informative=2,
29+
n_classes=2,
30+
random_state=45,
31+
)
32+
33+
fig, ax = plt.subplots(figsize=(6, 6))
34+
plot_datasets(ax, X_train, y_train)
35+
36+
model = KNeighborsClassifier(n_neighbors=3)
37+
model.fit(X_train, y_train)
38+
39+
x_explain = np.array([[-0.75, -0.4]])
40+
y_explain_pred = model.predict(x_explain)[0]
41+
print(f"Prediction: class {y_explain_pred}")
42+
43+
y_explain_proba = model.predict_proba(x_explain)[0]
44+
print(f"Prediction probabilities: {y_explain_proba}")
45+
46+
47+
# %%
48+
# Using the ``KNNExplainer`` for Unweighted :math:`k`-Nearest Neighbor Models
49+
# ---------------------------------------------------------------------------
50+
#
51+
# To explain the prediction, we create an explainer for the model by passing it to the constructor of ``Explainer``, which will automatically dispatch to the adequate subclass ``KNNExplainer``.
52+
53+
from shapiq import Explainer
54+
55+
explainer = Explainer(model, class_index=y_explain_pred, index="SV", max_order=1)
56+
print(type(explainer))
57+
58+
59+
# %%
60+
# Note that we set ``class_index=y_explain_pred``, since for now, we want to quantify the contribution of the training data to the class that was actually predicted. (We could also set a different class index if we wished to see how much the data points contribute to shifting the prediction towards another class.)
61+
#
62+
# Now we can get an explanation for the prediction we saw above:
63+
64+
iv = explainer.explain(x_explain)
65+
print(iv)
66+
67+
68+
# %%
69+
# Explaining Weighted :math:`k`-Nearest Neighbor and Threshold Nearest Neighbor Models
70+
# ------------------------------------------------------------------------------------
71+
72+
# %%
73+
# There are separate explainers for weighted :math:`k`-NN and threshold NN models, which are selected automatically when an `Explainer` is instantiated with a corresponding model:
74+
75+
wknn_model = KNeighborsClassifier(n_neighbors=3, weights="distance")
76+
wknn_model.fit(X_train, y_train)
77+
wknn_explainer = Explainer(wknn_model, class_index=0, index="SV", max_order=1)
78+
print(type(wknn_explainer))
79+
80+
tnn_model = RadiusNeighborsClassifier()
81+
tnn_model.fit(X_train, y_train)
82+
tnn_explainer = Explainer(tnn_model, class_index=0, index="SV", max_order=1)
83+
print(type(tnn_explainer))
84+
85+
86+
# %%
87+
# They can be used just the same way:
88+
89+
print(wknn_explainer.explain(x_explain))
90+
print(tnn_explainer.explain(x_explain))
91+
92+
93+
# %%
94+
# Large numbers of training samples
95+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
96+
#
97+
# Since the algorithms are pretty efficient, we can run them on large sets of training data.
98+
99+
from time import time
100+
101+
102+
def print_explain_times(model, n, n_test) -> None:
103+
X_train, y_train = make_classification(
104+
n_samples=n,
105+
n_features=5,
106+
n_redundant=0,
107+
n_clusters_per_class=1,
108+
n_informative=3,
109+
n_classes=2,
110+
random_state=45,
111+
)
112+
X_test = X_train[:n_test]
113+
X_train = X_train[n_test:]
114+
y_train = y_train[n_test:]
115+
model.fit(X_train, y_train)
116+
explainer = Explainer(model, class_index=0, index="SV", max_order=1)
117+
118+
times = np.zeros((n_test,))
119+
for i, x_test in enumerate(X_test):
120+
t_start = time()
121+
explainer.explain(x_test)
122+
t_end = time()
123+
times[i] = t_end - t_start
124+
mean = np.mean(times) * 1000
125+
std = np.std(times) * 1000
126+
print(f"{explainer.__class__.__name__} on {n} samples: average {mean:.1f}±{std:.1f}ms")
127+
128+
129+
# %%
130+
# The cell below which uses the KNN explainer takes roughly 0.15 s to explain a single data point on a consumer-grade laptop with a 12th Gen Intel i5 processor.
131+
132+
print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="uniform"), n=100_000, n_test=50)
133+
134+
135+
# %%
136+
# Since the algorithm of the WKNN explainer is less efficient, featuring a quadratic runtime complexity, the number of data points needs to be limited.
137+
138+
print_explain_times(KNeighborsClassifier(n_neighbors=5, weights="distance"), n=200, n_test=10)
139+
140+
141+
# %%
142+
# The TNN algorithm, on the other hand, is faster:
143+
144+
print_explain_times(RadiusNeighborsClassifier(radius=5), n=100_000, n_test=50)
145+
146+
147+
# %%
148+
# ## Identifying corrupted training samples
149+
# -----------------------------------------
150+
#
151+
# We can estimate the usefulness of each point of a training datset by calculating Shapley values for a set of test data points and averaging the results. This will allow us to identify potentially mislabeled data points.
152+
#
153+
# First, let's create a classification datset and split it into train and test sets. We will corrupt the training data by changing the class of a few randomly selected data points.
154+
155+
from sklearn.model_selection import train_test_split
156+
157+
X, y = make_classification(
158+
n_samples=100,
159+
n_features=2,
160+
n_redundant=0,
161+
n_clusters_per_class=1,
162+
n_informative=2,
163+
n_classes=2,
164+
flip_y=0,
165+
random_state=49,
166+
class_sep=1.5,
167+
)
168+
169+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
170+
171+
y_train_corrupted = y_train.copy()
172+
n_corrupt = 7
173+
rng = np.random.default_rng(seed=43)
174+
corrupted = rng.choice(np.arange(X_train.shape[0]), size=n_corrupt, replace=False)
175+
# Since our only class indices are 0 and 1, this is a quick way to flip the class
176+
y_train_corrupted[corrupted] = 1 - y_train[corrupted]
177+
178+
fig, ax = plt.subplots(figsize=(6, 6))
179+
plot_datasets(ax, X_train, y_train_corrupted, X_test, y_test)
180+
# Mark corrupted datapoints
181+
ax.scatter(
182+
X_train[corrupted, 0],
183+
X_train[corrupted, 1],
184+
marker="o",
185+
edgecolors="#b1170c",
186+
facecolors="none",
187+
s=100,
188+
)
189+
# %%
190+
# Now, we can use the `KNNExplainer` to compute the training points' Shapley values based on the entire test dataset by averaging the Shapley values computed using each test point.
191+
192+
# Train the model with the corrupted training data
193+
model = KNeighborsClassifier(n_neighbors=5)
194+
model.fit(X_train, y_train_corrupted)
195+
196+
sv_test = np.zeros(X_train.shape[0], dtype=np.float64)
197+
198+
for x_test_current, y_test_current in zip(X_test, y_test, strict=True):
199+
explainer = Explainer(model, class_index=y_test_current, index="SV", max_order=1)
200+
iv = explainer.explain(x_test_current)
201+
sv_test += iv.to_first_order_array()
202+
203+
sv_test /= X_test.shape[0]
204+
205+
206+
# %%
207+
# We can reasonably assume that the corrupted training data points will on average make the model's prediction worse, resulting in negative Shapley values. So let's filter out just those indices where the Shapley value is below zero and compare with our original array of corrupted indices:
208+
209+
print(f"Corrupted: {np.sort(corrupted)}") # Sort for easier comparison
210+
print(f"Negative Shapley values: {np.where(sv_test < 0)[0]}")
211+
212+
213+
# %%
214+
# We have identified the set corrupted samples almost exactly.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Helper functions for plotting to be used by the notebooks."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
import numpy as np
9+
import numpy.typing as npt
10+
from matplotlib.axes import Axes
11+
12+
import matplotlib.pyplot as plt
13+
from matplotlib.lines import Line2D
14+
15+
16+
def plot_datasets(
17+
ax: Axes,
18+
X_train: npt.NDArray[np.floating],
19+
y_train: npt.NDArray[np.floating],
20+
X_test: npt.NDArray[np.floating] | None = None,
21+
y_test: npt.NDArray[np.floating] | None = None,
22+
title: str | None = None,
23+
) -> None:
24+
"""Plots train and test datasets in the same figure."""
25+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
26+
27+
if title is not None:
28+
ax.set_title(title)
29+
ax.scatter(
30+
X_train[:, 0],
31+
X_train[:, 1],
32+
c=[colors[i] for i in y_train],
33+
label="Training Points",
34+
marker="o",
35+
)
36+
if X_test is not None and y_test is not None:
37+
ax.scatter(
38+
X_test[:, 0],
39+
X_test[:, 1],
40+
c=[colors[i] for i in y_test],
41+
label="Test Points",
42+
marker="x",
43+
)
44+
45+
handles = [
46+
Line2D(
47+
[0],
48+
[0],
49+
marker="o",
50+
color="w",
51+
markerfacecolor=colors[i],
52+
markersize=10,
53+
label=f"Class {i} (Train)",
54+
)
55+
for i in set(y_train)
56+
]
57+
if y_test is not None:
58+
handles += [
59+
Line2D(
60+
[0],
61+
[0],
62+
marker="x",
63+
linewidth=0,
64+
color=colors[i],
65+
markerfacecolor=colors[i],
66+
markersize=10,
67+
label=f"Class {i} (Test)",
68+
)
69+
for i in set(y_train)
70+
]
71+
ax.legend(handles=handles, loc="upper right", title="Data Points")
72+
73+
ax.set_xlabel("Feature 1")
74+
ax.set_ylabel("Feature 2")

src/shapiq/explainer/custom_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from typing import Literal
66

77
ExplainerIndices = Literal["SV", "SII", "k-SII", "STII", "FSII", "BV", "BII", "FBII"]
8+
ValidNNExplainerIndices = Literal["SV"]
89
ValidProductKernelExplainerIndices = Literal["SV"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Explainers for nearest neighbor models."""
2+
3+
from .knn import KNNExplainer
4+
from .threshold_nn import ThresholdNNExplainer
5+
from .weighted_knn import WeightedKNNExplainer
6+
7+
__all__ = [
8+
"KNNExplainer",
9+
"ThresholdNNExplainer",
10+
"WeightedKNNExplainer",
11+
]

src/shapiq/explainer/nn/_util.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Utility function for the NormalKNNExplainer and the WeightedKNNExplainer."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import TYPE_CHECKING, Any
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Iterable, Mapping
10+
11+
from shapiq.explainer.custom_types import ValidNNExplainerIndices
12+
13+
logger = logging.getLogger()
14+
15+
16+
def warn_ignored_parameters(
17+
local_vars: Mapping[str, Any], ignored_parameter_names: Iterable[str], class_name: str
18+
) -> None:
19+
for param in ignored_parameter_names:
20+
if local_vars[param] is not None:
21+
logger.warning(
22+
"A non-None value was passed as parameter `%s` to the constructor of %s, which will be ignored.",
23+
param,
24+
class_name,
25+
)
26+
27+
28+
def assert_valid_index_and_order(index: ValidNNExplainerIndices, max_order: int) -> None:
29+
"""Check that the explainer index and max_order are valid for NN models, raise otherwise.
30+
31+
The only valid index is ``'SV'``; the only valid max. order is ``1``.
32+
33+
Args:
34+
index: The explainer index to validate.
35+
max_order: The max. order to validate.
36+
37+
Raises:
38+
ValueError: If either of the parameters does not satisfy the requirements.
39+
"""
40+
if index != "SV":
41+
msg = f"Explainer index '{index}' is invalid for nearest neighbor models. The only valid index is 'SV'."
42+
raise ValueError(msg)
43+
44+
if max_order != 1:
45+
msg = f"Explanation order of {max_order} is invalid; the only valid order is 1."
46+
raise ValueError(msg)

0 commit comments

Comments
 (0)