[Feature] SHAP-based explainer for torch models#3049
Open
daidahao wants to merge 147 commits into
Open
Conversation
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
This ensure that the last possible index is always explained when `add_encoders` is used. Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
3 tasks
Contributor
Author
|
FYI, SHAP has just release a new minor version 0.52.0, which include major updates to binary build and distribution. I have re-run the explainer unit tests locally (overriding the Darts newer package cap) and all of them have passed. cc: @CloseChoice |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Checklist before merging this PR:
Fixes #871. Fixes #2788. Fixes #2296. Fixes #2571. Fixes #1262. Fixes #2566. Fixes #1332.
Summary
This PR:
TorchExplainerfor explaining torch models with SHAP,SKLearnExplainer,explain_single()method for explaining a single prediction instance in both explainers,(NEW) Torch Explainer
TorchExplaineris introduced forTorchForecastingModelinstances, with a feature set aligned with the SKLearn explainer:explain().explain_single().summary_plot()andforce_plot().It supports target, past covariates, future covariates, and static covariates (including component-specific/global covariates), and returns SHAP values in
SHAPExplainabilityResult/SHAPSingleExplainabilityResultobjects.Motivation
An increasing number of models in Darts are torch-based (recently #3002, #2980, #2944) and users need a consistent way to explain their forecasts.
For scikit-learn models, the existing
ShapExplainer(nowSKLearnExplainer) provides SHAP-based explanations with method selection based on model type.For torch models, we need a new explainer that can handle the different model architectures, while conforming to existing explainability API patterns.
permutationprovides general applicability and faster explanations thankernelorsampling. Users can choose other SHAP methods if desired.PLForecastingModulein a genericnn.Modulethat can be explained by these methods, in addition to the current numpy-based function wrapper.Design
TorchExplainermirrors theSKLearnExplainerAPI for consistency, withexplain(),summary_plot(), andforce_plot()methods.Implementation Details
PLForecastingModulein a numpy-compatible function which:forward()method to get predictions,SKLearnExplainerfor consistency in querying and visualization.API Reference
TorchExplainer. Other SHAP explainers have similar class signatures.PLForecastingModule._get_batch_prediction()is incorporated intoTorchExplainer._func_wrapper()for SHAP, which handles the conversion between flat numpy arrays and the torch tensors expected by the module.create_lagged_component_names()is the ground-truth for feature naming conventions in Darts.Differences to SKLearnExplainer
TorchForecastingModelvsSKLearnModel.kernel,sampling,partition,permutation; sklearn additionally supports tree/linear/additive where applicable).TorchExplainercan explain likelihood parameters of probabilistic forecasts, whileSKLearnExplainercan only explain the median (quantile) or mean (poisson) predictions.TorchExplaineruses batched tensor to prevent OOM errors, whileSKLearnExplaineruses full-size numpy arrays.Methods
explain()for horizon/component-level explanations over forecastable timestamps.explain_single()for one forecast instance (equivalent prediction context topredict(n=output_chunk_length)).summary_plot()shows distributions of feature contributions.force_plot()shows feature contributions for a specific horizon/component.Use Cases
Summary Plot
Feature-importance distribution analysis per horizon/component for torch models.
Force Plot
Local additive contribution view for a selected horizon and target component.
Explaining Multiple Instances
Batch explanations from foreground data with optional sampling controls for performance.
Explaining Single Instance
Per-instance explanation API (
explain_single()) for local interpretability.Explaining Probabilistic Forecasts
Probabilistic torch models are supported by explaining each likelihood parameter component, treating them as separate targets. This is useful for understanding how features contribute to uncertainty estimates.
(CHANGE) SKLearn Explainer
The previous
ShapExplaineris renamed and aligned with the new naming/API style.Renaming
ShapExplainer->SKLearnExplainer.ShapExplainabilityResult->SHAPExplainabilityResult.SHAPSingleExplainabilityResultforexplain_single()outputs.darts.explainabilitynow exposeSKLearnExplainer,TorchExplainer, and SHAP result classes.Bug Fixes
generate_fit_predict_encodings), improving consistency with forecasting behavior.(NEW) Explaining Single Instance
SKLearnExplainer.explain_single()is added, returning SHAP and feature values for a single prediction instance in the same style as the torch explainer.(NEW) Explainability Notebook
Added
examples/28-Explainability-examples.ipynbcovering:summary_plot()and scatter dependence plots for both explainers (same below).explain()andforce_plot()and common SHAP visualizations.explain_single()and corresponding visualizations.TorchExplainerand visualizing component-specific explanations.ShapExplainertoSKLearnExplainer.Notebook is wired into docs examples (
docs/source/examples.rst) and referenced in docs indexing.Miscellaneous
SHAPcapitalization.darts/tests/explainability/test_sklearn_explainer.pydarts/tests/explainability/test_torch_explainer.pyOther Information
ShapExplainershould migrate toSKLearnExplainer.