Skip to content

Commit 9fc3d4a

Browse files
authored
Upgrade TreeSHAP algorithm (#12179)
1 parent e0d3dfd commit 9fc3d4a

7 files changed

Lines changed: 752 additions & 319 deletions

File tree

R-package/src/Makevars.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ OBJECTS= \
8686
$(PKGROOT)/src/predictor/predictor.o \
8787
$(PKGROOT)/src/predictor/cpu_predictor.o \
8888
$(PKGROOT)/src/predictor/interpretability/shap.o \
89-
$(PKGROOT)/src/predictor/treeshap.o \
9089
$(PKGROOT)/src/tree/constraints.o \
9190
$(PKGROOT)/src/tree/param.o \
9291
$(PKGROOT)/src/tree/fit_stump.o \

R-package/src/Makevars.win.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ OBJECTS= \
8585
$(PKGROOT)/src/predictor/predictor.o \
8686
$(PKGROOT)/src/predictor/cpu_predictor.o \
8787
$(PKGROOT)/src/predictor/interpretability/shap.o \
88-
$(PKGROOT)/src/predictor/treeshap.o \
8988
$(PKGROOT)/src/tree/constraints.o \
9089
$(PKGROOT)/src/tree/param.o \
9190
$(PKGROOT)/src/tree/fit_stump.o \
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/**
2+
* Copyright 2017-2026, XGBoost Contributors
3+
*/
4+
#ifndef XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_
5+
#define XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_
6+
7+
#include <algorithm>
8+
#include <array>
9+
#include <cmath>
10+
#include <cstddef>
11+
#include <utility>
12+
#include <vector>
13+
14+
#include "xgboost/logging.h"
15+
16+
namespace xgboost::interpretability::detail {
17+
18+
constexpr double kPi = 3.141592653589793238462643383279502884;
19+
20+
template <std::size_t MaxPoints>
21+
struct EndpointQuadratureRule {
22+
std::size_t points{0};
23+
std::array<double, MaxPoints> nodes{};
24+
std::array<double, MaxPoints> weights{};
25+
};
26+
27+
inline double LegendrePolynomial(std::size_t n, double x) {
28+
double p0 = 1.0;
29+
if (n == 0) {
30+
return p0;
31+
}
32+
double p1 = x;
33+
if (n == 1) {
34+
return p1;
35+
}
36+
for (std::size_t k = 2; k <= n; ++k) {
37+
double pk =
38+
((2.0 * static_cast<double>(k) - 1.0) * x * p1 - (static_cast<double>(k) - 1.0) * p0) /
39+
static_cast<double>(k);
40+
p0 = p1;
41+
p1 = pk;
42+
}
43+
return p1;
44+
}
45+
46+
inline double LegendreDerivative(std::size_t n, double x, double pn) {
47+
auto n_d = static_cast<double>(n);
48+
return n_d * (x * pn - LegendrePolynomial(n - 1, x)) / (x * x - 1.0);
49+
}
50+
51+
template <std::size_t MaxPoints>
52+
inline EndpointQuadratureRule<MaxPoints> MakeEndpointQuadrature(std::size_t n,
53+
double convergence_eps) {
54+
CHECK_GE(n, 2);
55+
CHECK_LE(n, MaxPoints);
56+
57+
EndpointQuadratureRule<MaxPoints> rule;
58+
rule.points = n;
59+
std::vector<std::pair<double, double>> nodes_weights;
60+
nodes_weights.reserve(n);
61+
62+
for (std::size_t i = 0; i < n; ++i) {
63+
double theta = kPi * (static_cast<double>(i) + 0.75) / (static_cast<double>(n) + 0.5);
64+
double x = std::cos(theta);
65+
for (std::size_t iter = 0; iter < 64; ++iter) {
66+
auto pn = LegendrePolynomial(n, x);
67+
auto dpn = LegendreDerivative(n, x, pn);
68+
auto dx = pn / dpn;
69+
x -= dx;
70+
if (std::abs(dx) < convergence_eps) {
71+
break;
72+
}
73+
}
74+
75+
auto pn = LegendrePolynomial(n, x);
76+
auto dpn = LegendreDerivative(n, x, pn);
77+
auto w = 2.0 / ((1.0 - x * x) * dpn * dpn);
78+
double s = 0.5 * (x + 1.0);
79+
double ws = 0.5 * w;
80+
nodes_weights.emplace_back(s * s, 2.0 * s * ws);
81+
}
82+
83+
std::sort(nodes_weights.begin(), nodes_weights.end(),
84+
[](auto const &l, auto const &r) { return l.first < r.first; });
85+
for (std::size_t i = 0; i < n; ++i) {
86+
rule.nodes[i] = nodes_weights[i].first;
87+
rule.weights[i] = nodes_weights[i].second;
88+
}
89+
return rule;
90+
}
91+
92+
template <std::size_t Points>
93+
inline EndpointQuadratureRule<Points> MakeEndpointQuadrature(double convergence_eps) {
94+
return MakeEndpointQuadrature<Points>(Points, convergence_eps);
95+
}
96+
97+
} // namespace xgboost::interpretability::detail
98+
99+
#endif // XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_

0 commit comments

Comments
 (0)