fix(tree): solve the TreeSHAP Vandermonde systems exactly at every depth#547
Open
42logos wants to merge 3 commits into
Open
fix(tree): solve the TreeSHAP Vandermonde systems exactly at every depth#54742logos wants to merge 3 commits into
42logos wants to merge 3 commits into
Conversation
…TreeSHAP-IQ The polynomial TreeSHAP machinery computed inv(vander(points).T) @ rhs at four call sites (TreeSHAPIQ N/N_cii/N_id matrices and LinearTreeSHAP's N_v2). The interpolation nodes are the first i entries of a depth-sized Chebyshev grid, so the Vandermonde systems are ill-conditioned even for moderate sizes: precision degrades silently from roughly size 30 (residuals > 1e-6 around 29, > 1e-3 around 37, O(1) around 45 on a depth-grid basis) and the matrix becomes exactly singular for very deep trees, which crashed with an unexplained LinAlgError (observed for trees of depth ~55-60). This change centralises the solve in tree/_numerics.solve_vandermonde(): - np.linalg.solve instead of forming an explicit inverse, - a RuntimeWarning when the condition number exceeds 1e12 (precision loss), - a least-squares fallback instead of a hard crash when the system is singular, with a warning that values may be inaccurate. Deep trees now explain (approximately) instead of crashing, and users are told when the exactness contract can no longer be honoured. Minimal reproduction: fit DecisionTreeRegressor(max_depth=60) on noise and construct LinearTreeSHAP(tree) - previously LinAlgError, now a warning.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Contributor
There was a problem hiding this comment.
Pull request overview
This PR hardens the polynomial interpolation (Vandermonde) solves used by TreeSHAPIQ and LinearTreeSHAP to avoid silent numerical corruption and prevent LinAlgError crashes on deep trees, adding coalesced diagnostics warnings and unit test coverage.
Changes:
- Introduces
src/shapiq/tree/_numerics.pyto centralize guarded Vandermonde solving with grid certification + diagnostics aggregation. - Switches
TreeSHAPIQandLinearTreeSHAPN-matrix construction to usesolve_vandermonde(...)and emit one summaryRuntimeWarningper matrix. - Adds unit tests covering fast-path equivalence, warning behavior, least-squares fallback, and an end-to-end deep-tree regression; updates
CHANGELOG.md.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/shapiq/tree/_numerics.py |
New numerics helper for certified/checked Vandermonde solves plus coalesced warning emission. |
src/shapiq/tree/treeshapiq.py |
Replaces explicit inverses with solve_vandermonde and coalesced diagnostics emission for TreeSHAPIQ N-matrices. |
src/shapiq/tree/linear/explainer.py |
Replaces explicit inverse in get_N_v2 with guarded solve + coalesced diagnostics emission. |
tests/shapiq/tests_unit/tests_explainer/tests_tree_explainer/test_tree_numerics.py |
New tests validating numerical behavior, warnings, and deep-tree regression coverage. |
CHANGELOG.md |
Documents the behavior change (new warnings, least-squares fallback) under Unreleased. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c91e070 to
49742ed
Compare
49742ed to
c053595
Compare
82effe4 to
30073e2
Compare
Replaces the guarded float64 solves with exact ones, removing all warning and least-squares fallback paths: - O(n^2) Bjorck-Pereyra dual recursion in scaled-integer fixed-point arithmetic (standard-library int), with a convergence certificate: each system is solved at increasing precision until two independent computations agree bitwise, so the result is the float64 rounding of the exact rational solution at any depth (a depth-100 grid's full prefix workload takes ~0.3 s). Nodes that collapse onto the same scaled integer at the first precision rung climb the ladder instead of being misreported as coincident. - The previous inversion drifted at the ~1e-7 level from interpolation degree ~20, returned silently wrong values from ~32, and crashed (LinAlgError) at ~60+; the exact solves carry no conditioning error at any degree. - The one remaining limit is representational: the monomial-basis N entries grow exponentially with the interpolation degree and the downstream float64 pipeline cancels them, so beyond a measured magnitude bound (degree ~29 on the default grids, pinned by a boundary test) construction raises an explanatory RepresentationLimitError instead of returning silently wrong values; for order-1 Shapley values TreeExplainer re-routes affected trees to TreeSHAPIQ when its feature-bounded degree still fits. - Tests: bitwise agreement with an exact rational-elimination oracle (including the previously rank-deficient sizes 28-45 and a clustered custom grid), the exact 29/30 accept/refuse boundary, deep-chain completeness at depths 20/24/28, refusal at 35/60 on both the LinearTreeSHAP and TreeSHAPIQ paths, re-routing for SV and order-1 SII, precision-ladder climbing for nodes finer than the first rung's resolution, and cache isolation.
30073e2 to
73618de
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #545.
What was wrong
The polynomial TreeSHAP machinery ($i \approx n/2$ ). On the default Chebyshev grids the explicit inverse drifts at the $\sim 10^{-7}$ level from interpolation degree ~20, returns silently wrong values from degree ~32 (rank-deficient to machine precision), and crashes with an unexplained
TreeSHAPIQ,LinearTreeSHAP) builds its interpolation N matrices by explicitly inverting Vandermonde systems,inv(np.vander(D[:i]).T), over prefixes of the interpolation grid. These systems are severely ill-conditioned in double precision — the conditioning is non-monotonic in the prefix length and peaks at interior prefixes (LinAlgErrornear degree 60 — all without a single warning.Fix: solve the systems exactly
The ill-conditioning is purely a floating-point artifact — the interpolation nodes are distinct, so the systems are exactly solvable over the rationals. The new
solve_vandermondetherefore returns the float64 rounding of the exact solution at every depth, using only the standard library:int, scaleValueError.(grid, rhs): every tree of equal depth in an ensemble issues identical solves, so a forest pays one tree's worth of work.build_n_matrixhelper replaces the four duplicated construction loops and checks each row against the representation limit as soon as it is computed.This removes all conditioning error from the solves: up to the representation limit below, the returned coefficients are the correctly rounded exact solutions (values for degrees above ~20 may therefore shift at the$\sim 10^{-7}$ level relative to previous releases).
The remaining limit is representational — and is now enforced honestly
The exact coefficients grow exponentially with the interpolation degree (e.g.$\max|N| \approx 10^{13}$ at degree 36), and the downstream float64 pipeline consumes them in inner products whose cancellation error tracks $\max|N| \cdot 10^{-13}$ (measured on chain trees of depth 20–40, within one order of magnitude). Beyond an empirically calibrated bound (
max|N| > 3e10, expected loss in the 0.3–3 % range) construction raisesRepresentationLimitError(aValueError) instead of returning silently wrong values. On the default grids the accept/refuse boundary is exactly degree 29/30 forLinearTreeSHAPand 25/26 forTreeSHAPIQ(whose identity N matrix, built for every index, saturates first) — both pinned by tests.TreeExplainerhandles the limit transparently where possible:LinearTreeSHAP's degree is the full tree depth, whileTreeSHAPIQ's ismin(depth, features in the tree), so when only the former trips the limit (deep trees over few features), order-1 explanations are re-routed toTreeSHAPIQ, which computes the same Shapley values at a feasible degree. When both degrees exceed the limit (e.g. a depth-30 tree that actually uses 30+ features), the error propagates: the user gets an explanatory exception instead of the silently wrong numbers previous releases returned.Curing the representation limit itself requires rewriting the polynomial algebra in a value-space (barycentric) basis — out of scope here and tracked as follow-up work.
How this resolves the symptoms reported in #545
RepresentationLimitError(with the re-route above)LinAlgErrorcrashRepresentationLimitErrorRepresentationLimitError; nothing remains in the silent-failure classNote that the issue proposed a weaker fix (certify the grid, degrade to least squares, emit a coalesced
RuntimeWarning). This PR deliberately goes further: exact-or-raise instead of warn-and-approximate — within the representable range the values carry no conditioning error at all, and beyond it no approximate values are returned.Performance
Measured on the N-matrix build for
chebpts2grids (min of 7 runs):invSingle trees pay a one-time millisecond-scale cost; for ensembles the memoized path is slightly faster than the old inverse, so a 100-tree forest builds its N matrices in about the same time as before (~34 ms vs ~38 ms at depth 12).
How has this been tested
test_tree_numerics.py(34 tests, all passing; the fulltests_tree_explainersuite passes with 253 passed / 1 skipped):inv(vander)-based construction is recomputed inline at grid size 32 with the library's own right-hand side and shown to be off by an absolute error of ~2e2 (fatal after downstream cancellation) where the new solver is bitwise exact; and an end-to-end sklearn scenario (a one-hot-style fit forcing a depth-39 chain over 39 features) must raiseRepresentationLimitErrorwhere previous releases silently returned values with a completeness error of ~2e5. Both tests fail on the pre-fix code by construction.LinearTreeSHAP); degree 25 constructs and 26 is refused (TreeSHAPIQ) — the figures quoted in the error message and docs cannot silently decouple from behavior.RepresentationLimitError) at depths 35/60 — both on a deterministic chain-tree generator that cannot silently skip. The gate is exercised on theLinearTreeSHAPpath and insideTreeSHAPIQ's own N matrices (higher-order indices).LinearTreeSHAPbut explained exactly throughTreeExplainer(falls back toTreeSHAPIQ, completeness toindex="SV"and order-1"SII".Behavioral notes (also in the CHANGELOG under "Changed")
LinAlgError; both now raiseRepresentationLimitErrorat construction (with theTreeExplainerre-routing above).