Skip to content

Commit 438f7d1

Browse files
SHAP Support + Various Bug Fixes (#45)
* fix `ValueError: y_true and y_pred contain different number of classes` * use number instead of str * Set default LogisticRegression solver to saga to make it working for RFE * Add a script to inspect and plot output db * Add `y_true_collector` and `y_pred_proba_collector` metrics To collect true binary labels and predicted probabilities for ROC curve generation * Move getting the best replicate into a function * Remove empty lines * Return `best_models_dict` by `get_best_replicate` * Plot ROC * Keep only the models we are interested in * Improve figure titles * Plot the mean ± std ROC curve for each model across replicates. * Unify titles across figures * include metric (e.g., 'balanced_accuracy (test)' or 'balanced_accuracy (validate)') to ROC figure fnames * Add max_iter parameter to log_reg configuration * Add proper argparse * Get unique targets dynamically and iterate over them * Initial commit; added SHAP value calculation to the list of available metrics. * Swapped SHAP values from list to dict (bound by feature name) * Fixed error when a model cannot natively be parsed by SHAP. * Added new "VarianceDrop" data hook, allowing low-variance features to be dropped as part of a trial's run. * Updated iris testing dataset + config with new encoders. * Added Jupyter Notebooks to the git ignore, as we occasionally use them for visual validation of tests. * Added catch for homogeneity when running SHAP tests. * Removed "inspect_output_db", as it is too specific to Jan's analysis. * Pinned to pre-3.0 version of Pandas until `dtype` issues can be addressed. --------- Co-authored-by: valosekj <jan.valosek@upol.cz> Co-authored-by: Jan Valosek <39456460+valosekj@users.noreply.github.com>
1 parent 20d6297 commit 438f7d1

10 files changed

Lines changed: 370 additions & 161 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*__pycache__/
88
.vscode/
99
.DS_Store
10+
/.ipynb_checkpoints/
1011

1112
# Environments
1213
.env

data/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _decorator(cls: Type[DataHook]):
3030
# TODO: Find a more elegant way to do this
3131
from data.hooks.feature_selection import (
3232
SampleNullityDrop, FeatureNullityDrop, ExplicitDrop, ExplicitKeep, PrincipalComponentAnalysis,
33-
RecursiveFeatureElimination
33+
RecursiveFeatureElimination, VarianceDrop
3434
)
3535
from data.hooks.imputation import SimpleImputation
3636
from data.hooks.encoding import OneHotEncoding

data/hooks/feature_selection.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
from optuna import Trial
88
from sklearn.decomposition import PCA
9-
from sklearn.feature_selection import RFE
9+
from sklearn.feature_selection import RFE, VarianceThreshold
1010
from sklearn.linear_model import LogisticRegression
1111

1212
from config.utils import default_as, is_float, is_list, parse_data_config_entry
@@ -135,6 +135,81 @@ def run(self, x: BaseDataManager, y: Optional[BaseDataManager] = None) -> BaseDa
135135
return x.drop_features(drop_idx)
136136

137137

138+
## Feature selection by homogeneity
139+
@registered_data_hook("drop_low_variance")
140+
class VarianceDrop(FittedDataHook):
141+
"""
142+
Thin wrapper for SciKit-Learn's VarianceThreshold class, for use as a data hook within MOOP.
143+
144+
Runs additional checks on top of the default implementation provided by SciKit-Learn:
145+
* Ensures the resulting dataset always contains at least 1 feature.
146+
147+
Example usage:
148+
{
149+
"type": "drop_low_variance",
150+
"threshold": 0.1
151+
}
152+
"""
153+
def __init__(self, config, **kwargs):
154+
# TODO: make this tunable
155+
super().__init__(config, **kwargs)
156+
157+
# Get the variance threshold
158+
threshold = parse_data_config_entry(
159+
"threshold", config,
160+
default_as(0.0, self.logger), is_float(self.logger)
161+
)
162+
163+
# Build the wrapped VarianceThreshold object
164+
self.threshold = threshold
165+
self.selected_features: list[str] | None = None
166+
167+
@classmethod
168+
def from_config(cls, config: dict, logger: Logger = Logger.root) -> Self:
169+
return cls(config=config, logger=logger)
170+
171+
def run(self, x: BaseDataManager, y: Optional[BaseDataManager] = None) -> BaseDataManager:
172+
# If x contains only one feature already, just return that feature, as RFE has a stroke otherwise
173+
if x.n_features() == 1:
174+
self.logger.warning("Only one feature in the dataset was found; "
175+
"dropping any further would result in a null dataset."
176+
"Original (unmodified) dataset returned instead.")
177+
return x
178+
179+
# Fit the model to the dataset
180+
vt = VarianceThreshold(threshold=self.threshold)
181+
vt.fit(x.as_array(), np.ravel(y.as_array())) # Ravel prevents some warning spam
182+
183+
# Select only the features with variance less than the threshold
184+
self.selected_features = vt.get_feature_names_out(x.features())
185+
186+
# Ensure that at least one feature was kept
187+
if self.selected_features.shape[0] < 1:
188+
# Find the
189+
highest_var = np.max(vt.variances_)
190+
highest_var_feature = list(x.features())[np.argmax(vt.variances_)]
191+
self.selected_features = [highest_var_feature]
192+
self.logger.warning(
193+
f"Low-variance filter almost dropped all features; kept highest variance "
194+
f"feature ({highest_var_feature}, variance {highest_var}) alone to prevent crash!"
195+
)
196+
197+
# Return the copy of x containing only these features
198+
x_out = x.get_features(self.selected_features)
199+
return x_out
200+
201+
def run_fitted(self, x_train: BaseDataManager, x_test: Optional[BaseDataManager],
202+
y_train: Optional[BaseDataManager] = None, y_test: Optional[BaseDataManager] = None) -> \
203+
tuple[BaseDataManager, BaseDataManager]:
204+
# Run the fitted analysis first
205+
train_out = self.run(x_train, y_train)
206+
207+
# Use the same set of features to filter the x_test set
208+
test_out = x_test.get_features(self.selected_features)
209+
210+
return train_out, test_out
211+
212+
138213
### Principal Component Analysis ###
139214
@registered_data_hook("principal_component_analysis")
140215
class PrincipalComponentAnalysis(Tunable, FittedDataHook):
@@ -263,8 +338,7 @@ def from_config(cls, config: dict, logger: Logger = Logger.root) -> Self:
263338
def tune(self, trial: Trial):
264339
self.prop_tuner.tune(trial)
265340
# Generate the new backing model based on this setup
266-
# TODO: Generalize this to work with continuous targets as well
267-
new_lor = LogisticRegression()
341+
new_lor = LogisticRegression(solver='saga')
268342
self.backing_rfe = RFE(estimator=new_lor, n_features_to_select=self.prop_tuner.value)
269343

270344
def tunable_params(self) -> list[TunableParam]:

environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ dependencies:
77
- ca-certificates
88
- openssl
99
- scikit-learn
10-
- pandas
10+
- pandas<3
1111
- pytest
12+
- shap

study/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
"sk_f1_weighted_avg": sk_f1_weighted_avg,
1818
"sk_f1_perclass": sk_f1_perclass,
1919
"importance_by_permutation": importance_by_permutation,
20+
"shap_additive": shap_additive,
2021
"correct_samples": correct_samples,
21-
"incorrect_samples": incorrect_samples
22+
"incorrect_samples": incorrect_samples,
23+
"y_true_collector": y_true_collector,
24+
"y_pred_proba_collector": y_pred_proba_collector
2225
}

study/metrics.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""
22
Metric-reporting closures for use in this framework.
33
"""
4+
import sys
5+
46
import numpy as np
7+
import shap
58
from sklearn.inspection import permutation_importance
69
from sklearn.metrics import balanced_accuracy_score, log_loss, roc_auc_score, precision_score, recall_score, f1_score
710

@@ -18,7 +21,8 @@ def clean_val_for_db(val):
1821
def sk_log_loss(manager: OptunaModelManager, x: BaseDataManager, y: BaseDataManager):
1922
# Log Loss
2023
py = manager.predict_proba(x.as_array())
21-
return log_loss(y.as_array(), py)
24+
y_labels = [i for i in range(py.shape[1])]
25+
return log_loss(y.as_array(), py, labels=y_labels)
2226

2327
def sk_balanced_accuracy(manager: OptunaModelManager, x: BaseDataManager, y: BaseDataManager):
2428
# Balanced Accuracy
@@ -101,6 +105,95 @@ def importance_by_permutation(manager: OptunaModelManager, x: BaseDataManager, y
101105
importance_vals = clean_val_for_db(importance_vals)
102106
return importance_vals
103107

108+
def shap_additive(manager: OptunaModelManager, x: BaseDataManager, _: BaseDataManager):
109+
"""
110+
To restore the (raw) values for a given run, run the following snippet:
111+
112+
```
113+
from io import StringIO
114+
115+
# You can omit the Numpy import if you manually
116+
# parse the inner string in the list comp
117+
import numpy as np
118+
119+
# This is the SHAP value within the database you want to parse
120+
val = ...
121+
122+
# Strip the brackets first
123+
val = val.strip("{").strip("}")
124+
# Split by commas
125+
entry_strs = val.split(", ")
126+
# Split the dataset by feature name
127+
shap_map = dict()
128+
for entry_str in entry_strs:
129+
# Split the text along the colon to get the feature label back
130+
feature_label, shap_value_str = entry_str.split(": ")
131+
# Parse the shap value string back into numeric form
132+
shap_vals = [list(np.fromstring(x, sep=" ")) for x in shap_value_str.split("\n")]
133+
# Add it to the map
134+
shap_map[feature_label] = shap_vals
135+
```
136+
137+
Each entry in `shap_map` will be Numpy array of with the following dimensions:
138+
* n is the number of samples in the input dataset (train, validate, or test), and
139+
* c is the number of categorical classes used during training;
140+
If this is binary classification, or a continuous target, c=1.
141+
142+
For categorical targets with more than 2 classes, each class is treated as
143+
unique feature by SHAP for the purpose of calculating SHAP values.
144+
145+
TODO: Save the `shap_values` directly via pickle into a SQLite blob
146+
"""
147+
# Initialize the explainer, using the x data as both the mask and feature list
148+
x_arr = x.as_array()
149+
model = manager.get_model()
150+
151+
if np.unique(x_arr).shape[0] < 2:
152+
# SHAP cannot run on a dataset which is entirely homogenous;
153+
# return early to avoid an error
154+
return "NULL"
155+
156+
try:
157+
# Default to the "generic" explainer
158+
explainer = shap.Explainer(
159+
model, x_arr, feature_names=x.features()
160+
)
161+
# Calculate the Shapley values from this dataset
162+
shap_values = explainer(x_arr)
163+
except TypeError as err:
164+
# If that failed, try to use the model's "predict" function instead
165+
if hasattr(model, "predict"):
166+
explainer = shap.Explainer(
167+
model.predict, x_arr, feature_names=x.features()
168+
)
169+
# Calculate the Shapley values from this dataset
170+
shap_values = explainer(x_arr)
171+
else:
172+
raise err
173+
174+
shap_list = list()
175+
for i, v in enumerate(shap_values.feature_names):
176+
# SHAP auto-reduces the shape of its features if it is targeting
177+
# a binary classification OR a continuous metric
178+
if len(shap_values.values.shape) < 3:
179+
val_str = np.array2string(shap_values.values[:, i], max_line_width=sys.maxsize, threshold=sys.maxsize)
180+
else:
181+
val_str = np.array2string(shap_values.values[:, i, :], max_line_width=sys.maxsize, threshold=sys.maxsize)
182+
# Remove the brackets; despite Numpy adding them, it cannot parse them after...
183+
val_str = val_str.replace("[", "").replace("]", "")
184+
val_str = f"{v}: {val_str}"
185+
shap_list.append(val_str)
186+
187+
# This nonsense is required because Python maps
188+
# "/n" to "//n" if you string convert a dict;
189+
# why the hell does it do that?!?!?
190+
full_str = "{"
191+
full_str += ", ".join(shap_list)
192+
full_str += "}"
193+
194+
# Return the result to be saved
195+
return full_str
196+
104197

105198
""" Sample Reporting """
106199
def correct_samples(manager: OptunaModelManager, x: BaseDataManager, y: BaseDataManager):
@@ -133,4 +226,16 @@ def incorrect_samples(manager: OptunaModelManager, x: BaseDataManager, y: BaseDa
133226
# Strip quotation marks from the result so the DB backend doesn't explode
134227
bad_samples = clean_val_for_db(bad_samples)
135228

136-
return bad_samples
229+
return bad_samples
230+
231+
""" ROC Curve """
232+
def y_true_collector(_: OptunaModelManager, __: BaseDataManager, y: BaseDataManager):
233+
""" Collects the true binary labels for ROC curve generation. """
234+
return clean_val_for_db(list(y.as_array().flatten()))
235+
236+
def y_pred_proba_collector(manager: OptunaModelManager, x: BaseDataManager, _: BaseDataManager):
237+
""" Collects predicted probabilities for the positive class. """
238+
py = manager.predict_proba(x.as_array())
239+
if py.shape[1] != 2:
240+
raise ValueError(f"Expected binary classification with two probability columns; found {py.shape[1]}.")
241+
return clean_val_for_db(list(py[:, 1])) # Probabilities for the positive class

testing/iris_data/iris_config.json

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
}
1616
],
1717
"post_split_hooks": [
18+
{
19+
"type": "imputation_simple",
20+
"strategy": "most_frequent",
21+
"features": ["color", "flower_category", "is_flower", "size"]
22+
},
1823
{
1924
"type": "one_hot_encode",
20-
"features": ["color", "flower_category"]
25+
"features": ["color", "flower_category", "is_flower"]
2126
},
2227
{
2328
"type": "ladder_encode",
@@ -33,6 +38,19 @@
3338
"type": "standard_scaling",
3439
"run_per_cross": true
3540
},
41+
{
42+
"type": "drop_low_variance",
43+
"threshold": 0.0
44+
},
45+
{
46+
"type": "principal_component_analysis",
47+
"proportion": {
48+
"label": "pca_feature_proportion",
49+
"type": "float",
50+
"low": 0.1,
51+
"high": 0.9
52+
}
53+
},
3654
{
3755
"type": "recursive_feature_elimination",
3856
"proportion": {

0 commit comments

Comments
 (0)