Skip to content

Commit 0888129

Browse files
committed
Refactor and optimize CPU QuadratureSHAP
1 parent 4203d66 commit 0888129

7 files changed

Lines changed: 374 additions & 378 deletions

File tree

doc/parameter.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ Parameters for Tree Booster
192192
* ``shap_algorithm`` string [default= ``treeshap``]
193193

194194
- CPU algorithm used for ``pred_contribs`` with tree boosters.
195-
- Choices: ``treeshap``, ``v6``.
195+
- Choices: ``treeshap``, ``quadratureshap``.
196196

197197
- ``treeshap``: Existing exact TreeSHAP implementation.
198-
- ``v6``: Quadrature plus telescoping SHAP implementation for CPU prediction.
198+
- ``quadratureshap``: Quadrature plus telescoping SHAP implementation for CPU prediction.
199199

200200
* ``scale_pos_weight`` [default=1]
201201

src/predictor/cpu_predictor.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ class CPUPredictor : public Predictor {
748748
void Configure(Args const &cfg) override {
749749
for (auto const &kv : cfg) {
750750
if (kv.first == "shap_algorithm") {
751-
CHECK(kv.second == "treeshap" || kv.second == "v6")
751+
CHECK(kv.second == "treeshap" || kv.second == "quadratureshap")
752752
<< "Unknown SHAP algorithm: " << kv.second;
753753
shap_algorithm_ = kv.second;
754754
}
@@ -878,9 +878,9 @@ class CPUPredictor : public Predictor {
878878
if (approximate) {
879879
interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model,
880880
ntree_limit, tree_weights);
881-
} else if (shap_algorithm_ == "v6" && condition == 0 && condition_feature == 0) {
882-
interpretability::cpu_impl::V6ShapValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit,
883-
tree_weights);
881+
} else if (shap_algorithm_ == "quadratureshap" && condition == 0 && condition_feature == 0) {
882+
interpretability::cpu_impl::QuadratureShapValues(this->ctx_, p_fmat, out_contribs, model,
883+
ntree_limit, tree_weights);
884884
} else {
885885
interpretability::ShapValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit,
886886
tree_weights, condition, condition_feature);

src/predictor/interpretability/shap.cc

Lines changed: 262 additions & 88 deletions
Large diffs are not rendered by default.

src/predictor/interpretability/shap.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector<float>* ou
1919
gbm::GBTreeModel const& model, bst_tree_t tree_end,
2020
std::vector<float> const* tree_weights, int condition, unsigned condition_feature);
2121

22-
void V6ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
23-
gbm::GBTreeModel const& model, bst_tree_t tree_end,
24-
std::vector<float> const* tree_weights);
22+
void QuadratureShapValues(Context const* ctx, DMatrix* p_fmat,
23+
HostDeviceVector<float>* out_contribs, gbm::GBTreeModel const& model,
24+
bst_tree_t tree_end, std::vector<float> const* tree_weights);
2525

2626
void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat,
2727
HostDeviceVector<float>* out_contribs, gbm::GBTreeModel const& model,

src/predictor/treeshap.cc

Lines changed: 0 additions & 232 deletions
This file was deleted.

src/predictor/treeshap.h

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)