diff --git a/demo/guide-python/quadratureshap_rapids_benchmark.py b/demo/guide-python/quadratureshap_rapids_benchmark.py new file mode 100644 index 000000000000..82b5aceed033 --- /dev/null +++ b/demo/guide-python/quadratureshap_rapids_benchmark.py @@ -0,0 +1,605 @@ +"""RAPIDS-style SHAP benchmark for TreeSHAP and QuadratureSHAP. + +This benchmark keeps the basic structure of the RAPIDS GPUTreeShap benchmark while +benchmarking four explanation paths from the current XGBoost worktree: + +- CPU TreeSHAP +- CPU QuadratureTreeSHAP +- GPU TreeSHAP +- GPU QuadratureTreeSHAP + +It supports additive SHAP values and SHAP interactions and emits both a model-metadata table +and a timing table. +""" + +# pylint: disable=missing-class-docstring,missing-function-docstring,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,too-many-locals,broad-exception-caught,no-member + +from __future__ import annotations + +import argparse +import gc +import json +import multiprocessing as mp +import statistics +import time +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +import numpy as np +import pandas as pd +import xgboost as xgb +from sklearn import datasets + + +@dataclass(frozen=True) +class TestDataset: + name: str + objective: str + X: object + y: np.ndarray + + def set_params(self, params: dict[str, object]) -> dict[str, object]: + params["objective"] = self.objective + if self.objective == "multi:softmax": + params["num_class"] = int(np.max(self.y) + 1) + return params + + def train_dmatrix(self) -> xgb.DMatrix: + return xgb.QuantileDMatrix(self.X, self.y, enable_categorical=True) + + def test_input(self, num_rows: int, seed: int) -> object: + rs = np.random.RandomState(seed) + row_idx = rs.randint(0, self.X.shape[0], size=num_rows) + if hasattr(self.X, "iloc"): + return self.X.iloc[row_idx, :] + return self.X[row_idx, :] + + def test_dmatrix(self, num_rows: int, seed: int) -> xgb.DMatrix: + return xgb.DMatrix(self.test_input(num_rows, seed), enable_categorical=True) + + +@dataclass(frozen=True) +class ModelSpec: + suffix: str + num_rounds: int + max_depth: int + grow_policy: str | None = None + max_leaves: int | None = None + + def training_params(self) -> dict[str, object]: + params: dict[str, object] = { + "tree_method": "hist", + "device": "cuda", + "eta": 0.01, + "max_depth": self.max_depth, + } + if self.grow_policy is not None: + params["grow_policy"] = self.grow_policy + if self.max_leaves is not None: + params["max_leaves"] = self.max_leaves + return params + + +MODEL_SPECS = { + "small": ModelSpec("small", 10, 6), + "large": ModelSpec("large", 1000, 16), + # "sparse" here means a LightGBM-style leaf-wise tree shape rather than sparse input storage. + "sparse": ModelSpec("sparse", 100, 0, grow_policy="lossguide", max_leaves=512), +} + + +@dataclass(frozen=True) +class Model: + name: str + dataset: TestDataset + spec: ModelSpec + booster: xgb.Booster + trees: int + leaves: int + average_depth: float + mean_max_depth: float + max_max_depth: int + mean_nodes: float + mean_leaves: float + + +@lru_cache(maxsize=1) +def fetch_adult() -> tuple[object, np.ndarray]: + x, y = datasets.fetch_openml("adult", return_X_y=True) + y_binary = np.array([y_i != "<=50K" for y_i in y]) + return x, y_binary + + +@lru_cache(maxsize=1) +def fetch_fashion_mnist() -> tuple[object, np.ndarray]: + x, y = datasets.fetch_openml("Fashion-MNIST", return_X_y=True) + return x, y.astype(np.int64) + + +@lru_cache(maxsize=1) +def get_test_datasets() -> tuple[TestDataset, ...]: + cov_x, cov_y = datasets.fetch_covtype(return_X_y=True) + cal_x, cal_y = datasets.fetch_california_housing(return_X_y=True) + return ( + TestDataset("adult", "binary:logistic", *fetch_adult()), + TestDataset("covtype", "multi:softmax", cov_x, cov_y.astype(np.int64)), + TestDataset( + "cal_housing", + "reg:squarederror", + cal_x.astype(np.float32), + cal_y.astype(np.float32), + ), + TestDataset("fashion_mnist", "multi:softmax", *fetch_fashion_mnist()), + ) + + +def train_model(dataset: TestDataset, spec: ModelSpec) -> xgb.Booster: + dtrain = dataset.train_dmatrix() + params = spec.training_params() + params = dataset.set_params(params) + return xgb.train( + params, + dtrain, + spec.num_rounds, + evals=[(dtrain, "train")], + verbose_eval=False, + ) + + +def tree_stats(model: xgb.Booster) -> dict[str, float]: + dump = model.get_dump(dump_format="json", with_stats=True) + + def walk(node: dict, depth: int = 0) -> tuple[int, int, int]: + children = node.get("children", []) + if not children: + return depth, 1, 1 + max_depth = depth + node_count = 1 + leaf_count = 0 + for child in children: + child_depth, child_nodes, child_leaves = walk(child, depth + 1) + max_depth = max(max_depth, child_depth) + node_count += child_nodes + leaf_count += child_leaves + return max_depth, node_count, leaf_count + + max_depths: list[int] = [] + node_counts: list[int] = [] + leaf_counts: list[int] = [] + for tree_json in dump: + tree = json.loads(tree_json) + max_depth, nodes, leaves = walk(tree) + max_depths.append(max_depth) + node_counts.append(nodes) + leaf_counts.append(leaves) + + return { + "trees": len(dump), + "leaves": int(sum(leaf_counts)), + "average_depth": float(statistics.mean(max_depths)), + "mean_max_depth": float(statistics.mean(max_depths)), + "max_max_depth": int(max(max_depths)), + "mean_nodes": float(statistics.mean(node_counts)), + "mean_leaves": float(statistics.mean(leaf_counts)), + } + + +def get_models(model_filter: str) -> list[Model]: + models: list[Model] = [] + for dataset in get_test_datasets(): + for spec in MODEL_SPECS.values(): + model_name = f"{dataset.name}-{spec.suffix}" + if model_filter not in {"all", spec.suffix} and model_filter != model_name: + continue + print(f"Training {model_name}") + booster = train_model(dataset, spec) + stats = tree_stats(booster) + models.append( + Model( + name=model_name, + dataset=dataset, + spec=spec, + booster=booster, + trees=int(stats["trees"]), + leaves=int(stats["leaves"]), + average_depth=float(stats["average_depth"]), + mean_max_depth=float(stats["mean_max_depth"]), + max_max_depth=int(stats["max_max_depth"]), + mean_nodes=float(stats["mean_nodes"]), + mean_leaves=float(stats["mean_leaves"]), + ) + ) + return models + + +def predict_with_algorithm( + booster: xgb.Booster, + dtest: xgb.DMatrix, + device: str, + algorithm: str, + interactions: bool, +) -> np.ndarray: + params: dict[str, object] = {"device": device} + if algorithm == "quadratureshap": + params["shap_algorithm"] = "quadratureshap" + params["quadratureshap_points"] = 8 + else: + params["shap_algorithm"] = "treeshap" + booster.set_param(params) + if interactions: + return np.asarray(booster.predict(dtest, pred_interactions=True)) + return np.asarray(booster.predict(dtest, pred_contribs=True)) + + +def _benchmark_case_worker( + queue: mp.Queue, + booster_raw: bytes, + x_test: object, + device: str, + algorithm: str, + interactions: bool, + niter: int, + margin: np.ndarray | None, +) -> None: + try: + booster = xgb.Booster() + booster.load_model(bytearray(booster_raw)) + dtest = xgb.DMatrix(x_test, enable_categorical=True) + pred = predict_with_algorithm(booster, dtest, device, algorithm, interactions) + if interactions: + additive = predict_with_algorithm( + booster, dtest, device, algorithm, interactions=False + ) + row_sums = np.sum(pred, axis=pred.ndim - 1) + metrics = { + "max_row_sum_err": float(np.max(np.abs(row_sums - additive))), + "mean_row_sum_err": float(np.mean(np.abs(row_sums - additive))), + "max_asymmetry": float( + np.max(np.abs(pred - np.swapaxes(pred, -1, -2))) + ), + } + else: + assert margin is not None + summed = np.sum(pred, axis=pred.ndim - 1) + metrics = { + "max_additivity_err": float(np.max(np.abs(summed - margin))), + "mean_additivity_err": float(np.mean(np.abs(summed - margin))), + } + + samples = [] + for _ in range(niter): + t0 = time.perf_counter() + predict_with_algorithm(booster, dtest, device, algorithm, interactions) + samples.append(time.perf_counter() - t0) + queue.put( + { + "mean_time_s": float(np.mean(samples)), + "std_time_s": float(np.std(samples)), + "error": None, + **metrics, + } + ) + except Exception as err: # noqa: BLE001 + queue.put( + { + "mean_time_s": None, + "std_time_s": None, + "max_additivity_err": None, + "mean_additivity_err": None, + "max_row_sum_err": None, + "mean_row_sum_err": None, + "max_asymmetry": None, + "error": str(err).splitlines()[0], + } + ) + + +def run_case_with_timeout( + booster: xgb.Booster, + x_test: object, + device: str, + algorithm: str, + interactions: bool, + niter: int, + margin: np.ndarray | None, + timeout_seconds: float | None, +) -> dict[str, object]: + if timeout_seconds is None: + queue: mp.Queue = mp.Queue() + _benchmark_case_worker( + queue, + bytes(booster.save_raw()), + x_test, + device, + algorithm, + interactions, + niter, + margin, + ) + result = queue.get() + queue.close() + return result + + ctx = mp.get_context("spawn") + queue = ctx.Queue() + proc = ctx.Process( + target=_benchmark_case_worker, + args=( + queue, + bytes(booster.save_raw()), + x_test, + device, + algorithm, + interactions, + niter, + margin, + ), + ) + proc.start() + proc.join(timeout_seconds) + if proc.is_alive(): + proc.terminate() + proc.join() + queue.close() + return { + "mean_time_s": None, + "std_time_s": None, + "max_additivity_err": None, + "mean_additivity_err": None, + "max_row_sum_err": None, + "mean_row_sum_err": None, + "max_asymmetry": None, + "error": f"DNF: exceeded {timeout_seconds:g}s", + } + if queue.empty(): + queue.close() + return { + "mean_time_s": None, + "std_time_s": None, + "max_additivity_err": None, + "mean_additivity_err": None, + "max_row_sum_err": None, + "mean_row_sum_err": None, + "max_asymmetry": None, + "error": "DNF: worker exited without result", + } + result = queue.get() + queue.close() + return result + + +def check_accuracy( + booster: xgb.Booster, + dtest: xgb.DMatrix, + device: str, + algorithm: str, + pred: np.ndarray, + margin: np.ndarray, + interactions: bool, +) -> dict[str, float]: + if interactions: + additive = predict_with_algorithm( + booster, dtest, device, algorithm, interactions=False + ) + row_sums = np.sum(pred, axis=pred.ndim - 1) + return { + "max_row_sum_err": float(np.max(np.abs(row_sums - additive))), + "mean_row_sum_err": float(np.mean(np.abs(row_sums - additive))), + "max_asymmetry": float(np.max(np.abs(pred - np.swapaxes(pred, -1, -2)))), + } + + summed = np.sum(pred, axis=pred.ndim - 1) + return { + "max_additivity_err": float(np.max(np.abs(summed - margin))), + "mean_additivity_err": float(np.mean(np.abs(summed - margin))), + } + + +def benchmark_model( + model: Model, + x_test: object, + dtest: xgb.DMatrix, + niter: int, + interactions: bool, + timeout_seconds: float | None, +) -> tuple[dict[str, object], list[dict[str, object]]]: + margin = model.booster.predict(dtest, output_margin=True) + details: list[dict[str, object]] = [] + result_row = { + "model": model.name, + "test_rows": dtest.num_row(), + "TreeSHAP": None, + "QuadratureTreeSHAP": None, + "GPUTreeShap": None, + "QuadratureTreeSHAP (GPU)": None, + "QuadratureTreeSHAP Speedup": None, + "QuadratureTreeSHAP (GPU) Speedup": None, + } + + for algorithm in ["treeshap", "quadratureshap"]: + for device in ["cpu", "cuda"]: + result = run_case_with_timeout( + model.booster, + x_test, + device, + algorithm, + interactions, + niter, + margin if not interactions else None, + timeout_seconds, + ) + details.append( + { + "model": model.name, + "algorithm": algorithm, + "device": device, + **result, + } + ) + if result["mean_time_s"] is not None: + if algorithm == "treeshap" and device == "cpu": + result_row["TreeSHAP"] = float(result["mean_time_s"]) + elif algorithm == "quadratureshap" and device == "cpu": + result_row["QuadratureTreeSHAP"] = float(result["mean_time_s"]) + elif algorithm == "treeshap" and device == "cuda": + result_row["GPUTreeShap"] = float(result["mean_time_s"]) + elif algorithm == "quadratureshap" and device == "cuda": + result_row["QuadratureTreeSHAP (GPU)"] = float( + result["mean_time_s"] + ) + gc.collect() + + if ( + result_row["TreeSHAP"] is not None + and result_row["QuadratureTreeSHAP"] is not None + ): + result_row["QuadratureTreeSHAP Speedup"] = ( + result_row["TreeSHAP"] / result_row["QuadratureTreeSHAP"] + ) + if ( + result_row["GPUTreeShap"] is not None + and result_row["QuadratureTreeSHAP (GPU)"] is not None + ): + result_row["QuadratureTreeSHAP (GPU) Speedup"] = ( + result_row["GPUTreeShap"] / result_row["QuadratureTreeSHAP (GPU)"] + ) + return result_row, details + + +def markdown_table(df: pd.DataFrame, float_fmt: str = ".6f") -> str: + headers = [str(c) for c in df.columns] + rows = [ + "| " + " | ".join(headers) + " |", + "| " + " | ".join(["---"] * len(headers)) + " |", + ] + for _, row in df.iterrows(): + formatted = [] + for value in row: + if value is None or (isinstance(value, float) and pd.isna(value)): + formatted.append("NA") + elif isinstance(value, float): + formatted.append(format(value, float_fmt)) + else: + formatted.append(str(value)) + rows.append("| " + " | ".join(formatted) + " |") + return "\n".join(rows) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="RAPIDS-style benchmark adapted for XGBoost TreeSHAP and QuadratureSHAP." + ) + parser.add_argument("--output", type=Path, required=True, help="JSON summary path") + parser.add_argument( + "--out-models", type=Path, default=None, help="CSV path for model table" + ) + parser.add_argument( + "--out-results", type=Path, default=None, help="CSV path for timing table" + ) + parser.add_argument( + "--out-markdown", type=Path, default=None, help="Markdown table path" + ) + parser.add_argument("--nrows", type=int, default=1000) + parser.add_argument("--niter", type=int, default=3) + parser.add_argument("--seed", type=int, default=432) + parser.add_argument( + "--case-timeout-seconds", + type=float, + default=None, + help="Optional per algorithm/device/model timeout. Timed-out cases are marked DNF.", + ) + parser.add_argument( + "--model", + type=str, + default="all", + help="Model filter: all, small, large, sparse, or a specific dataset-size name", + ) + parser.add_argument("--interactions", action="store_true") + args = parser.parse_args() + + models = get_models(args.model) + model_rows = [ + { + "model": model.name, + "num_rounds": model.spec.num_rounds, + "requested_max_depth": model.spec.max_depth, + "grow_policy": model.spec.grow_policy or "depthwise", + "max_leaves": model.spec.max_leaves, + "num_trees": model.trees, + "num_leaves": model.leaves, + "average_depth": model.average_depth, + "mean_max_depth": model.mean_max_depth, + "max_max_depth": model.max_max_depth, + "mean_nodes": model.mean_nodes, + "mean_leaves_per_tree": model.mean_leaves, + } + for model in models + ] + results_rows: list[dict[str, object]] = [] + details_rows: list[dict[str, object]] = [] + for model in models: + x_test = model.dataset.test_input(args.nrows, args.seed) + dtest = xgb.DMatrix(x_test, enable_categorical=True) + result_row, details = benchmark_model( + model, + x_test, + dtest, + args.niter, + args.interactions, + args.case_timeout_seconds, + ) + results_rows.append(result_row) + details_rows.extend(details) + print( + pd.DataFrame(results_rows).to_string( + index=False, float_format=lambda x: f"{x:.6f}" + ) + ) + + models_df = pd.DataFrame(model_rows) + results_df = pd.DataFrame(results_rows) + payload = { + "nrows": args.nrows, + "niter": args.niter, + "interactions": args.interactions, + "model_filter": args.model, + "model_specs": { + name: { + "num_rounds": spec.num_rounds, + "max_depth": spec.max_depth, + "grow_policy": spec.grow_policy, + "max_leaves": spec.max_leaves, + } + for name, spec in MODEL_SPECS.items() + }, + "models_table": models_df.to_dict(orient="records"), + "results_table": results_df.to_dict(orient="records"), + "details": details_rows, + } + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(json.dumps(payload, indent=2) + "\n") + if args.out_models is not None: + args.out_models.parent.mkdir(parents=True, exist_ok=True) + models_df.to_csv(args.out_models, index=False) + if args.out_results is not None: + args.out_results.parent.mkdir(parents=True, exist_ok=True) + results_df.to_csv(args.out_results, index=False) + if args.out_markdown is not None: + args.out_markdown.parent.mkdir(parents=True, exist_ok=True) + args.out_markdown.write_text( + "## Models\n\n" + + markdown_table(models_df, ".3f") + + "\n\n## Results\n\n" + + markdown_table(results_df, ".6f") + + "\n" + ) + + print("Models:") + print(models_df.to_string(index=False)) + print("Results:") + print(results_df.to_string(index=False, float_format=lambda x: f"{x:.6f}")) + + +if __name__ == "__main__": + main() diff --git a/doc/parameter.rst b/doc/parameter.rst index 46891cbb9736..e1f4e5f26279 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -189,6 +189,19 @@ Parameters for Tree Booster - ``approx``: Approximate greedy algorithm using quantile sketch and gradient histogram. - ``hist``: Faster histogram optimized approximate greedy algorithm. +* ``shap_algorithm`` string [default= ``treeshap``] + + - CPU algorithm used for ``pred_contribs`` with tree boosters. + - Choices: ``treeshap``, ``quadratureshap``. + + - ``treeshap``: Existing exact TreeSHAP implementation. + - ``quadratureshap``: Quadrature plus telescoping SHAP implementation for CPU prediction. + +* ``quadratureshap_points`` [default= ``8``] + + - Experimental fixed quadrature size used by CPU and GPU ``quadratureshap`` variants. + - Current supported value: ``8``. + * ``scale_pos_weight`` [default=1] - Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: ``sum(negative instances) / sum(positive instances)``. See :doc:`Parameters Tuning ` for more discussion. Also, see Higgs Kaggle competition demo for examples: `R `_, `py1 `_, `py2 `_, `py3 `_. diff --git a/experiments/2026-04-21-fashion-mnist-efficiency-sweep/README.md b/experiments/2026-04-21-fashion-mnist-efficiency-sweep/README.md new file mode 100644 index 000000000000..682db19388c9 --- /dev/null +++ b/experiments/2026-04-21-fashion-mnist-efficiency-sweep/README.md @@ -0,0 +1,49 @@ +## Purpose + +Sweep requested depth for a `fashion_mnist` lossguide model and compare efficiency error for: + +- CPU TreeSHAP +- CPU QuadratureSHAP with `4` points +- CPU QuadratureSHAP with `6` points +- CPU QuadratureSHAP with `8` points +- CPU QuadratureSHAP with `16` points + +The experiment records `mean`, `p99`, and `max` efficiency error, where efficiency is checked +against the raw margin: + +`sum(phi) == predict(output_margin=True)` + +The generated result directories are local experiment outputs and are not intended to be tracked. + +## Commands + +Original `max_leaves=128` run: + +```bash +PYTHONPATH=/home/nfs/rorym/xgboost-wt/shapley-value-algorithms/python-package \ +LD_LIBRARY_PATH=/home/nfs/rorym/xgboost-wt/shapley-value-algorithms/lib:${LD_LIBRARY_PATH} \ +/home/nfs/rorym/anaconda3/bin/conda run -n xgboost python \ + /home/nfs/rorym/xgboost-wt/shapley-value-algorithms/experiments/2026-04-21-fashion-mnist-efficiency-sweep/benchmark_fashion_mnist_efficiency.py \ + --out-dir /home/nfs/rorym/xgboost-wt/shapley-value-algorithms/experiments/2026-04-21-fashion-mnist-efficiency-sweep/results \ + --points 4 8 16 +``` + +Follow-up `max_leaves=1024` run: + +```bash +PYTHONPATH=/home/nfs/rorym/xgboost-wt/shapley-value-algorithms/python-package \ +LD_LIBRARY_PATH=/home/nfs/rorym/xgboost-wt/shapley-value-algorithms/lib:${LD_LIBRARY_PATH} \ +/home/nfs/rorym/anaconda3/bin/conda run -n xgboost python \ + /home/nfs/rorym/xgboost-wt/shapley-value-algorithms/experiments/2026-04-21-fashion-mnist-efficiency-sweep/benchmark_fashion_mnist_efficiency.py \ + --out-dir /home/nfs/rorym/xgboost-wt/shapley-value-algorithms/experiments/2026-04-21-fashion-mnist-efficiency-sweep/results-maxleaves1024 \ + --max-leaves 1024 --depths 4 8 12 16 24 32 48 64 --points 4 6 8 16 +``` + +## Generated Outputs + +- `results.json` +- `results.csv` +- `summary.md` +- `efficiency_mean.png` +- `efficiency_p99.png` +- `efficiency_max.png` diff --git a/experiments/2026-04-21-fashion-mnist-efficiency-sweep/benchmark_fashion_mnist_efficiency.py b/experiments/2026-04-21-fashion-mnist-efficiency-sweep/benchmark_fashion_mnist_efficiency.py new file mode 100644 index 000000000000..2106d4298f75 --- /dev/null +++ b/experiments/2026-04-21-fashion-mnist-efficiency-sweep/benchmark_fashion_mnist_efficiency.py @@ -0,0 +1,307 @@ +"""Run Fashion-MNIST SHAP efficiency-error sweeps.""" + +from __future__ import annotations + +# pylint: disable=missing-function-docstring,too-many-locals +import argparse +import csv +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import xgboost as xgb +from sklearn import datasets + +DEFAULT_DEPTHS = [4, 8, 12, 16, 24, 32, 48, 64] +DEFAULT_POINTS = [4, 6, 8, 16] +DEFAULT_SEED = 20260421 +DEFAULT_TEST_ROWS = 512 +DEFAULT_THREADS = 35 +DEFAULT_ROUNDS = 100 +DEFAULT_MAX_LEAVES = 128 + + +def fetch_fashion_mnist() -> tuple[object, np.ndarray]: + x, y = datasets.fetch_openml("Fashion-MNIST", return_X_y=True) + return x, y.astype(np.int64) + + +def tree_stats(model: xgb.Booster) -> dict[str, float]: + dump = model.get_dump(dump_format="json", with_stats=True) + + def walk(node: dict, depth: int = 0) -> tuple[int, int, int]: + children = node.get("children", []) + if not children: + return depth, 1, 1 + max_depth = depth + node_count = 1 + leaf_count = 0 + for child in children: + child_depth, child_nodes, child_leaves = walk(child, depth + 1) + max_depth = max(max_depth, child_depth) + node_count += child_nodes + leaf_count += child_leaves + return max_depth, node_count, leaf_count + + max_depths: list[int] = [] + node_counts: list[int] = [] + leaf_counts: list[int] = [] + for tree_json in dump: + tree_json = re.sub(r"\bnan\b", "0", tree_json) + tree_json = re.sub(r"\binf\b", "0", tree_json) + tree = json.loads(tree_json) + max_depth, nodes, leaves = walk(tree) + max_depths.append(max_depth) + node_counts.append(nodes) + leaf_counts.append(leaves) + + return { + "num_trees": len(dump), + "mean_max_depth": float(np.mean(max_depths)), + "max_max_depth": float(np.max(max_depths)), + "mean_nodes": float(np.mean(node_counts)), + "mean_leaves": float(np.mean(leaf_counts)), + } + + +def train_model( + x_train: object, y_train: np.ndarray, depth: int, seed: int, max_leaves: int +) -> xgb.Booster: + dtrain = xgb.QuantileDMatrix(x_train, y_train, enable_categorical=True) + params: dict[str, object] = { + "objective": "multi:softmax", + "num_class": 10, + "tree_method": "hist", + "device": "cpu", + "grow_policy": "lossguide", + "max_leaves": max_leaves, + "max_depth": depth, + "eta": 0.01, + "seed": seed, + "nthread": DEFAULT_THREADS, + } + return xgb.train(params, dtrain, num_boost_round=DEFAULT_ROUNDS, verbose_eval=False) + + +def sample_rows( + x: object, y: np.ndarray, rows: int, seed: int +) -> tuple[object, np.ndarray]: + rs = np.random.RandomState(seed) + row_idx = rs.choice(len(y), size=rows, replace=False) + if hasattr(x, "iloc"): + return x.iloc[row_idx, :], y[row_idx] + return x[row_idx, :], y[row_idx] + + +def efficiency_metrics(pred: np.ndarray, margin: np.ndarray) -> dict[str, float]: + err = np.abs(np.sum(pred, axis=pred.ndim - 1) - margin).reshape(-1) + return { + "mean_efficiency_err": float(np.mean(err)), + "p99_efficiency_err": float(np.quantile(err, 0.99)), + "max_efficiency_err": float(np.max(err)), + } + + +def predict_contribs( + booster: xgb.Booster, + dtest: xgb.DMatrix, + algorithm: str, + quadrature_points: int | None, +) -> np.ndarray: + params: dict[str, object] = {"device": "cpu", "shap_algorithm": algorithm} + if algorithm == "quadratureshap": + assert quadrature_points is not None + params["quadratureshap_points"] = quadrature_points + booster = booster.copy() + booster.set_param(params) + return np.asarray(booster.predict(dtest, pred_contribs=True)) + + +def write_csv(path: Path, rows: list[dict[str, object]]) -> None: + fieldnames: list[str] = [] + seen: set[str] = set() + for row in rows: + for key in row.keys(): + if key not in seen: + seen.add(key) + fieldnames.append(key) + with path.open("w", newline="") as fd: + writer = csv.DictWriter(fd, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def make_plot(rows: list[dict[str, object]], metric: str, out_path: Path) -> None: + plt.figure(figsize=(7, 4.5)) + series = {} + for row in rows: + series.setdefault(row["algorithm_label"], []).append(row) + for label, vals in series.items(): + vals = sorted(vals, key=lambda r: r["requested_depth"]) + xs = [r["requested_depth"] for r in vals] + ys = [r[metric] for r in vals] + plt.plot(xs, ys, marker="o", linewidth=2, label=label) + plt.yscale("log") + plt.xlabel("Requested max_depth") + plt.ylabel(metric.replace("_", " ")) + plt.title(f"Fashion-MNIST efficiency sweep: {metric}") + plt.grid(True, which="both", alpha=0.25) + plt.legend() + plt.tight_layout() + plt.savefig(out_path, dpi=160) + plt.close() + + +def write_summary(path: Path, rows: list[dict[str, object]]) -> None: + header = ( + "| algorithm | requested_depth | mean_max_depth | max_max_depth | " + "mean_efficiency_err | p99_efficiency_err | max_efficiency_err |" + ) + lines = [ + "## Fashion-MNIST Efficiency Sweep", + "", + header, + "| --- | --- | --- | --- | --- | --- | --- |", + ] + for row in rows: + lines.append( + f"| {row['algorithm_label']} | {row['requested_depth']} | " + f"{row['mean_max_depth']:.3f} | {row['max_max_depth']:.0f} | " + f"{row['mean_efficiency_err']:.6e} | {row['p99_efficiency_err']:.6e} | " + f"{row['max_efficiency_err']:.6e} |" + ) + path.write_text("\n".join(lines) + "\n") + + +def write_outputs( + out_dir: Path, metadata: dict[str, object], rows: list[dict[str, object]] +) -> None: + (out_dir / "results.json").write_text( + json.dumps({"metadata": metadata, "rows": rows}, indent=2) + ) + write_csv(out_dir / "results.csv", rows) + write_summary(out_dir / "summary.md", rows) + make_plot(rows, "mean_efficiency_err", out_dir / "efficiency_mean.png") + make_plot(rows, "p99_efficiency_err", out_dir / "efficiency_p99.png") + make_plot(rows, "max_efficiency_err", out_dir / "efficiency_max.png") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=Path, required=True) + parser.add_argument("--depths", type=int, nargs="+", default=DEFAULT_DEPTHS) + parser.add_argument("--points", type=int, nargs="+", default=DEFAULT_POINTS) + parser.add_argument("--test-rows", type=int, default=DEFAULT_TEST_ROWS) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED) + parser.add_argument("--max-leaves", type=int, default=DEFAULT_MAX_LEAVES) + parser.add_argument("--reuse-json", type=Path, default=None) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + args.out_dir.mkdir(parents=True, exist_ok=True) + if args.reuse_json is not None: + payload = json.loads(args.reuse_json.read_text()) + metadata = payload["metadata"] + rows = payload["rows"] + completed_depths = {row["requested_depth"] for row in rows} + pending_depths = [ + depth for depth in args.depths if depth not in completed_depths + ] + if pending_depths: + x, y = fetch_fashion_mnist() + x_test, y_test = sample_rows(x, y, args.test_rows, args.seed) + dtest = xgb.DMatrix(x_test, y_test, enable_categorical=True) + for depth in pending_depths: + print(f"Training requested depth {depth}") + booster = train_model(x, y, depth, args.seed, args.max_leaves) + stats = tree_stats(booster) + margin = np.asarray(booster.predict(dtest, output_margin=True)) + + treeshap = predict_contribs(booster, dtest, "treeshap", None) + rows.append( + { + "algorithm": "treeshap", + "algorithm_label": "TreeSHAP", + "requested_depth": depth, + **stats, + **efficiency_metrics(treeshap, margin), + } + ) + + for points in args.points: + contribs = predict_contribs( + booster, dtest, "quadratureshap", points + ) + rows.append( + { + "algorithm": "quadratureshap", + "algorithm_label": f"QuadratureSHAP-{points}", + "requested_depth": depth, + "quadrature_points": points, + **stats, + **efficiency_metrics(contribs, margin), + } + ) + metadata["depths"] = sorted( + set(metadata.get("depths", [])) | set(args.depths) + ) + metadata["points"] = sorted( + set(metadata.get("points", [])) | set(args.points) + ) + metadata["max_leaves"] = args.max_leaves + write_outputs(args.out_dir, metadata, rows) + else: + x, y = fetch_fashion_mnist() + x_test, y_test = sample_rows(x, y, args.test_rows, args.seed) + dtest = xgb.DMatrix(x_test, y_test, enable_categorical=True) + + rows = [] + metadata = { + "seed": args.seed, + "test_rows": args.test_rows, + "rounds": DEFAULT_ROUNDS, + "max_leaves": args.max_leaves, + "depths": args.depths, + "points": args.points, + "threads": DEFAULT_THREADS, + } + for depth in args.depths: + print(f"Training requested depth {depth}") + booster = train_model(x, y, depth, args.seed, args.max_leaves) + stats = tree_stats(booster) + margin = np.asarray(booster.predict(dtest, output_margin=True)) + + treeshap = predict_contribs(booster, dtest, "treeshap", None) + rows.append( + { + "algorithm": "treeshap", + "algorithm_label": "TreeSHAP", + "requested_depth": depth, + **stats, + **efficiency_metrics(treeshap, margin), + } + ) + + for points in args.points: + contribs = predict_contribs(booster, dtest, "quadratureshap", points) + rows.append( + { + "algorithm": "quadratureshap", + "algorithm_label": f"QuadratureSHAP-{points}", + "requested_depth": depth, + "quadrature_points": points, + **stats, + **efficiency_metrics(contribs, margin), + } + ) + write_outputs(args.out_dir, metadata, rows) + + write_outputs(args.out_dir, metadata, rows) + + +if __name__ == "__main__": + main() diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 2d1e63133f52..b361b2a0ceae 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -62,6 +62,10 @@ struct GBTreeTrainParam : public XGBoostParameter { TreeProcessType process_type; // tree construction method TreeMethod tree_method; + // CPU SHAP implementation used for pred_contribs. + std::string shap_algorithm; + // Number of quadrature points for QuadratureSHAP variants. + std::size_t quadratureshap_points; // declare parameters DMLC_DECLARE_PARAMETER(GBTreeTrainParam) { DMLC_DECLARE_FIELD(updater_seq).describe("Tree updater sequence.").set_default(""); @@ -80,6 +84,12 @@ struct GBTreeTrainParam : public XGBoostParameter { .add_enum("exact", TreeMethod::kExact) .add_enum("hist", TreeMethod::kHist) .describe("Choice of tree construction method."); + DMLC_DECLARE_FIELD(shap_algorithm) + .set_default("treeshap") + .describe("CPU algorithm used for SHAP feature contributions."); + DMLC_DECLARE_FIELD(quadratureshap_points) + .set_default(8) + .describe("Experimental fixed quadrature size used by CPU and GPU QuadratureSHAP."); } }; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 85947b907760..4d8418eae09d 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -6,6 +6,7 @@ #include // for size_t #include // for uint32_t, int32_t, uint64_t #include // for unique_ptr, shared_ptr +#include // for invalid_argument, out_of_range #include // for vector #include "../collective/allreduce.h" // for Allreduce @@ -745,6 +746,28 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} + void Configure(Args const &cfg) override { + for (auto const &kv : cfg) { + if (kv.first == "shap_algorithm") { + CHECK(kv.second == "treeshap" || kv.second == "quadratureshap") + << "Unknown SHAP algorithm: " << kv.second; + shap_algorithm_ = kv.second; + } else if (kv.first == "quadratureshap_points") { + std::size_t points{0}; + try { + points = std::stoul(kv.second); + } catch (std::invalid_argument const &) { + LOG(FATAL) << "Invalid quadratureshap_points: " << kv.second; + } catch (std::out_of_range const &) { + LOG(FATAL) << "quadratureshap_points out of range: " << kv.second; + } + CHECK(points == 4 || points == 6 || points == 8 || points == 16) + << "CPU QuadratureSHAP currently supports quadrature sizes of 4, 6, 8, or 16."; + quadrature_shap_points_ = points; + } + } + } + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, gbm::GBTreeModel const &model, bst_tree_t tree_begin, bst_tree_t tree_end = 0, std::vector const *tree_weights = nullptr) const override { @@ -868,6 +891,10 @@ class CPUPredictor : public Predictor { if (approximate) { interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model, ntree_limit, tree_weights); + } else if (shap_algorithm_ == "quadratureshap" && condition == 0 && condition_feature == 0) { + interpretability::cpu_impl::QuadratureShapValues(this->ctx_, p_fmat, out_contribs, model, + ntree_limit, tree_weights, + quadrature_shap_points_); } else { interpretability::ShapValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, tree_weights, condition, condition_feature); @@ -878,9 +905,19 @@ class CPUPredictor : public Predictor { gbm::GBTreeModel const &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate) const override { - interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, - tree_weights, approximate); + if (!approximate && shap_algorithm_ == "quadratureshap") { + interpretability::cpu_impl::QuadratureShapInteractionValues(this->ctx_, p_fmat, out_contribs, + model, ntree_limit, tree_weights, + quadrature_shap_points_); + } else { + interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, + tree_weights, approximate); + } } + + private: + std::string shap_algorithm_{"treeshap"}; + std::size_t quadrature_shap_points_{8}; }; XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") diff --git a/src/predictor/gpu_data_accessor.cuh b/src/predictor/gpu_data_accessor.cuh index 8fee7a149f08..4fcdb58c9662 100644 --- a/src/predictor/gpu_data_accessor.cuh +++ b/src/predictor/gpu_data_accessor.cuh @@ -32,14 +32,17 @@ struct SparsePageView { num_features{n_features} {} [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { - // Binary search - auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; - auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; - if (end_ptr - begin_ptr == this->NumCols()) { - // Bypass span check for dense data - return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; + auto row_begin = d_row_ptr[ridx]; + auto row_size = d_row_ptr[ridx + 1] - row_begin; + auto begin_ptr = d_data.data() + row_begin; + if (row_size == this->NumCols()) { + // Dense rows are laid out in feature order, so this is a raw-pointer lookup. + return begin_ptr[fidx].fvalue; } - common::Span::iterator previous_middle; + + // Binary search over sparse entries using raw pointers. + auto end_ptr = begin_ptr + row_size; + auto previous_middle = static_cast(nullptr); while (end_ptr != begin_ptr) { auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; if (middle == previous_middle) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index f4c0d00c6aeb..beba570fb11c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -8,6 +8,7 @@ #include // for proclaim_return_type #include // for swap #include +#include #include "../collective/allreduce.h" #include "../common/bitfield.h" @@ -669,6 +670,27 @@ class GPUPredictor : public xgboost::Predictor { } } + void Configure(Args const& cfg) override { + for (auto const& kv : cfg) { + if (kv.first == "shap_algorithm") { + CHECK(kv.second == "treeshap" || kv.second == "quadratureshap") + << "Unknown SHAP algorithm: " << kv.second; + shap_algorithm_ = kv.second; + } else if (kv.first == "quadratureshap_points") { + std::size_t points{0}; + try { + points = std::stoul(kv.second); + } catch (std::invalid_argument const&) { + LOG(FATAL) << "Invalid quadratureshap_points: " << kv.second; + } catch (std::out_of_range const&) { + LOG(FATAL) << "quadratureshap_points out of range: " << kv.second; + } + CHECK_EQ(points, 8) << "GPU QuadratureSHAP currently uses a fixed quadrature size of 8."; + quadrature_shap_points_ = points; + } + } + } + void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, bst_tree_t tree_begin, bst_tree_t tree_end = 0, std::vector const* tree_weights = nullptr) const override { @@ -766,14 +788,20 @@ class GPUPredictor : public xgboost::Predictor { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, bst_tree_t tree_end, - std::vector const* tree_weights, bool approximate, int, - unsigned) const override { + std::vector const* tree_weights, bool approximate, int condition, + unsigned condition_feature) const override { xgboost_NVTX_FN_RANGE(); if (approximate) { LOG(FATAL) << "Approximated contribution is not implemented in the GPU predictor, use CPU " "instead."; } - interpretability::ShapValues(ctx_, p_fmat, out_contribs, model, tree_end, tree_weights, 0, 0); + if (shap_algorithm_ == "quadratureshap" && condition == 0 && condition_feature == 0) { + interpretability::cuda_impl::QuadratureShapValues(ctx_, p_fmat, out_contribs, model, tree_end, + tree_weights, quadrature_shap_points_); + } else { + interpretability::ShapValues(ctx_, p_fmat, out_contribs, model, tree_end, tree_weights, + condition, condition_feature); + } } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, @@ -785,8 +813,13 @@ class GPUPredictor : public xgboost::Predictor { LOG(FATAL) << "Approximated contribution is not implemented in GPU predictor, use cpu " "instead."; } - interpretability::ShapInteractionValues(ctx_, p_fmat, out_contribs, model, tree_end, - tree_weights, approximate); + if (shap_algorithm_ == "quadratureshap") { + interpretability::cuda_impl::QuadratureShapInteractionValues( + ctx_, p_fmat, out_contribs, model, tree_end, tree_weights, quadrature_shap_points_); + } else { + interpretability::ShapInteractionValues(ctx_, p_fmat, out_contribs, model, tree_end, + tree_weights, approximate); + } } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, @@ -832,6 +865,8 @@ class GPUPredictor : public xgboost::Predictor { private: ColumnSplitHelper column_split_helper_; + std::string shap_algorithm_{"treeshap"}; + std::size_t quadrature_shap_points_{8}; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") diff --git a/src/predictor/interpretability/quadrature.h b/src/predictor/interpretability/quadrature.h new file mode 100644 index 000000000000..34469622372e --- /dev/null +++ b/src/predictor/interpretability/quadrature.h @@ -0,0 +1,97 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#ifndef XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_ +#define XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_ + +#include +#include +#include +#include +#include +#include + +#include "xgboost/logging.h" + +namespace xgboost::interpretability::detail { + +template +struct EndpointQuadratureRule { + std::size_t points{0}; + std::array nodes{}; + std::array weights{}; +}; + +inline double LegendrePolynomial(std::size_t n, double x) { + double p0 = 1.0; + if (n == 0) { + return p0; + } + double p1 = x; + if (n == 1) { + return p1; + } + for (std::size_t k = 2; k <= n; ++k) { + double pk = + ((2.0 * static_cast(k) - 1.0) * x * p1 - (static_cast(k) - 1.0) * p0) / + static_cast(k); + p0 = p1; + p1 = pk; + } + return p1; +} + +inline double LegendreDerivative(std::size_t n, double x, double pn) { + auto n_d = static_cast(n); + return n_d * (x * pn - LegendrePolynomial(n - 1, x)) / (x * x - 1.0); +} + +template +inline EndpointQuadratureRule MakeEndpointQuadrature(std::size_t n, + double convergence_eps) { + CHECK_GE(n, 2); + CHECK_LE(n, MaxPoints); + + EndpointQuadratureRule rule; + rule.points = n; + std::vector> nodes_weights; + nodes_weights.reserve(n); + + for (std::size_t i = 0; i < n; ++i) { + double theta = M_PI * (static_cast(i) + 0.75) / (static_cast(n) + 0.5); + double x = std::cos(theta); + for (std::size_t iter = 0; iter < 64; ++iter) { + auto pn = LegendrePolynomial(n, x); + auto dpn = LegendreDerivative(n, x, pn); + auto dx = pn / dpn; + x -= dx; + if (std::abs(dx) < convergence_eps) { + break; + } + } + + auto pn = LegendrePolynomial(n, x); + auto dpn = LegendreDerivative(n, x, pn); + auto w = 2.0 / ((1.0 - x * x) * dpn * dpn); + double s = 0.5 * (x + 1.0); + double ws = 0.5 * w; + nodes_weights.emplace_back(s * s, 2.0 * s * ws); + } + + std::sort(nodes_weights.begin(), nodes_weights.end(), + [](auto const &l, auto const &r) { return l.first < r.first; }); + for (std::size_t i = 0; i < n; ++i) { + rule.nodes[i] = nodes_weights[i].first; + rule.weights[i] = nodes_weights[i].second; + } + return rule; +} + +template +inline EndpointQuadratureRule MakeEndpointQuadrature(double convergence_eps) { + return MakeEndpointQuadrature(Points, convergence_eps); +} + +} // namespace xgboost::interpretability::detail + +#endif // XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_ diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index a6608a4133fb..999d15c7f134 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -3,7 +3,11 @@ */ #include "shap.h" -#include // for fill +#include // for copy, fill +#include // for array +#include // for abs +#include // for uint32_t +#include // for numeric_limits #include // for remove_const_t #include // for vector @@ -12,11 +16,12 @@ #include "../../tree/tree_view.h" // for ScalarTreeView #include "../data_accessor.h" // for GHistIndexMatrixView #include "../predict_fn.h" // for GetTreeLimit -#include "../treeshap.h" // for CalculateContributions #include "dmlc/omp.h" // for omp_get_thread_num -#include "xgboost/base.h" // for bst_omp_uint -#include "xgboost/logging.h" // for CHECK -#include "xgboost/tree_model.h" // for MTNotImplemented +#include "quadrature.h" +#include "xgboost/base.h" // for bst_omp_uint +#include "xgboost/logging.h" // for CHECK +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for MTNotImplemented namespace xgboost::interpretability { namespace { @@ -57,7 +62,687 @@ void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVe std::vector *mean_values, std::vector *out_contribs) { CHECK_EQ(out_contribs->size(), feats.Size() + 1); - CalculateContributionsApprox(tree, feats, mean_values, out_contribs->data()); + CHECK_GT(mean_values->size(), 0U); + bst_feature_t split_index = 0; + float node_value = (*mean_values)[0]; + out_contribs->back() += node_value; + if (tree.IsLeaf(RegTree::kRoot)) { + return; + } + + bst_node_t nidx = RegTree::kRoot; + auto const &cats = tree.GetCategoriesMatrix(); + while (!tree.IsLeaf(nidx)) { + split_index = tree.SplitIndex(nidx); + nidx = predictor::GetNextNode(tree, nidx, feats.GetFvalue(split_index), + feats.IsMissing(split_index), cats); + auto new_value = (*mean_values)[nidx]; + (*out_contribs)[split_index] += new_value - node_value; + node_value = new_value; + } + (*out_contribs)[split_index] += tree.LeafValue(nidx) - node_value; +} + +struct PathElement { + int feature_index; + float zero_fraction; + float one_fraction; + float pweight; + PathElement() = default; + PathElement(int i, float z, float o, float w) + : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {} +}; + +void ExtendPath(PathElement *unique_path, std::uint32_t unique_depth, float zero_fraction, + float one_fraction, int feature_index) { + unique_path[unique_depth].feature_index = feature_index; + unique_path[unique_depth].zero_fraction = zero_fraction; + unique_path[unique_depth].one_fraction = one_fraction; + unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f); + for (int i = static_cast(unique_depth) - 1; i >= 0; --i) { + unique_path[i + 1].pweight += + one_fraction * unique_path[i].pweight * (i + 1) / static_cast(unique_depth + 1); + unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) / + static_cast(unique_depth + 1); + } +} + +void UnwindPath(PathElement *unique_path, std::uint32_t unique_depth, std::uint32_t path_index) { + auto const one_fraction = unique_path[path_index].one_fraction; + auto const zero_fraction = unique_path[path_index].zero_fraction; + float next_one_portion = unique_path[unique_depth].pweight; + + for (int i = static_cast(unique_depth) - 1; i >= 0; --i) { + if (one_fraction != 0.0f) { + auto const tmp = unique_path[i].pweight; + unique_path[i].pweight = + next_one_portion * (unique_depth + 1) / static_cast((i + 1) * one_fraction); + next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) / + static_cast(unique_depth + 1); + } else { + unique_path[i].pweight = unique_path[i].pweight * (unique_depth + 1) / + static_cast(zero_fraction * (unique_depth - i)); + } + } + + for (auto i = path_index; i < unique_depth; ++i) { + unique_path[i].feature_index = unique_path[i + 1].feature_index; + unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction; + unique_path[i].one_fraction = unique_path[i + 1].one_fraction; + } +} + +float UnwoundPathSum(PathElement const *unique_path, std::uint32_t unique_depth, + std::uint32_t path_index) { + auto const one_fraction = unique_path[path_index].one_fraction; + auto const zero_fraction = unique_path[path_index].zero_fraction; + float next_one_portion = unique_path[unique_depth].pweight; + float total = 0.0f; + for (int i = static_cast(unique_depth) - 1; i >= 0; --i) { + if (one_fraction != 0.0f) { + auto const tmp = + next_one_portion * (unique_depth + 1) / static_cast((i + 1) * one_fraction); + total += tmp; + next_one_portion = + unique_path[i].pweight - + tmp * zero_fraction * ((unique_depth - i) / static_cast(unique_depth + 1)); + } else if (zero_fraction != 0.0f) { + total += (unique_path[i].pweight / zero_fraction) / + ((unique_depth - i) / static_cast(unique_depth + 1)); + } else { + CHECK_EQ(unique_path[i].pweight, 0.0f) << "Unique path " << i << " must have zero weight"; + } + } + return total; +} + +void TreeShap(tree::ScalarTreeView const &tree, RegTree::FVec const &feat, float *phi, + bst_node_t nidx, std::uint32_t unique_depth, PathElement *parent_unique_path, + float parent_zero_fraction, float parent_one_fraction, int parent_feature_index, + int condition, std::uint32_t condition_feature, float condition_fraction) { + if (condition_fraction == 0.0f) { + return; + } + + PathElement *unique_path = parent_unique_path + unique_depth + 1; + std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path); + if (condition == 0 || condition_feature != static_cast(parent_feature_index)) { + ExtendPath(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction, + parent_feature_index); + } + + auto const split_index = tree.SplitIndex(nidx); + if (tree.IsLeaf(nidx)) { + for (std::uint32_t i = 1; i <= unique_depth; ++i) { + auto const w = UnwoundPathSum(unique_path, unique_depth, i); + auto const &el = unique_path[i]; + phi[el.feature_index] += + w * (el.one_fraction - el.zero_fraction) * tree.LeafValue(nidx) * condition_fraction; + } + return; + } + + auto const &cats = tree.GetCategoriesMatrix(); + auto hot_index = predictor::GetNextNode(tree, nidx, feat.GetFvalue(split_index), + feat.IsMissing(split_index), cats); + auto const cold_index = + (hot_index == tree.LeftChild(nidx) ? tree.RightChild(nidx) : tree.LeftChild(nidx)); + auto const w = tree.Stat(nidx).sum_hess; + auto const hot_zero_fraction = tree.Stat(hot_index).sum_hess / w; + auto const cold_zero_fraction = tree.Stat(cold_index).sum_hess / w; + float incoming_zero_fraction = 1.0f; + float incoming_one_fraction = 1.0f; + + std::uint32_t path_index = 0; + for (; path_index <= unique_depth; ++path_index) { + if (static_cast(unique_path[path_index].feature_index) == split_index) { + break; + } + } + if (path_index != unique_depth + 1) { + incoming_zero_fraction = unique_path[path_index].zero_fraction; + incoming_one_fraction = unique_path[path_index].one_fraction; + UnwindPath(unique_path, unique_depth, path_index); + unique_depth -= 1; + } + + float hot_condition_fraction = condition_fraction; + float cold_condition_fraction = condition_fraction; + if (condition > 0 && split_index == condition_feature) { + cold_condition_fraction = 0.0f; + unique_depth -= 1; + } else if (condition < 0 && split_index == condition_feature) { + hot_condition_fraction *= hot_zero_fraction; + cold_condition_fraction *= cold_zero_fraction; + unique_depth -= 1; + } + + TreeShap(tree, feat, phi, hot_index, unique_depth + 1, unique_path, + hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index, + condition, condition_feature, hot_condition_fraction); + TreeShap(tree, feat, phi, cold_index, unique_depth + 1, unique_path, + cold_zero_fraction * incoming_zero_fraction, 0.0f, split_index, condition, + condition_feature, cold_condition_fraction); +} + +void CalculateContributions(tree::ScalarTreeView const &tree, RegTree::FVec const &feat, + std::vector *mean_values, float *out_contribs, int condition, + std::uint32_t condition_feature) { + if (condition == 0) { + out_contribs[feat.Size()] += (*mean_values)[RegTree::kRoot]; + } + + auto const maxd = tree.MaxDepth() + 2; + std::vector unique_path_data((maxd * (maxd + 1)) / 2); + TreeShap(tree, feat, out_contribs, RegTree::kRoot, 0, unique_path_data.data(), 1.0f, 1.0f, -1, + condition, condition_feature, 1.0f); +} + +// The CPU additive path supports a few fixed quadrature sizes so experiments can sweep point +// counts while keeping compile-time-unrolled hot loops. +constexpr std::size_t kQuadratureShapPoints = 8; +constexpr double kQuadratureShapBuildQeps = 1e-15; +constexpr float kQuadratureShapUnseen = -999.0f; + +template +struct QuadratureRule { + std::array nodes{}; + std::array weights{}; +}; +template +using QuadratureBuffer = std::array; + +template +QuadratureRule const &GetQuadratureRule() { + static QuadratureRule const rule = [] { + auto const rule_d = detail::MakeEndpointQuadrature(kQuadratureShapBuildQeps); + QuadratureRule out; + for (std::size_t i = 0; i < Points; ++i) { + out.nodes[i] = static_cast(rule_d.nodes[i]); + out.weights[i] = static_cast(rule_d.weights[i]); + } + return out; + }(); + return rule; +} + +template +void AddInPlace(QuadratureBuffer *lhs, QuadratureBuffer const &rhs) { + for (std::size_t i = 0; i < Points; ++i) { + (*lhs)[i] += rhs[i]; + } +} + +template +float ExtractQuadratureDelta(QuadratureRule const &rule, + QuadratureBuffer const &h_vals, float p_enter, float p_exit) { + float acc = 0.0f; + if (p_enter != 1.0f) { + auto const alpha_enter = p_enter - 1.0f; + for (std::size_t i = 0; i < Points; ++i) { + acc += alpha_enter * h_vals[i] / (1.0f + alpha_enter * rule.nodes[i]); + } + } + if (p_exit != 1.0f) { + auto const alpha_exit = p_exit - 1.0f; + for (std::size_t i = 0; i < Points; ++i) { + acc -= alpha_exit * h_vals[i] / (1.0f + alpha_exit * rule.nodes[i]); + } + } + return acc; +} + +constexpr bool kQuadratureInteractionUseEdgeKernel = false; +constexpr bool kQuadratureInteractionUseLatestLiveIndex = false; + +// Off-diagonal interaction terms use the same return-edge delta as additive SHAP, but with one +// partner feature removed from the live quadrature basis. For an active partner with live ratio +// q_j, the weighted subtree return factors as +// H(t) = H_without_j(t) * (1 + (q_j - 1) t) +// after the zero-fraction terms cancel. The conditioned on/off difference is therefore the +// precomputed return-edge kernel divided by that partner factor and multiplied by (q_j - 1). +template +float ExtractQuadratureInteractionDelta(QuadratureRule const &rule, + QuadratureBuffer const &h_vals, float p_enter, + float p_exit, float q_partner) { + if (q_partner == 1.0f) { + return 0.0f; + } + + auto const alpha_partner = q_partner - 1.0f; + auto const has_enter = p_enter != 1.0f; + auto const has_exit = p_exit != 1.0f; + auto const alpha_enter = p_enter - 1.0f; + auto const alpha_exit = p_exit - 1.0f; + + float acc = 0.0f; + for (std::size_t i = 0; i < Points; ++i) { + float edge_delta = 0.0f; + if (has_enter) { + edge_delta += alpha_enter / (1.0f + alpha_enter * rule.nodes[i]); + } + if (has_exit) { + edge_delta -= alpha_exit / (1.0f + alpha_exit * rule.nodes[i]); + } + acc += alpha_partner * h_vals[i] * edge_delta / (1.0f + alpha_partner * rule.nodes[i]); + } + return acc; +} + +template +float ExtractQuadratureInteractionDelta(QuadratureRule const &rule, + QuadratureBuffer const &edge_kernel, + float q_partner) { + if (q_partner == 1.0f) { + return 0.0f; + } + + auto const alpha_partner = q_partner - 1.0f; + float acc = 0.0f; + for (std::size_t i = 0; i < Points; ++i) { + acc += alpha_partner * edge_kernel[i] / (1.0f + alpha_partner * rule.nodes[i]); + } + return acc; +} + +template +void WriteWeightedLeafReturn(tree::ScalarTreeView const &tree, QuadratureRule const &rule, + bst_node_t nidx, QuadratureBuffer const &c_vals, float w_prod, + QuadratureBuffer *out_h) { + auto const leaf_scale = w_prod * tree.LeafValue(nidx); + for (std::size_t i = 0; i < Points; ++i) { + (*out_h)[i] = c_vals[i] * leaf_scale * rule.weights[i]; + } +} + +// Dense row-local output view for additive contributions. +template +struct ContributionVectorView { + T *data; + std::size_t size; + + T &operator[](std::size_t idx) const { return data[idx]; } +}; + +// Dense row-local output view for interaction matrices. Future formulations can target this sink +// directly instead of open-coding flattened indexing arithmetic. +template +struct DenseInteractionMatrixView { + T *data; + std::size_t ncolumns; + + T &operator()(std::size_t i, std::size_t j) const { return data[i * ncolumns + j]; } +}; + +// One active split on the current root-to-node path. Traversal owns the push/pop discipline, while +// formulations can inspect the live path without duplicating duplicate-feature bookkeeping. +struct QuadraturePathElement { + bst_feature_t split_index; + float p_parent; + float p_child; + std::int32_t prev_live_index; +}; + +// Read-only formulation view of the current root-to-node path. Traversal keeps ownership of the +// stack so different contribution formulations can inspect the same live path state. +struct QuadraturePathView { + common::Span elements; + common::Span latest_live_index; + + [[nodiscard]] auto Depth() const { return elements.size(); } + [[nodiscard]] bool Empty() const { return elements.empty(); } + [[nodiscard]] auto Entries() const { return elements; } + + [[nodiscard]] auto CurrentSplit() const -> QuadraturePathElement const & { + CHECK(!elements.empty()); + return elements.back(); + } + + // Iterate the active path once per feature, newest-to-oldest. Later duplicate splits are the + // live ones for path-local partner lookups, so older duplicates are hidden from formulations. + template + void ForEachUniqueFeature(Fn &&fn) const { + if (!latest_live_index.empty()) { + for (std::size_t i = elements.size(); i != 0; --i) { + auto const idx = i - 1; + auto const split_index = elements[idx].split_index; + if (latest_live_index[split_index] == static_cast(idx)) { + fn(idx, elements[idx]); + } + } + } else { + for (std::size_t i = elements.size(); i != 0; --i) { + auto const idx = i - 1; + auto const split_index = elements[idx].split_index; + bool shadowed = false; + for (std::size_t newer = elements.size(); newer > i; --newer) { + if (elements[newer - 1].split_index == split_index) { + shadowed = true; + break; + } + } + if (!shadowed) { + fn(idx, elements[idx]); + } + } + } + } +}; + +struct EmptyQuadraturePathState { + void Reset() const {} + void Push(bst_feature_t, float, float) const {} + void Pop(bst_feature_t) const {} + [[nodiscard]] auto View() const { return QuadraturePathView{{}, {}}; } +}; + +struct LiveQuadraturePathState { + std::vector *path; + std::vector *latest_live_index; + + void Reset() const { path->clear(); } + + void Push(bst_feature_t split_index, float p_parent, float p_child) const { + if constexpr (kQuadratureInteractionUseLatestLiveIndex) { + auto prev_live = (*latest_live_index)[split_index]; + path->push_back(QuadraturePathElement{split_index, p_parent, p_child, prev_live}); + (*latest_live_index)[split_index] = static_cast(path->size() - 1); + } else { + path->push_back(QuadraturePathElement{split_index, p_parent, p_child, -1}); + } + } + + void Pop(bst_feature_t split_index) const { + if constexpr (kQuadratureInteractionUseLatestLiveIndex) { + (*latest_live_index)[split_index] = path->back().prev_live_index; + } + path->pop_back(); + } + + [[nodiscard]] auto View() const { + if constexpr (kQuadratureInteractionUseLatestLiveIndex) { + return QuadraturePathView{common::Span{*path}, + common::Span{*latest_live_index}}; + } else { + return QuadraturePathView{common::Span{*path}, {}}; + } + } +}; + +// Current additive SHAP formulation. It consumes the weighted subtree return and writes one +// feature contribution per return edge. +template +struct AdditiveContributionFormulation { + EmptyQuadraturePathState path_state; + ContributionVectorView phi; + + explicit AdditiveContributionFormulation(ContributionVectorView phi) : phi{phi} {} + + void ResetPath() const { path_state.Reset(); } + void PushPathSplit(bst_feature_t split_index, float p_parent, float p_child) const { + path_state.Push(split_index, p_parent, p_child); + } + void PopPathSplit(bst_feature_t split_index) const { path_state.Pop(split_index); } + + void HandleLeaf(tree::ScalarTreeView const &tree, QuadratureRule const &rule, + bst_node_t nidx, QuadratureBuffer const &c_vals, float w_prod, + QuadratureBuffer *out_h) const { + WriteWeightedLeafReturn(tree, rule, nidx, c_vals, w_prod, out_h); + } + + void HandleReturn(QuadratureRule const &rule, bst_feature_t split_index, + QuadratureBuffer const &h_vals, float p_enter, float p_exit) const { + phi[split_index] += ExtractQuadratureDelta(rule, h_vals, p_enter, p_exit); + } +}; + +// First path-local interaction formulation built on top of the quadrature traversal. It keeps the +// traversal and weighted subtree return shared with additive SHAP, and only changes how return +// edges are written into the dense interaction sink. +template +struct InteractionContributionFormulation { + struct EdgeEffect { + bst_feature_t split_index; + float diagonal_delta; + QuadratureBuffer edge_kernel; + }; + + LiveQuadraturePathState path_state; + ContributionVectorView phi_diag; + DenseInteractionMatrixView phi_interactions; + float scale; + + InteractionContributionFormulation(LiveQuadraturePathState path_state, + ContributionVectorView phi_diag, + DenseInteractionMatrixView phi_interactions, + float scale) + : path_state{path_state}, + phi_diag{phi_diag}, + phi_interactions{phi_interactions}, + scale{scale} {} + + void ResetPath() const { path_state.Reset(); } + void PushPathSplit(bst_feature_t split_index, float p_parent, float p_child) const { + path_state.Push(split_index, p_parent, p_child); + } + void PopPathSplit(bst_feature_t split_index) const { path_state.Pop(split_index); } + + // Traversal still needs a weighted subtree return, so the interaction path shares the additive + // leaf behavior and changes only the return-edge algebra. + void HandleLeaf(tree::ScalarTreeView const &tree, QuadratureRule const &rule, + bst_node_t nidx, QuadratureBuffer const &c_vals, float w_prod, + QuadratureBuffer *out_h) const { + WriteWeightedLeafReturn(tree, rule, nidx, c_vals, w_prod, out_h); + } + + [[nodiscard]] auto MakeEdgeEffect(QuadratureRule const &rule, bst_feature_t split_index, + QuadratureBuffer const &h_vals, float p_enter, + float p_exit) const { + QuadratureBuffer edge_kernel{}; + float diagonal_delta = 0.0f; + + if constexpr (kQuadratureInteractionUseEdgeKernel) { + auto const has_enter = p_enter != 1.0f; + auto const has_exit = p_exit != 1.0f; + auto const alpha_enter = p_enter - 1.0f; + auto const alpha_exit = p_exit - 1.0f; + + for (std::size_t i = 0; i < Points; ++i) { + float edge_delta = 0.0f; + if (has_enter) { + edge_delta += alpha_enter / (1.0f + alpha_enter * rule.nodes[i]); + } + if (has_exit) { + edge_delta -= alpha_exit / (1.0f + alpha_exit * rule.nodes[i]); + } + edge_kernel[i] = h_vals[i] * edge_delta; + diagonal_delta += edge_kernel[i]; + } + } else { + diagonal_delta = ExtractQuadratureDelta(rule, h_vals, p_enter, p_exit); + } + + return EdgeEffect{split_index, diagonal_delta, edge_kernel}; + } + + void AccumulateDiagonal(EdgeEffect const &edge) const { + phi_diag[edge.split_index] += scale * edge.diagonal_delta; + } + + // Walk the live unique path excluding the current split. A pairwise formulation can distribute + // the current edge effect across these partner features without reimplementing duplicate logic. + template + void ForEachPartner(QuadraturePathView path, Fn &&fn) const { + CHECK(!path.Empty()); + auto const current_split = path.CurrentSplit().split_index; + bool skipped_current = false; + path.ForEachUniqueFeature([&](std::size_t, QuadraturePathElement const &element) { + if (!skipped_current && element.split_index == current_split) { + skipped_current = true; + return; + } + fn(element); + }); + } + + void AccumulatePair(EdgeEffect const &edge, QuadraturePathElement const &partner, + float pair_delta) const { + auto const i = static_cast(edge.split_index); + auto const j = static_cast(partner.split_index); + phi_interactions(i, j) += scale * pair_delta; + } + + void HandleReturn(QuadratureRule const &rule, bst_feature_t split_index, + QuadratureBuffer const &h_vals, float p_enter, float p_exit) const { + auto path = path_state.View(); + auto const edge = this->MakeEdgeEffect(rule, split_index, h_vals, p_enter, p_exit); + this->AccumulateDiagonal(edge); + + this->ForEachPartner(path, [&](QuadraturePathElement const &partner) { + float pair_delta = 0.0f; + if constexpr (kQuadratureInteractionUseEdgeKernel) { + pair_delta = + ExtractQuadratureInteractionDelta(rule, edge.edge_kernel, partner.p_child); + } else { + pair_delta = ExtractQuadratureInteractionDelta(rule, h_vals, p_enter, p_exit, + partner.p_child); + } + this->AccumulatePair(edge, partner, pair_delta); + }); + } +}; + +// Tree-walk engine for quadrature formulations. It owns feature evaluation, child descent, and +// the live path-probability state, then hands leaf/return events to the selected formulation. +template +struct QuadratureShapTreeRunner { + tree::ScalarTreeView const &tree; + RegTree::FVec const &feat; + QuadratureRule const &rule; + std::vector *path_prob; + ContributionFormulation formulation; + + [[nodiscard]] bool EvaluateGoesLeft(bst_node_t nidx) const { + auto split_index = tree.SplitIndex(nidx); + auto const &cats = tree.GetCategoriesMatrix(); + auto next = predictor::GetNextNode(tree, nidx, feat.GetFvalue(split_index), + feat.IsMissing(split_index), cats); + return next == tree.LeftChild(nidx); + } + + [[nodiscard]] float ChildWeight(bst_node_t parent, bst_node_t child) const { + auto parent_cover = tree.Stat(parent).sum_hess; + CHECK_GT(parent_cover, 0.0f); + return tree.Stat(child).sum_hess / parent_cover; + } + + void VisitChild(bst_node_t split_node, bst_node_t child_node, float child_weight, bool satisfies, + QuadratureBuffer const &c_vals, float w_prod, + QuadratureBuffer *out_h) { + auto split_index = tree.SplitIndex(split_node); + auto p_old = (*path_prob)[split_index]; + + float p_e = 0.0f; + float p_up = 0.0f; + if (p_old == kQuadratureShapUnseen) { + p_e = satisfies ? 1.0f / child_weight : 0.0f; + p_up = 1.0f; + } else if (p_old == 0.0f) { + p_e = 0.0f; + p_up = 0.0f; + } else { + p_e = satisfies ? p_old / child_weight : 0.0f; + p_up = p_old; + } + + auto c_child = c_vals; + auto alpha_e = p_e - 1.0f; + for (std::size_t i = 0; i < Points; ++i) { + c_child[i] *= 1.0f + alpha_e * rule.nodes[i]; + } + + if (p_old != kQuadratureShapUnseen) { + auto alpha_old = p_old - 1.0f; + if (alpha_old != 0.0f) { + for (std::size_t i = 0; i < Points; ++i) { + c_child[i] /= 1.0f + alpha_old * rule.nodes[i]; + } + } + } + + (*path_prob)[split_index] = p_e; + formulation.PushPathSplit(split_index, p_up, p_e); + this->RunNode(child_node, c_child, w_prod * child_weight, out_h); + formulation.HandleReturn(rule, split_index, *out_h, p_e, p_up); + formulation.PopPathSplit(split_index); + (*path_prob)[split_index] = p_old; + } + + void RunNode(bst_node_t nidx, QuadratureBuffer const &c_vals, float w_prod, + QuadratureBuffer *out_h) { + if (tree.IsLeaf(nidx)) { + formulation.HandleLeaf(tree, rule, nidx, c_vals, w_prod, out_h); + return; + } + + auto left = tree.LeftChild(nidx); + auto right = tree.RightChild(nidx); + auto left_weight = this->ChildWeight(nidx, left); + auto right_weight = this->ChildWeight(nidx, right); + auto goes_left = this->EvaluateGoesLeft(nidx); + + QuadratureBuffer right_h{}; + + this->VisitChild(nidx, left, left_weight, goes_left, c_vals, w_prod, out_h); + this->VisitChild(nidx, right, right_weight, !goes_left, c_vals, w_prod, &right_h); + AddInPlace(out_h, right_h); + } + + void Run() { + formulation.ResetPath(); + if (tree.IsLeaf(RegTree::kRoot)) { + return; + } + + QuadratureBuffer c_init{}; + c_init.fill(1.0f); + QuadratureBuffer h_vals{}; + this->RunNode(RegTree::kRoot, c_init, 1.0f, &h_vals); + } +}; + +struct QuadratureShapModelData { + std::vector trees; + std::vector> trees_by_group; + std::vector weights; + std::vector group_root_mean_sums; +}; + +QuadratureShapModelData MakeQuadratureShapModelData(gbm::GBTreeModel const &model, + bst_tree_t tree_end, + std::vector const *tree_weights) { + auto const n_trees = static_cast(tree_end); + auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + auto const n_groups = model.learner_model_param->num_output_group; + + QuadratureShapModelData out; + out.trees.reserve(n_trees); + out.trees_by_group.resize(n_groups); + out.weights.resize(n_trees, 1.0f); + out.group_root_mean_sums.resize(n_groups, 0.0f); + + for (std::size_t i = 0; i < n_trees; ++i) { + out.trees.emplace_back(model.trees[i].get()); + } + for (bst_tree_t i = 0; i < tree_end; ++i) { + auto gid = h_tree_groups[i]; + auto weight = tree_weights == nullptr ? 1.0f : (*tree_weights)[i]; + out.trees_by_group[gid].push_back(i); + out.weights[i] = weight; + out.group_root_mean_sums[gid] += + static_cast(detail::FillRootMeanValue(out.trees[i], RegTree::kRoot) * weight); + } + return out; } template @@ -165,6 +850,215 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou LaunchShap(ctx, p_fmat, model, process_view); } +template +void QuadratureShapValuesImpl(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights) { + static_assert(Points == 4 || Points == 6 || Points == 8 || Points == 16); + + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + CHECK_GE(tree_end, 0); + ValidateTreeWeights(tree_weights, tree_end); + auto const n_threads = ctx->Threads(); + auto const n_groups = model.learner_model_param->num_output_group; + auto const n_features = model.learner_model_param->num_feature; + size_t const ncolumns = model.learner_model_param->num_feature + 1; + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); + std::fill(contribs.begin(), contribs.end(), 0.0f); + CHECK_NE(n_groups, 0); + auto const &rule = GetQuadratureRule(); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto model_data = MakeQuadratureShapModelData(model, tree_end, tree_weights); + std::vector feats_tloc(n_threads); + std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); + std::vector> path_prob_tloc( + n_threads, std::vector(n_features, kQuadratureShapUnseen)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto &path_prob = path_prob_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (auto j : model_data.trees_by_group[gid]) { + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0.0f); + auto formulation = + AdditiveContributionFormulation{{this_tree_contribs.data(), ncolumns}}; + auto runner = QuadratureShapTreeRunner>{ + model_data.trees[j], feats, rule, &path_prob, formulation}; + runner.Run(); + auto const weight = model_data.weights[j]; + for (size_t ci = 0; ci + 1 < ncolumns; ++ci) { + p_contribs[ci] += this_tree_contribs[ci] * weight; + } + } + p_contribs[ncolumns - 1] += model_data.group_root_mean_sums[gid]; + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + }; + + LaunchShap(ctx, p_fmat, model, process_view); +} + +void QuadratureShapValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + std::size_t quadrature_points) { + switch (quadrature_points) { + case 4: + QuadratureShapValuesImpl<4>(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + case 6: + QuadratureShapValuesImpl<6>(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + case 8: + QuadratureShapValuesImpl<8>(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + case 16: + QuadratureShapValuesImpl<16>(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + default: + LOG(FATAL) << "CPU QuadratureSHAP currently supports quadrature sizes of 4, 6, 8, or 16."; + } +} + +void QuadratureShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, + std::size_t quadrature_points) { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; + CHECK_EQ(quadrature_points, kQuadratureShapPoints) + << "CPU QuadratureSHAP currently uses a fixed quadrature size of " << kQuadratureShapPoints + << "."; + + MetaInfo const &info = p_fmat->Info(); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + CHECK_GE(tree_end, 0); + ValidateTreeWeights(tree_weights, tree_end); + + auto const n_threads = ctx->Threads(); + auto const n_groups = model.learner_model_param->num_output_group; + auto const n_features = model.learner_model_param->num_feature; + auto const ncolumns = n_features + 1; + auto const row_chunk = n_groups * ncolumns * ncolumns; + auto const matrix_chunk = ncolumns * ncolumns; + + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * row_chunk); + std::fill(contribs.begin(), contribs.end(), 0.0f); + + auto const &rule = GetQuadratureRule(); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto model_data = MakeQuadratureShapModelData(model, tree_end, tree_weights); + std::vector feats_tloc(n_threads); + std::vector> path_tloc(n_threads); + std::vector> path_prob_tloc( + n_threads, std::vector(n_features, kQuadratureShapUnseen)); + std::vector> latest_live_tloc( + n_threads, std::vector(n_features, -1)); + std::vector> diag_tloc(n_threads, std::vector(ncolumns)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &path = path_tloc[tid]; + auto &path_prob = path_prob_tloc[tid]; + auto &latest_live = latest_live_tloc[tid]; + auto &diag = diag_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + auto const offset = (row_idx * n_groups + gid) * matrix_chunk; + auto matrix = DenseInteractionMatrixView{contribs.data() + offset, ncolumns}; + std::fill(diag.begin(), diag.end(), 0.0f); + + for (auto j : model_data.trees_by_group[gid]) { + auto formulation = InteractionContributionFormulation{ + {&path, &latest_live}, + {diag.data(), ncolumns}, + {matrix.data, matrix.ncolumns}, + model_data.weights[j]}; + auto runner = + QuadratureShapTreeRunner>{ + model_data.trees[j], feats, rule, &path_prob, formulation}; + runner.Run(); + } + + diag[ncolumns - 1] += model_data.group_root_mean_sums[gid]; + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + diag[ncolumns - 1] += base_margin(row_idx, gid); + } else { + diag[ncolumns - 1] += base_score(gid); + } + + // The path-local return updates populate row-wise off-diagonal effects. Average the two + // directional estimates so the final matrix is explicitly symmetric. + for (size_t r = 0; r < ncolumns; ++r) { + for (size_t c = r + 1; c < ncolumns; ++c) { + auto const sym = 0.5f * (matrix(r, c) + matrix(c, r)); + matrix(r, c) = sym; + matrix(c, r) = sym; + } + } + + // Match the incumbent interaction semantics: each diagonal entry is the additive SHAP + // value minus the off-diagonal interactions in that row. + for (size_t r = 0; r < ncolumns; ++r) { + float value = diag[r]; + for (size_t c = 0; c < ncolumns; ++c) { + if (c != r) { + value -= matrix(r, c); + } + } + matrix(r, r) = value; + } + } + + feats.Drop(); + }); + }; + + LaunchShap(ctx, p_fmat, model, process_view); +} + void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, std::vector const *tree_weights) { @@ -246,9 +1140,9 @@ void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, MetaInfo const &info = p_fmat->Info(); auto const ngroup = model.learner_model_param->num_output_group; auto const ncolumns = model.learner_model_param->num_feature; - const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); - const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); - const unsigned crow_chunk = ngroup * (ncolumns + 1); + const std::size_t row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); + const std::size_t mrow_chunk = (ncolumns + 1) * (ncolumns + 1); + const std::size_t crow_chunk = ngroup * (ncolumns + 1); // allocate space for (number of features^2) times the number of rows and tmp off/on contribs std::vector &contribs = out_contribs->HostVector(); @@ -279,16 +1173,22 @@ void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, for (size_t j = 0; j < info.num_row_; ++j) { for (std::remove_const_t l = 0; l < ngroup; ++l) { - const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); - const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); - contribs[o_offset + i] = 0; + const std::size_t o_offset = j * row_chunk + l * mrow_chunk; + const std::size_t c_offset = j * crow_chunk + l * (ncolumns + 1); + auto matrix = + DenseInteractionMatrixView{contribs.data() + o_offset, ncolumns + 1}; + auto diag = + ContributionVectorView{contribs_diag.data() + c_offset, ncolumns + 1}; + auto off = ContributionVectorView{contribs_off.data() + c_offset, ncolumns + 1}; + auto on = ContributionVectorView{contribs_on.data() + c_offset, ncolumns + 1}; + matrix(i, i) = 0; for (size_t k = 0; k < ncolumns + 1; ++k) { // fill in the diagonal with additive effects, and off-diagonal with the interactions if (k == i) { - contribs[o_offset + i] += contribs_diag[c_offset + k]; + matrix(i, i) += diag[k]; } else { - contribs[o_offset + k] = (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; - contribs[o_offset + i] -= contribs[o_offset + k]; + matrix(i, k) = (on[k] - off[k]) / 2.0f; + matrix(i, i) -= matrix(i, k); } } } diff --git a/src/predictor/interpretability/shap.cu b/src/predictor/interpretability/shap.cu index 50d680f10a77..3b62b023893a 100644 --- a/src/predictor/interpretability/shap.cu +++ b/src/predictor/interpretability/shap.cu @@ -7,9 +7,13 @@ #include #include #include +#include #include #include +#include +#include +#include #include // for proclaim_return_type #include // for swap #include // for variant @@ -25,6 +29,7 @@ #include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/cuda_rt_utils.h" // for SetDevice #include "../../common/device_helpers.cuh" +#include "../../common/math.h" #include "../../common/nvtx_utils.h" #include "../../common/optional_weight.h" #include "../../data/batch_utils.h" // for StaticBatch @@ -36,6 +41,7 @@ #include "../gbtree_view.h" #include "../gpu_data_accessor.cuh" #include "../predict_fn.h" // for GetTreeLimit +#include "quadrature.h" #include "shap.h" #include "xgboost/data.h" #include "xgboost/host_device_vector.h" @@ -53,6 +59,1047 @@ using ::xgboost::cuda_impl::StaticBatch; using TreeViewVar = cuda::std::variant; +constexpr std::size_t kGpuQuadraturePoints = 8; +constexpr std::size_t kMaxGpuQuadratureDepth = 64; +constexpr std::size_t kGpuQuadratureRowsPerWarp = 4; +constexpr std::size_t kGpuQuadratureTreeBlockThreads = 64; +constexpr std::array kGpuQuadratureDepthBuckets{{16, 32, 64}}; +constexpr double kQuadratureShapQeps = 1e-15; +using QuadratureRule = detail::EndpointQuadratureRule; + +std::vector MakeGroupRootMeanSums(gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights) { + auto h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + auto n_groups = model.learner_model_param->num_output_group; + std::vector h_group_root_mean_sums(n_groups, 0.0); + for (bst_tree_t tree_idx = 0; tree_idx < tree_end; ++tree_idx) { + auto const weight = tree_weights == nullptr ? 1.0f : (*tree_weights)[tree_idx]; + auto const tree = model.trees.at(tree_idx)->HostScView(); + h_group_root_mean_sums[h_tree_groups[tree_idx]] += + detail::FillRootMeanValue(tree, RegTree::kRoot) * weight; + } + + std::vector out(h_group_root_mean_sums.size()); + std::transform(h_group_root_mean_sums.cbegin(), h_group_root_mean_sums.cend(), out.begin(), + [](double v) { return static_cast(v); }); + return out; +} + +struct CompressedNode { + bst_node_t left{RegTree::kInvalidNodeId}; + bst_node_t right{RegTree::kInvalidNodeId}; + bst_feature_t split_global{0}; + float split_cond{0}; + float leaf_value{0}; + float left_weight{0}; + float right_weight{0}; + std::uint32_t cat_begin{0}; + std::uint32_t cat_size{0}; + std::uint8_t default_left{0}; + std::uint8_t is_leaf{0}; + std::uint8_t is_categorical{0}; + std::uint8_t prev_same_offset_plus1{0}; +}; + +struct CompressedTree { + std::uint32_t node_begin{0}; + bst_target_t group{0}; +}; + +struct CompressedModel { + dh::device_vector trees; + dh::device_vector nodes; + dh::device_vector categories; +}; + +std::size_t DepthBucketIndex(std::size_t path_depth) { + for (std::size_t i = 0; i < kGpuQuadratureDepthBuckets.size(); ++i) { + if (path_depth <= kGpuQuadratureDepthBuckets[i]) { + return i; + } + } + LOG(FATAL) << "GPU QuadratureSHAP currently supports trees of depth up to " + << (kMaxGpuQuadratureDepth - 1) << "."; + return kGpuQuadratureDepthBuckets.size() - 1; +} + +CompressedModel MakeCompressedModel(Context const* ctx, gbm::GBTreeModel const& model, + std::vector const& tree_indices, + std::vector const* tree_weights) { + std::vector h_trees; + std::vector h_nodes; + std::vector h_categories; + auto h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + static_cast(ctx); + + h_trees.reserve(tree_indices.size()); + for (auto tree_idx : tree_indices) { + auto const weight = tree_weights == nullptr ? 1.0f : (*tree_weights)[tree_idx]; + auto const tree = model.trees.at(tree_idx)->HostScView(); + + auto node_begin = h_nodes.size(); + h_nodes.resize(node_begin + tree.Size()); + + for (bst_node_t nidx = 0; nidx < tree.Size(); ++nidx) { + auto& out = h_nodes[node_begin + nidx]; + if (tree.IsLeaf(nidx)) { + out.is_leaf = 1; + out.leaf_value = tree.LeafValue(nidx) * weight; + continue; + } + + auto left = tree.LeftChild(nidx); + auto right = tree.RightChild(nidx); + auto parent_cover = static_cast(tree.SumHess(nidx)); + CHECK_GT(parent_cover, 0.0); + + out.left = left; + out.right = right; + out.split_global = tree.SplitIndex(nidx); + out.split_cond = tree.SplitCond(nidx); + out.left_weight = static_cast(static_cast(tree.SumHess(left)) / parent_cover); + out.right_weight = + static_cast(static_cast(tree.SumHess(right)) / parent_cover); + if (common::IsCat(tree.cats.split_type, nidx)) { + auto node_cats = tree.NodeCats(nidx); + CHECK_LE(node_cats.size(), + static_cast(std::numeric_limits::max())); + CHECK_LE(h_categories.size(), + static_cast(std::numeric_limits::max())); + out.cat_begin = static_cast(h_categories.size()); + out.cat_size = static_cast(node_cats.size()); + out.is_categorical = 1; + h_categories.insert(h_categories.end(), node_cats.begin(), node_cats.end()); + } + out.default_left = tree.DefaultLeft(nidx); + out.is_leaf = 0; + } + + for (bst_node_t nidx = 0; nidx < tree.Size(); ++nidx) { + auto& out = h_nodes[node_begin + nidx]; + if (out.is_leaf) { + continue; + } + + std::uint8_t prev_same_offset_plus1 = 0; + std::uint16_t distance = 0; + auto ancestor = nidx; + while (!tree.IsRoot(ancestor)) { + ancestor = tree.Parent(ancestor); + ++distance; + auto const& ancestor_node = h_nodes[node_begin + ancestor]; + if (!ancestor_node.is_leaf && ancestor_node.split_global == out.split_global) { + prev_same_offset_plus1 = static_cast(distance + 1); + break; + } + } + out.prev_same_offset_plus1 = prev_same_offset_plus1; + } + + h_trees.push_back( + CompressedTree{static_cast(node_begin), h_tree_groups[tree_idx]}); + } + + CompressedModel out; + out.trees = dh::device_vector(h_trees.cbegin(), h_trees.cend()); + out.nodes = dh::device_vector(h_nodes.cbegin(), h_nodes.cend()); + out.categories = dh::device_vector(h_categories.cbegin(), h_categories.cend()); + return out; +} + +template +XGBOOST_DEVICE constexpr unsigned ActiveSubgroupMask(int row_slot) { + constexpr int kSegmentWidth = dh::WarpThreads() / RowsPerWarp; + static_assert(kSegmentWidth >= static_cast(kGpuQuadraturePoints)); + return ((1u << kGpuQuadraturePoints) - 1u) << (row_slot * kSegmentWidth); +} + +XGBOOST_DEVICE inline float PreviousPathProbability(std::uint8_t prev_same_offset_plus1, int depth, + float const* q_vals_row) { + if (prev_same_offset_plus1 == 0) { + return 1.0f; + } + auto prev_depth = depth - static_cast(prev_same_offset_plus1) + 1; + return q_vals_row[prev_depth]; +} + +XGBOOST_DEVICE inline float ExtractQuadratureEdgeDeltaLocal(float quad_node, float quad_weight, + float ret_val, float p_enter, + float q_prev) { + auto weighted_ret = ret_val * quad_weight; + float edge_delta = 0.0f; + if (p_enter != 1.0f) { + auto alpha_enter = p_enter - 1.0f; + edge_delta += alpha_enter / (1.0f + alpha_enter * quad_node); + } + if (q_prev != 1.0f) { + auto alpha_exit = q_prev - 1.0f; + edge_delta -= alpha_exit / (1.0f + alpha_exit * quad_node); + } + return weighted_ret * edge_delta; +} + +XGBOOST_DEVICE inline float ExtractQuadratureInteractionDeltaLocal(float quad_node, + float edge_delta_local, + float q_partner) { + if (q_partner == 1.0f) { + return 0.0f; + } + auto alpha_partner = q_partner - 1.0f; + return alpha_partner * edge_delta_local / (1.0f + alpha_partner * quad_node); +} + +template +struct IsSparsePageLoaderNoShared : std::false_type {}; + +template +struct IsSparsePageLoaderNoShared> : std::true_type {}; + +// Encapsulate the tail-tile versus full-tile differences so the traversal code can focus on +// probability updates instead of mask plumbing. +template +struct SubgroupOps { + static constexpr int kRowsPerWarpValue = RowsPerWarp; + static constexpr int kSegmentWidth = dh::WarpThreads() / RowsPerWarp; + static constexpr unsigned kFullMask = 0xffffffffu; + + int row_slot; + int point; + unsigned subgroup_mask; + unsigned warp_mask; + bool row_valid; + bool is_leader; + bool is_warp_leader; + + XGBOOST_DEV_INLINE SubgroupOps(int lane, bst_idx_t valid_rows_in_tail) + : row_slot{lane / kSegmentWidth}, + point{lane % kSegmentWidth}, + subgroup_mask{kFullMask}, + warp_mask{kFullMask}, + row_valid{true}, + is_leader{point == 0}, + is_warp_leader{lane == 0} { + if constexpr (kHasRowMask) { + subgroup_mask = ActiveSubgroupMask(row_slot); + warp_mask = __activemask(); + row_valid = static_cast(row_slot) < valid_rows_in_tail; + } + } + + [[nodiscard]] XGBOOST_DEV_INLINE bool Participates() const { return point < MaxPoints; } + + [[nodiscard]] XGBOOST_DEV_INLINE bool RowActive() const { + if constexpr (kHasRowMask) { + return row_valid; + } else { + return true; + } + } + + [[nodiscard]] XGBOOST_DEV_INLINE bool ShouldWrite() const { + return is_leader && this->RowActive(); + } + + template + [[nodiscard]] XGBOOST_DEV_INLINE T Broadcast(T value) const { + // Each row uses an independent kGpuQuadraturePoints-wide subgroup inside the warp. + if constexpr (kHasRowMask) { + return __shfl_sync(subgroup_mask, value, 0, MaxPoints); + } else { + return __shfl_sync(kFullMask, value, 0, MaxPoints); + } + } + + template + [[nodiscard]] XGBOOST_DEV_INLINE T Sum(T value) const { + for (int offset = MaxPoints / 2; offset > 0; offset /= 2) { + if constexpr (kHasRowMask) { + value += __shfl_down_sync(subgroup_mask, value, offset, MaxPoints); + } else { + value += __shfl_down_sync(kFullMask, value, offset, MaxPoints); + } + } + return value; + } + + XGBOOST_DEV_INLINE void Sync() const { + if constexpr (kHasRowMask) { + __syncwarp(warp_mask); + } else { + __syncwarp(); + } + } +}; + +// Wrap the shared-memory layout in semantic accessors so the task runner talks in terms of path +// state instead of raw multidimensional indexing. +template +struct QuadratureSharedState { + bst_node_t (&nodes)[kWarpsPerBlock][DepthCap]; + std::uint8_t (&stages)[kWarpsPerBlock][DepthCap]; + std::uint8_t (&goes_left)[kWarpsPerBlock][RowsPerWarp][DepthCap]; + // q_d(t): path probability at depth d for one row-slot evaluated at quadrature point t. + float (&path_prob)[kWarpsPerBlock][RowsPerWarp][DepthCap]; + // G_d(t): multiplicative basis carried down the path before the leaf value is applied. + float (&basis)[kWarpsPerBlock][RowsPerWarp][DepthCap][MaxPoints]; + float (&q_prev_cache)[kUseQPrevCache ? kWarpsPerBlock : 1][kUseQPrevCache ? RowsPerWarp : 1] + [kUseQPrevCache ? DepthCap : 1]; + + [[nodiscard]] XGBOOST_DEV_INLINE bst_node_t& Node(int warp, int depth) { + return nodes[warp][depth]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE bst_node_t const& Node(int warp, int depth) const { + return nodes[warp][depth]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE std::uint8_t& Stage(int warp, int depth) { + return stages[warp][depth]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE std::uint8_t const& Stage(int warp, int depth) const { + return stages[warp][depth]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE bool GoesLeft(int warp, int row_slot, int depth) const { + return static_cast(goes_left[warp][row_slot][depth]); + } + + XGBOOST_DEV_INLINE void SetGoesLeft(int warp, int row_slot, int depth, bool value) { + goes_left[warp][row_slot][depth] = static_cast(value); + } + + [[nodiscard]] XGBOOST_DEV_INLINE float& PathProbability(int warp, int row_slot, int depth) { + return path_prob[warp][row_slot][depth]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE float const* PathProbabilityRow(int warp, int row_slot) const { + return path_prob[warp][row_slot]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE float& Basis(int warp, int row_slot, int depth, int point) { + return basis[warp][row_slot][depth][point]; + } + + [[nodiscard]] XGBOOST_DEV_INLINE float LoadQPrev(int warp, int row_slot, int depth, + std::uint8_t prev_same_offset_plus1) const { + if constexpr (kUseQPrevCache) { + return q_prev_cache[warp][row_slot][depth]; + } else { + return PreviousPathProbability(prev_same_offset_plus1, depth, + this->PathProbabilityRow(warp, row_slot)); + } + } + + XGBOOST_DEV_INLINE void StoreQPrev(int warp, int row_slot, int depth, float q_prev) { + if constexpr (kUseQPrevCache) { + q_prev_cache[warp][row_slot][depth] = q_prev; + } + } +}; + +template +struct QuadratureShapTaskRunner { + Loader loader; + SubgroupT subgroup; + SharedT shared; + CompressedTree const* trees; + CompressedNode const* nodes; + std::uint32_t const* categories; + float* phis; + bst_idx_t base_rowid; + bst_target_t n_groups; + bst_feature_t n_columns; + std::size_t row_tile_begin; + std::size_t row_tiles; + int warp; + float quad_node; + float quad_weight; + + [[nodiscard]] XGBOOST_DEV_INLINE bool EvaluateGoesLeft(bst_idx_t ridx, + CompressedNode const& node) const { + auto fvalue = loader.GetElement(ridx, node.split_global); + if (common::CheckNAN(fvalue)) { + return static_cast(node.default_left); + } + if (node.is_categorical) { + auto cats = common::Span{categories + node.cat_begin, node.cat_size}; + return common::Decision(cats, fvalue); + } + return fvalue < node.split_cond; + } + + XGBOOST_DEV_INLINE void AddContribution(bst_idx_t row_idx, bst_target_t tree_group, + bst_feature_t split_global, float contrib) const { + if (!subgroup.ShouldWrite()) { + return; + } + auto out_row = phis + (row_idx * n_groups + tree_group) * n_columns; + atomicAdd(out_row + split_global, contrib); + } + + XGBOOST_DEV_INLINE void InitializeTask() { + if (subgroup.is_warp_leader) { + shared.Node(warp, 0) = RegTree::kRoot; + shared.Stage(warp, 0) = 0; + } + // Start each row with G_0(t) = 1 at every quadrature node. + shared.Basis(warp, subgroup.row_slot, 0, subgroup.point) = 1.0f; + subgroup.Sync(); + } + + XGBOOST_DEV_INLINE bool HandleReturn(bst_idx_t row_idx, bst_target_t tree_group, + CompressedNode const* nodes_for_tree, int* stack_size, + bool* have_return, float* ret_val) { + if (*stack_size == 0) { + return false; + } + + int parent_depth = *stack_size - 1; + auto const& node = nodes_for_tree[shared.Node(warp, parent_depth)]; + int child_idx = static_cast(shared.Stage(warp, parent_depth)) - 1; + + float p_enter = 0.0f; + float q_prev = 1.0f; + if (subgroup.is_leader && subgroup.RowActive()) { + p_enter = shared.PathProbability(warp, subgroup.row_slot, parent_depth); + q_prev = shared.LoadQPrev(warp, subgroup.row_slot, parent_depth, node.prev_same_offset_plus1); + } + p_enter = subgroup.Broadcast(p_enter); + q_prev = subgroup.Broadcast(q_prev); + + // Extraction uses + // H * w(t) * ret_val * + // [ (p_enter - 1) / (1 + (p_enter - 1) t) + // - (q_prev - 1) / (1 + (q_prev - 1) t) ]. + // The two rational terms are the "enter current feature" and "rewind to previous same + // feature" adjustments from the quadrature recurrence. + float contrib = + ExtractQuadratureEdgeDeltaLocal(quad_node, quad_weight, *ret_val, p_enter, q_prev); + contrib = subgroup.Sum(contrib); + this->AddContribution(row_idx, tree_group, node.split_global, contrib); + + if (child_idx == 0) { + auto child_weight = node.right_weight; + auto child_node = node.right; + float p_e = 0.0f; + if (subgroup.is_leader) { + if (subgroup.RowActive()) { + auto goes_left = shared.GoesLeft(warp, subgroup.row_slot, parent_depth); + p_e = goes_left ? 0.0f : q_prev / child_weight; + } + shared.PathProbability(warp, subgroup.row_slot, parent_depth) = p_e; + } + p_e = subgroup.Broadcast(p_e); + + if (subgroup.is_warp_leader) { + shared.Node(warp, *stack_size) = child_node; + shared.Stage(warp, *stack_size) = 0; + shared.Stage(warp, parent_depth) = 2; + } + // Push the sibling subtree with + // G_child(t) = G_parent(t) * child_weight * + // (1 + (p_e - 1) t) / (1 + (q_prev - 1) t). + // This preserves the basis after swapping the active feature state from q_prev to p_e. + auto alpha_e = p_e - 1.0f; + auto v = shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point) * child_weight * + (1.0f + alpha_e * quad_node); + if (q_prev != 1.0f) { + auto alpha_old = q_prev - 1.0f; + v /= 1.0f + alpha_old * quad_node; + } + shared.Basis(warp, subgroup.row_slot, *stack_size, subgroup.point) = v; + subgroup.Sync(); + shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point) = *ret_val; + (*stack_size)++; + *have_return = false; + } else { + *ret_val += shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point); + (*stack_size)--; + *have_return = true; + } + + return true; + } + + XGBOOST_DEV_INLINE void Descend(CompressedNode const* nodes_for_tree, bst_idx_t ridx, + int* stack_size, bool* have_return, float* ret_val) { + int depth = *stack_size - 1; + auto const& node = nodes_for_tree[shared.Node(warp, depth)]; + if (node.is_leaf) { + *ret_val = shared.Basis(warp, subgroup.row_slot, depth, subgroup.point) * node.leaf_value; + (*stack_size)--; + *have_return = true; + return; + } + + // stage == 0 explores the left child first. After the return path updates the parent state, + // the second visit uses the cached go-left decision to push the right child. + int child = static_cast(shared.Stage(warp, depth) != 0); + if (child == 0) { + if (subgroup.is_warp_leader) { + shared.Stage(warp, depth) = 1; + } + subgroup.Sync(); + } + + auto child_weight = child == 0 ? node.left_weight : node.right_weight; + auto child_node = child == 0 ? node.left : node.right; + float q_prev = 1.0f; + if (subgroup.is_leader) { + if (subgroup.RowActive()) { + q_prev = PreviousPathProbability(node.prev_same_offset_plus1, depth, + shared.PathProbabilityRow(warp, subgroup.row_slot)); + } + shared.StoreQPrev(warp, subgroup.row_slot, depth, q_prev); + } + q_prev = subgroup.Broadcast(q_prev); + + float p_e = 0.0f; + if (subgroup.is_leader) { + bool goes_left = false; + if (subgroup.RowActive()) { + goes_left = this->EvaluateGoesLeft(ridx, node); + // p_e is the path probability after taking the chosen child for this row. + p_e = (child == 0 ? goes_left : !goes_left) ? q_prev / child_weight : 0.0f; + } + shared.SetGoesLeft(warp, subgroup.row_slot, depth, goes_left); + shared.PathProbability(warp, subgroup.row_slot, depth) = p_e; + } + p_e = subgroup.Broadcast(p_e); + + if (subgroup.is_warp_leader) { + shared.Node(warp, *stack_size) = child_node; + shared.Stage(warp, *stack_size) = 0; + } + // Same recurrence as the sibling push above: reweight G_d(t) by the child weight and replace + // q_prev with the new path probability p_e at this depth. + auto alpha_e = p_e - 1.0f; + auto v = shared.Basis(warp, subgroup.row_slot, depth, subgroup.point) * child_weight * + (1.0f + alpha_e * quad_node); + if (q_prev != 1.0f) { + auto alpha_old = q_prev - 1.0f; + v /= 1.0f + alpha_old * quad_node; + } + shared.Basis(warp, subgroup.row_slot, *stack_size, subgroup.point) = v; + subgroup.Sync(); + (*stack_size)++; + } + + XGBOOST_DEV_INLINE void RunTask(std::size_t task) { + auto tree_idx = task / row_tiles; + auto row_tile = task % row_tiles; + auto ridx = (row_tile_begin + row_tile) * SubgroupT::kRowsPerWarpValue + subgroup.row_slot; + auto row_idx = base_rowid + static_cast(ridx); + auto tree = trees[tree_idx]; + auto nodes_for_tree = nodes + tree.node_begin; + + this->InitializeTask(); + + int stack_size = 1; + bool have_return = false; + float ret_val = 0.0f; + while (stack_size > 0 || have_return) { + if (have_return) { + if (!this->HandleReturn(row_idx, tree.group, nodes_for_tree, &stack_size, &have_return, + &ret_val)) { + break; + } + continue; + } + this->Descend(nodes_for_tree, ridx, &stack_size, &have_return, &ret_val); + } + } +}; + +template +struct QuadratureShapInteractionTaskRunner { + Loader loader; + SubgroupT subgroup; + SharedT shared; + CompressedTree const* trees; + CompressedNode const* nodes; + std::uint32_t const* categories; + float* phis; + bst_idx_t base_rowid; + bst_target_t n_groups; + bst_feature_t n_columns; + std::size_t row_tile_begin; + std::size_t row_tiles; + int warp; + float quad_node; + float quad_weight; + + [[nodiscard]] XGBOOST_DEV_INLINE bool EvaluateGoesLeft(bst_idx_t ridx, + CompressedNode const& node) const { + auto fvalue = loader.GetElement(ridx, node.split_global); + if (common::CheckNAN(fvalue)) { + return static_cast(node.default_left); + } + if (node.is_categorical) { + auto cats = common::Span{categories + node.cat_begin, node.cat_size}; + return common::Decision(cats, fvalue); + } + return fvalue < node.split_cond; + } + + XGBOOST_DEV_INLINE void AddDiagonalContribution(bst_idx_t row_idx, bst_target_t tree_group, + bst_feature_t split_global, float contrib) const { + if (!subgroup.ShouldWrite()) { + return; + } + auto out_idx = gpu_treeshap::IndexPhiInteractions(row_idx, n_groups, tree_group, n_columns - 1, + split_global, split_global); + atomicAdd(phis + out_idx, contrib); + } + + XGBOOST_DEV_INLINE void AddPairContribution(bst_idx_t row_idx, bst_target_t tree_group, + bst_feature_t split_i, bst_feature_t split_j, + float contrib) const { + if (!subgroup.ShouldWrite()) { + return; + } + auto out_idx = gpu_treeshap::IndexPhiInteractions(row_idx, n_groups, tree_group, n_columns - 1, + split_i, split_j); + atomicAdd(phis + out_idx, contrib); + } + + template + XGBOOST_DEV_INLINE void ForEachUniquePartner(CompressedNode const* nodes_for_tree, + int current_depth, bst_feature_t current_split, + Fn&& fn) const { + bool skipped_current = false; + for (int depth = current_depth; depth >= 0; --depth) { + auto const& candidate = nodes_for_tree[shared.Node(warp, depth)]; + if (candidate.is_leaf) { + continue; + } + auto split = candidate.split_global; + if (!skipped_current && split == current_split) { + skipped_current = true; + continue; + } + bool shadowed = false; + for (int newer = current_depth; newer > depth; --newer) { + auto const& newer_node = nodes_for_tree[shared.Node(warp, newer)]; + if (!newer_node.is_leaf && newer_node.split_global == split) { + shadowed = true; + break; + } + } + if (!shadowed) { + fn(depth, split); + } + } + } + + XGBOOST_DEV_INLINE void InitializeTask() { + if (subgroup.is_warp_leader) { + shared.Node(warp, 0) = RegTree::kRoot; + shared.Stage(warp, 0) = 0; + } + shared.Basis(warp, subgroup.row_slot, 0, subgroup.point) = 1.0f; + subgroup.Sync(); + } + + XGBOOST_DEV_INLINE bool HandleReturn(bst_idx_t row_idx, bst_target_t tree_group, + CompressedNode const* nodes_for_tree, int* stack_size, + bool* have_return, float* ret_val) { + if (*stack_size == 0) { + return false; + } + + int parent_depth = *stack_size - 1; + auto const& node = nodes_for_tree[shared.Node(warp, parent_depth)]; + int child_idx = static_cast(shared.Stage(warp, parent_depth)) - 1; + + float p_enter = 0.0f; + float q_prev = 1.0f; + if (subgroup.is_leader && subgroup.RowActive()) { + p_enter = shared.PathProbability(warp, subgroup.row_slot, parent_depth); + q_prev = shared.LoadQPrev(warp, subgroup.row_slot, parent_depth, node.prev_same_offset_plus1); + } + p_enter = subgroup.Broadcast(p_enter); + q_prev = subgroup.Broadcast(q_prev); + + auto edge_delta_local = + ExtractQuadratureEdgeDeltaLocal(quad_node, quad_weight, *ret_val, p_enter, q_prev); + auto diag_contrib = subgroup.Sum(edge_delta_local); + this->AddDiagonalContribution(row_idx, tree_group, node.split_global, diag_contrib); + + this->ForEachUniquePartner( + nodes_for_tree, parent_depth, node.split_global, + [&](int partner_depth, bst_feature_t partner_split) { + float q_partner = 1.0f; + if (subgroup.is_leader && subgroup.RowActive()) { + q_partner = shared.PathProbability(warp, subgroup.row_slot, partner_depth); + } + q_partner = subgroup.Broadcast(q_partner); + auto pair_delta_local = + ExtractQuadratureInteractionDeltaLocal(quad_node, edge_delta_local, q_partner); + auto pair_contrib = subgroup.Sum(pair_delta_local); + this->AddPairContribution(row_idx, tree_group, node.split_global, partner_split, + pair_contrib); + }); + + if (child_idx == 0) { + auto child_weight = node.right_weight; + auto child_node = node.right; + float p_e = 0.0f; + if (subgroup.is_leader) { + if (subgroup.RowActive()) { + auto goes_left = shared.GoesLeft(warp, subgroup.row_slot, parent_depth); + p_e = goes_left ? 0.0f : q_prev / child_weight; + } + shared.PathProbability(warp, subgroup.row_slot, parent_depth) = p_e; + } + p_e = subgroup.Broadcast(p_e); + + if (subgroup.is_warp_leader) { + shared.Node(warp, *stack_size) = child_node; + shared.Stage(warp, *stack_size) = 0; + shared.Stage(warp, parent_depth) = 2; + } + auto alpha_e = p_e - 1.0f; + auto v = shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point) * child_weight * + (1.0f + alpha_e * quad_node); + if (q_prev != 1.0f) { + auto alpha_old = q_prev - 1.0f; + v /= 1.0f + alpha_old * quad_node; + } + shared.Basis(warp, subgroup.row_slot, *stack_size, subgroup.point) = v; + subgroup.Sync(); + shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point) = *ret_val; + (*stack_size)++; + *have_return = false; + } else { + *ret_val += shared.Basis(warp, subgroup.row_slot, parent_depth, subgroup.point); + (*stack_size)--; + *have_return = true; + } + + return true; + } + + XGBOOST_DEV_INLINE void Descend(CompressedNode const* nodes_for_tree, bst_idx_t ridx, + int* stack_size, bool* have_return, float* ret_val) { + int depth = *stack_size - 1; + auto const& node = nodes_for_tree[shared.Node(warp, depth)]; + if (node.is_leaf) { + *ret_val = shared.Basis(warp, subgroup.row_slot, depth, subgroup.point) * node.leaf_value; + (*stack_size)--; + *have_return = true; + return; + } + + int child = static_cast(shared.Stage(warp, depth) != 0); + if (child == 0) { + if (subgroup.is_warp_leader) { + shared.Stage(warp, depth) = 1; + } + subgroup.Sync(); + } + + auto child_weight = child == 0 ? node.left_weight : node.right_weight; + auto child_node = child == 0 ? node.left : node.right; + float q_prev = 1.0f; + if (subgroup.is_leader) { + if (subgroup.RowActive()) { + q_prev = PreviousPathProbability(node.prev_same_offset_plus1, depth, + shared.PathProbabilityRow(warp, subgroup.row_slot)); + } + shared.StoreQPrev(warp, subgroup.row_slot, depth, q_prev); + } + q_prev = subgroup.Broadcast(q_prev); + + float p_e = 0.0f; + if (subgroup.is_leader) { + bool goes_left = false; + if (subgroup.RowActive()) { + goes_left = this->EvaluateGoesLeft(ridx, node); + p_e = (child == 0 ? goes_left : !goes_left) ? q_prev / child_weight : 0.0f; + } + shared.SetGoesLeft(warp, subgroup.row_slot, depth, goes_left); + shared.PathProbability(warp, subgroup.row_slot, depth) = p_e; + } + p_e = subgroup.Broadcast(p_e); + + if (subgroup.is_warp_leader) { + shared.Node(warp, *stack_size) = child_node; + shared.Stage(warp, *stack_size) = 0; + } + auto alpha_e = p_e - 1.0f; + auto v = shared.Basis(warp, subgroup.row_slot, depth, subgroup.point) * child_weight * + (1.0f + alpha_e * quad_node); + if (q_prev != 1.0f) { + auto alpha_old = q_prev - 1.0f; + v /= 1.0f + alpha_old * quad_node; + } + shared.Basis(warp, subgroup.row_slot, *stack_size, subgroup.point) = v; + subgroup.Sync(); + (*stack_size)++; + } + + XGBOOST_DEV_INLINE void RunTask(std::size_t task) { + auto tree_idx = task / row_tiles; + auto row_tile = task % row_tiles; + auto ridx = (row_tile_begin + row_tile) * SubgroupT::kRowsPerWarpValue + subgroup.row_slot; + auto row_idx = base_rowid + static_cast(ridx); + auto tree = trees[tree_idx]; + auto nodes_for_tree = nodes + tree.node_begin; + + this->InitializeTask(); + + int stack_size = 1; + bool have_return = false; + float ret_val = 0.0f; + while (stack_size > 0 || have_return) { + if (have_return) { + if (!this->HandleReturn(row_idx, tree.group, nodes_for_tree, &stack_size, &have_return, + &ret_val)) { + break; + } + continue; + } + this->Descend(nodes_for_tree, ridx, &stack_size, &have_return, &ret_val); + } + } +}; + +template +__global__ void __launch_bounds__(BlockThreads, 9) + QuadratureShapTaskKernel(Loader loader, bst_idx_t base_rowid, bst_target_t n_groups, + bst_feature_t n_columns, std::size_t row_tile_begin, + std::size_t row_tiles, bst_idx_t valid_rows_in_tail, + std::size_t n_trees, CompressedTree const* __restrict__ trees, + CompressedNode const* __restrict__ nodes, + std::uint32_t const* __restrict__ categories, + float const* __restrict__ quad_nodes, + float const* __restrict__ quad_weights, float* __restrict__ phis) { + static_assert(MaxPoints == kGpuQuadraturePoints); + static_assert(DepthCap <= static_cast(kMaxGpuQuadratureDepth)); + static_assert(dh::WarpThreads() % RowsPerWarp == 0); + static_assert(BlockThreads % dh::WarpThreads() == 0); + using SubgroupT = SubgroupOps; + constexpr int kSegmentWidth = SubgroupT::kSegmentWidth; + if constexpr (!kHasRowMask) { + static_assert(kSegmentWidth == MaxPoints, + "Full-tile specialization assumes every warp lane participates."); + } + constexpr int kWarpsPerBlock = BlockThreads / dh::WarpThreads(); + constexpr bool kUseQPrevCache = IsSparsePageLoaderNoShared::value; + using SharedT = + QuadratureSharedState; + + __shared__ bst_node_t s_node[kWarpsPerBlock][DepthCap]; + __shared__ std::uint8_t s_stage[kWarpsPerBlock][DepthCap]; + __shared__ std::uint8_t s_goes_left[kWarpsPerBlock][RowsPerWarp][DepthCap]; + __shared__ float s_path_p[kWarpsPerBlock][RowsPerWarp][DepthCap]; + __shared__ float s_c_vals[kWarpsPerBlock][RowsPerWarp][DepthCap][MaxPoints]; + __shared__ float s_q_prev[kUseQPrevCache ? kWarpsPerBlock : 1][kUseQPrevCache ? RowsPerWarp : 1] + [kUseQPrevCache ? DepthCap : 1]; + + int warp = static_cast(threadIdx.x) / dh::WarpThreads(); + int lane = static_cast(threadIdx.x) % dh::WarpThreads(); + auto subgroup = SubgroupT{lane, valid_rows_in_tail}; + if (!subgroup.Participates()) { + return; + } + + auto shared = SharedT{s_node, s_stage, s_goes_left, s_path_p, s_c_vals, s_q_prev}; + auto global_warp = + (static_cast(blockIdx.x) * BlockThreads + threadIdx.x) / dh::WarpThreads(); + auto warp_stride = (static_cast(gridDim.x) * BlockThreads) / dh::WarpThreads(); + auto n_tasks = n_trees * row_tiles; + + auto runner = QuadratureShapTaskRunner{loader, + subgroup, + shared, + trees, + nodes, + categories, + phis, + base_rowid, + n_groups, + n_columns, + row_tile_begin, + row_tiles, + warp, + quad_nodes[subgroup.point], + quad_weights[subgroup.point]}; + + for (std::size_t task = global_warp; task < n_tasks; task += warp_stride) { + runner.RunTask(task); + } +} + +template +void LaunchQuadratureShapTasks(Context const* ctx, Loader loader, bst_idx_t base_rowid, + bst_target_t n_groups, bst_feature_t n_columns, + std::size_t row_tile_begin, std::size_t row_tiles, + bst_idx_t valid_rows_in_tail, CompressedModel const& compressed, + common::Span quad_nodes, + common::Span quad_weights, + HostDeviceVector* out_contribs) { + static_assert(BlockThreads % dh::WarpThreads() == 0); + constexpr int kWarpsPerBlock = BlockThreads / dh::WarpThreads(); + if (compressed.trees.empty() || row_tiles == 0) { + return; + } + auto trees = thrust::raw_pointer_cast(compressed.trees.data()); + auto nodes = thrust::raw_pointer_cast(compressed.nodes.data()); + auto categories = thrust::raw_pointer_cast(compressed.categories.data()); + auto d_quad_nodes = quad_nodes.data(); + auto d_quad_weights = quad_weights.data(); + auto phis = out_contribs->DeviceSpan().data(); + auto n_tasks = compressed.trees.size() * row_tiles; + auto grids = common::DivRoundUp(n_tasks, static_cast(kWarpsPerBlock)); + QuadratureShapTaskKernel + <<(grids), static_cast(BlockThreads), 0, + ctx->CUDACtx()->Stream()>>>(loader, base_rowid, n_groups, n_columns, row_tile_begin, + row_tiles, valid_rows_in_tail, compressed.trees.size(), trees, + nodes, categories, d_quad_nodes, d_quad_weights, phis); + dh::safe_cuda(cudaGetLastError()); +} + +template +void LaunchQuadratureShapBuckets(Context const* ctx, Loader loader, bst_idx_t base_rowid, + bst_target_t n_groups, bst_feature_t n_columns, + CompressedModel const& compressed, + common::Span quad_nodes, + common::Span quad_weights, + HostDeviceVector* out_contribs) { + auto full_row_tiles = static_cast(loader.NumRows() / RowsPerWarp); + auto tail_rows = static_cast(loader.NumRows() % RowsPerWarp); + LaunchQuadratureShapTasks( + ctx, loader, base_rowid, n_groups, n_columns, /*row_tile_begin=*/0, full_row_tiles, + /*valid_rows_in_tail=*/RowsPerWarp, compressed, quad_nodes, quad_weights, out_contribs); + if (tail_rows != 0) { + LaunchQuadratureShapTasks( + ctx, loader, base_rowid, n_groups, n_columns, /*row_tile_begin=*/full_row_tiles, + /*row_tiles=*/1, tail_rows, compressed, quad_nodes, quad_weights, out_contribs); + } +} + +template +__global__ void __launch_bounds__(BlockThreads, 9) QuadratureShapInteractionTaskKernel( + Loader loader, bst_idx_t base_rowid, bst_target_t n_groups, bst_feature_t n_columns, + std::size_t row_tile_begin, std::size_t row_tiles, bst_idx_t valid_rows_in_tail, + std::size_t n_trees, CompressedTree const* __restrict__ trees, + CompressedNode const* __restrict__ nodes, std::uint32_t const* __restrict__ categories, + float const* __restrict__ quad_nodes, float const* __restrict__ quad_weights, + float* __restrict__ phis) { + static_assert(MaxPoints == kGpuQuadraturePoints); + static_assert(DepthCap <= static_cast(kMaxGpuQuadratureDepth)); + static_assert(dh::WarpThreads() % RowsPerWarp == 0); + static_assert(BlockThreads % dh::WarpThreads() == 0); + using SubgroupT = SubgroupOps; + constexpr int kSegmentWidth = SubgroupT::kSegmentWidth; + if constexpr (!kHasRowMask) { + static_assert(kSegmentWidth == MaxPoints, + "Full-tile specialization assumes every warp lane participates."); + } + constexpr int kWarpsPerBlock = BlockThreads / dh::WarpThreads(); + constexpr bool kUseQPrevCache = IsSparsePageLoaderNoShared::value; + using SharedT = + QuadratureSharedState; + + __shared__ bst_node_t s_node[kWarpsPerBlock][DepthCap]; + __shared__ std::uint8_t s_stage[kWarpsPerBlock][DepthCap]; + __shared__ std::uint8_t s_goes_left[kWarpsPerBlock][RowsPerWarp][DepthCap]; + __shared__ float s_path_p[kWarpsPerBlock][RowsPerWarp][DepthCap]; + __shared__ float s_c_vals[kWarpsPerBlock][RowsPerWarp][DepthCap][MaxPoints]; + __shared__ float s_q_prev[kUseQPrevCache ? kWarpsPerBlock : 1][kUseQPrevCache ? RowsPerWarp : 1] + [kUseQPrevCache ? DepthCap : 1]; + + int warp = static_cast(threadIdx.x) / dh::WarpThreads(); + int lane = static_cast(threadIdx.x) % dh::WarpThreads(); + auto subgroup = SubgroupT{lane, valid_rows_in_tail}; + if (!subgroup.Participates()) { + return; + } + + auto shared = SharedT{s_node, s_stage, s_goes_left, s_path_p, s_c_vals, s_q_prev}; + auto global_warp = + (static_cast(blockIdx.x) * BlockThreads + threadIdx.x) / dh::WarpThreads(); + auto warp_stride = (static_cast(gridDim.x) * BlockThreads) / dh::WarpThreads(); + auto n_tasks = n_trees * row_tiles; + + auto runner = + QuadratureShapInteractionTaskRunner{loader, + subgroup, + shared, + trees, + nodes, + categories, + phis, + base_rowid, + n_groups, + n_columns, + row_tile_begin, + row_tiles, + warp, + quad_nodes[subgroup.point], + quad_weights[subgroup.point]}; + + for (std::size_t task = global_warp; task < n_tasks; task += warp_stride) { + runner.RunTask(task); + } +} + +template +void LaunchQuadratureShapInteractionTasks(Context const* ctx, Loader loader, bst_idx_t base_rowid, + bst_target_t n_groups, bst_feature_t n_columns, + std::size_t row_tile_begin, std::size_t row_tiles, + bst_idx_t valid_rows_in_tail, + CompressedModel const& compressed, + common::Span quad_nodes, + common::Span quad_weights, + HostDeviceVector* out_contribs) { + static_assert(BlockThreads % dh::WarpThreads() == 0); + constexpr int kWarpsPerBlock = BlockThreads / dh::WarpThreads(); + if (compressed.trees.empty() || row_tiles == 0) { + return; + } + auto trees = thrust::raw_pointer_cast(compressed.trees.data()); + auto nodes = thrust::raw_pointer_cast(compressed.nodes.data()); + auto categories = thrust::raw_pointer_cast(compressed.categories.data()); + auto d_quad_nodes = quad_nodes.data(); + auto d_quad_weights = quad_weights.data(); + auto phis = out_contribs->DeviceSpan().data(); + auto n_tasks = compressed.trees.size() * row_tiles; + auto grids = common::DivRoundUp(n_tasks, static_cast(kWarpsPerBlock)); + QuadratureShapInteractionTaskKernel + <<(grids), static_cast(BlockThreads), 0, + ctx->CUDACtx()->Stream()>>>(loader, base_rowid, n_groups, n_columns, row_tile_begin, + row_tiles, valid_rows_in_tail, compressed.trees.size(), trees, + nodes, categories, d_quad_nodes, d_quad_weights, phis); + dh::safe_cuda(cudaGetLastError()); +} + +template +void LaunchQuadratureShapInteractionBuckets(Context const* ctx, Loader loader, bst_idx_t base_rowid, + bst_target_t n_groups, bst_feature_t n_columns, + CompressedModel const& compressed, + common::Span quad_nodes, + common::Span quad_weights, + HostDeviceVector* out_contribs) { + auto full_row_tiles = static_cast(loader.NumRows() / RowsPerWarp); + auto tail_rows = static_cast(loader.NumRows() % RowsPerWarp); + LaunchQuadratureShapInteractionTasks( + ctx, loader, base_rowid, n_groups, n_columns, /*row_tile_begin=*/0, full_row_tiles, + /*valid_rows_in_tail=*/RowsPerWarp, compressed, quad_nodes, quad_weights, out_contribs); + if (tail_rows != 0) { + LaunchQuadratureShapInteractionTasks( + ctx, loader, base_rowid, n_groups, n_columns, /*row_tile_begin=*/full_row_tiles, + /*row_tiles=*/1, tail_rows, compressed, quad_nodes, quad_weights, out_contribs); + } +} + struct CopyViews { Context const* ctx; explicit CopyViews(Context const* ctx) : ctx{ctx} {} @@ -370,6 +1417,204 @@ void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* ou }); } +void QuadratureShapValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + std::size_t quadrature_points) { + xgboost_NVTX_FN_RANGE(); + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + CHECK_EQ(quadrature_points, kGpuQuadraturePoints) + << "GPU QuadratureSHAP currently uses a fixed quadrature size of " << kGpuQuadraturePoints + << "."; + + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + auto const ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + auto const ncolumns = model.learner_model_param->num_feature + 1; + auto dim_size = ncolumns * ngroup; + out_contribs->SetDevice(ctx->Device()); + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + + bst_node_t max_depth = 0; + std::array, kGpuQuadratureDepthBuckets.size()> tree_buckets; + for (bst_tree_t tree_idx = 0; tree_idx < tree_end; ++tree_idx) { + CHECK(!model.trees[tree_idx]->IsMultiTarget()) << "Predict contribution" << MTNotImplemented(); + auto tree_depth = model.trees[tree_idx]->MaxDepth(); + max_depth = std::max(max_depth, tree_depth); + auto path_depth = static_cast(tree_depth) + 1; + auto bucket_idx = DepthBucketIndex(path_depth); + tree_buckets[bucket_idx].push_back(tree_idx); + } + CHECK_LE(max_depth + 1, static_cast(kMaxGpuQuadratureDepth)) + << "GPU QuadratureSHAP currently supports trees of depth up to " + << (kMaxGpuQuadratureDepth - 1) << "."; + auto h_group_root_mean_sums = MakeGroupRootMeanSums(model, tree_end, tree_weights); + + auto rule = detail::MakeEndpointQuadrature(kQuadratureShapQeps); + std::array h_quad_nodes{}; + std::array h_quad_weights{}; + for (std::size_t i = 0; i < kGpuQuadraturePoints; ++i) { + h_quad_nodes[i] = static_cast(rule.nodes[i]); + h_quad_weights[i] = static_cast(rule.weights[i]); + } + dh::device_vector d_quad_nodes(h_quad_nodes.cbegin(), h_quad_nodes.cend()); + dh::device_vector d_quad_weights(h_quad_weights.cbegin(), h_quad_weights.cend()); + dh::device_vector d_group_root_mean_sums(h_group_root_mean_sums.cbegin(), + h_group_root_mean_sums.cend()); + auto compressed_16 = MakeCompressedModel(ctx, model, tree_buckets[0], tree_weights); + auto compressed_32 = MakeCompressedModel(ctx, model, tree_buckets[1], tree_weights); + auto compressed_64 = MakeCompressedModel(ctx, model, tree_buckets[2], tree_weights); + + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + auto quad_nodes = + common::Span{thrust::raw_pointer_cast(d_quad_nodes.data()), d_quad_nodes.size()}; + auto quad_weights = common::Span{thrust::raw_pointer_cast(d_quad_weights.data()), + d_quad_weights.size()}; + auto group_root_mean_sums = common::Span{ + thrust::raw_pointer_cast(d_group_root_mean_sums.data()), d_group_root_mean_sums.size()}; + + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + LaunchQuadratureShapBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_16, quad_nodes, quad_weights, + out_contribs); + LaunchQuadratureShapBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_32, quad_nodes, quad_weights, + out_contribs); + LaunchQuadratureShapBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_64, quad_nodes, quad_weights, + out_contribs); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + auto base_score = model.learner_model_param->BaseScore(ctx); + auto phis = out_contribs->DeviceSpan(); + auto n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { + auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); + phis[(idx + 1) * ncolumns - 1] += + group_root_mean_sums[gid] + (margin.empty() ? base_score(gid) : margin[idx]); + }); +} + +void QuadratureShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, + std::size_t quadrature_points) { + xgboost_NVTX_FN_RANGE(); + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; + CHECK_EQ(quadrature_points, kGpuQuadraturePoints) + << "GPU QuadratureSHAP currently uses a fixed quadrature size of " << kGpuQuadraturePoints + << "."; + + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + auto const ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + auto const ncolumns = model.learner_model_param->num_feature + 1; + auto const n_features = model.learner_model_param->num_feature; + auto dim_size = ncolumns * ncolumns * ngroup; + out_contribs->SetDevice(ctx->Device()); + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + + bst_node_t max_depth = 0; + std::array, kGpuQuadratureDepthBuckets.size()> tree_buckets; + for (bst_tree_t tree_idx = 0; tree_idx < tree_end; ++tree_idx) { + CHECK(!model.trees[tree_idx]->IsMultiTarget()) + << "Predict interaction contribution" << MTNotImplemented(); + auto tree_depth = model.trees[tree_idx]->MaxDepth(); + max_depth = std::max(max_depth, tree_depth); + auto path_depth = static_cast(tree_depth) + 1; + auto bucket_idx = DepthBucketIndex(path_depth); + tree_buckets[bucket_idx].push_back(tree_idx); + } + CHECK_LE(max_depth + 1, static_cast(kMaxGpuQuadratureDepth)) + << "GPU QuadratureSHAP currently supports trees of depth up to " + << (kMaxGpuQuadratureDepth - 1) << "."; + + auto h_group_root_mean_sums = MakeGroupRootMeanSums(model, tree_end, tree_weights); + auto rule = detail::MakeEndpointQuadrature(kQuadratureShapQeps); + std::array h_quad_nodes{}; + std::array h_quad_weights{}; + for (std::size_t i = 0; i < kGpuQuadraturePoints; ++i) { + h_quad_nodes[i] = static_cast(rule.nodes[i]); + h_quad_weights[i] = static_cast(rule.weights[i]); + } + dh::device_vector d_quad_nodes(h_quad_nodes.cbegin(), h_quad_nodes.cend()); + dh::device_vector d_quad_weights(h_quad_weights.cbegin(), h_quad_weights.cend()); + dh::device_vector d_group_root_mean_sums(h_group_root_mean_sums.cbegin(), + h_group_root_mean_sums.cend()); + auto compressed_16 = MakeCompressedModel(ctx, model, tree_buckets[0], tree_weights); + auto compressed_32 = MakeCompressedModel(ctx, model, tree_buckets[1], tree_weights); + auto compressed_64 = MakeCompressedModel(ctx, model, tree_buckets[2], tree_weights); + + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + auto quad_nodes = + common::Span{thrust::raw_pointer_cast(d_quad_nodes.data()), d_quad_nodes.size()}; + auto quad_weights = common::Span{thrust::raw_pointer_cast(d_quad_weights.data()), + d_quad_weights.size()}; + auto group_root_mean_sums = common::Span{ + thrust::raw_pointer_cast(d_group_root_mean_sums.data()), d_group_root_mean_sums.size()}; + + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + LaunchQuadratureShapInteractionBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_16, quad_nodes, quad_weights, + out_contribs); + LaunchQuadratureShapInteractionBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_32, quad_nodes, quad_weights, + out_contribs); + LaunchQuadratureShapInteractionBuckets( + ctx, loader, base_rowid, ngroup, ncolumns, compressed_64, quad_nodes, quad_weights, + out_contribs); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + auto base_score = model.learner_model_param->BaseScore(ctx); + auto phis = out_contribs->DeviceSpan(); + auto n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { + auto [ridx, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); + auto bias_idx = + gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gid, n_features, n_features, n_features); + phis[bias_idx] += group_root_mean_sums[gid] + (margin.empty() ? base_score(gid) : margin[idx]); + + auto matrix_offset = gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gid, n_features, 0, 0); + auto matrix = phis.subspan(matrix_offset, ncolumns * ncolumns); + for (bst_feature_t r = 0; r < ncolumns; ++r) { + for (bst_feature_t c = r + 1; c < ncolumns; ++c) { + auto sym = 0.5f * (matrix[r * ncolumns + c] + matrix[c * ncolumns + r]); + matrix[r * ncolumns + c] = sym; + matrix[c * ncolumns + r] = sym; + } + } + for (bst_feature_t r = 0; r < ncolumns; ++r) { + float value = matrix[r * ncolumns + r]; + for (bst_feature_t c = 0; c < ncolumns; ++c) { + if (c != r) { + value -= matrix[r * ncolumns + c]; + } + } + matrix[r * ncolumns + r] = value; + } + }); +} + void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, std::vector const* tree_weights, diff --git a/src/predictor/interpretability/shap.h b/src/predictor/interpretability/shap.h index 2c32d1d84554..7dad71d7e91f 100644 --- a/src/predictor/interpretability/shap.h +++ b/src/predictor/interpretability/shap.h @@ -14,11 +14,37 @@ struct GBTreeModel; } // namespace xgboost::gbm namespace xgboost::interpretability { +namespace detail { +template +inline double FillRootMeanValue(TreeView const& tree, NodeIndex nidx) { + if (tree.IsLeaf(nidx)) { + return tree.LeafValue(nidx); + } + auto left = tree.LeftChild(nidx); + auto right = tree.RightChild(nidx); + double result = FillRootMeanValue(tree, left) * tree.SumHess(left); + result += FillRootMeanValue(tree, right) * tree.SumHess(right); + result /= tree.SumHess(nidx); + return result; +} +} // namespace detail + namespace cpu_impl { void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, std::vector const* tree_weights, int condition, unsigned condition_feature); +void QuadratureShapValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + std::size_t quadrature_points); + +void QuadratureShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, + std::size_t quadrature_points); + void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, std::vector const* tree_weights); @@ -34,6 +60,15 @@ namespace cuda_impl { void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, std::vector const* tree_weights, int condition, unsigned condition_feature); +void QuadratureShapValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + std::size_t quadrature_points); +void QuadratureShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, + std::size_t quadrature_points); void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, std::vector const* tree_weights); diff --git a/src/predictor/treeshap.cc b/src/predictor/treeshap.cc deleted file mode 100644 index bae297c973a8..000000000000 --- a/src/predictor/treeshap.cc +++ /dev/null @@ -1,232 +0,0 @@ -/** - * Copyright 2017-2025, XGBoost Contributors - */ -#include "treeshap.h" - -#include // copy -#include // std::uint32_t - -#include "../tree/tree_view.h" // for ScalarTreeView -#include "predict_fn.h" // GetNextNode -#include "xgboost/base.h" // bst_node_t -#include "xgboost/logging.h" -#include "xgboost/tree_model.h" // RegTree - -namespace xgboost { -void CalculateContributionsApprox(tree::ScalarTreeView const& tree, const RegTree::FVec& feat, - std::vector* mean_values, float* out_contribs) { - CHECK_GT(mean_values->size(), 0U); - bst_feature_t split_index = 0; - // update bias value - float node_value = (*mean_values)[0]; - out_contribs[feat.Size()] += node_value; - if (tree.IsLeaf(RegTree::kRoot)) { - // nothing to do anymore - return; - } - - bst_node_t nidx = 0; - auto const& cats = tree.GetCategoriesMatrix(); - - while (!tree.IsLeaf(nidx)) { - split_index = tree.SplitIndex(nidx); - nidx = predictor::GetNextNode(tree, nidx, feat.GetFvalue(split_index), - feat.IsMissing(split_index), cats); - bst_float new_value = (*mean_values)[nidx]; - // update feature weight - out_contribs[split_index] += new_value - node_value; - node_value = new_value; - } - float leaf_value = tree.LeafValue(nidx); - // update leaf feature weight - out_contribs[split_index] += leaf_value - node_value; -} - -// Used by TreeShap -// data we keep about our decision path -// note that pweight is included for convenience and is not tied with the other attributes -// the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them -struct PathElement { - int feature_index; - float zero_fraction; - float one_fraction; - float pweight; - PathElement() = default; - PathElement(int i, float z, float o, float w) - : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {} -}; - -// extend our decision path with a fraction of one and zero extensions -void ExtendPath(PathElement* unique_path, std::uint32_t unique_depth, float zero_fraction, - float one_fraction, int feature_index) { - unique_path[unique_depth].feature_index = feature_index; - unique_path[unique_depth].zero_fraction = zero_fraction; - unique_path[unique_depth].one_fraction = one_fraction; - unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f); - for (int i = unique_depth - 1; i >= 0; i--) { - unique_path[i + 1].pweight += - one_fraction * unique_path[i].pweight * (i + 1) / static_cast(unique_depth + 1); - unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) / - static_cast(unique_depth + 1); - } -} - -// undo a previous extension of the decision path -void UnwindPath(PathElement* unique_path, std::uint32_t unique_depth, std::uint32_t path_index) { - const float one_fraction = unique_path[path_index].one_fraction; - const float zero_fraction = unique_path[path_index].zero_fraction; - float next_one_portion = unique_path[unique_depth].pweight; - - for (int i = unique_depth - 1; i >= 0; --i) { - if (one_fraction != 0) { - const float tmp = unique_path[i].pweight; - unique_path[i].pweight = - next_one_portion * (unique_depth + 1) / static_cast((i + 1) * one_fraction); - next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) / - static_cast(unique_depth + 1); - } else { - unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1)) / - static_cast(zero_fraction * (unique_depth - i)); - } - } - - for (auto i = path_index; i < unique_depth; ++i) { - unique_path[i].feature_index = unique_path[i + 1].feature_index; - unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction; - unique_path[i].one_fraction = unique_path[i + 1].one_fraction; - } -} - -// determine what the total permutation weight would be if -// we unwound a previous extension in the decision path -float UnwoundPathSum(const PathElement* unique_path, std::uint32_t unique_depth, - std::uint32_t path_index) { - const float one_fraction = unique_path[path_index].one_fraction; - const float zero_fraction = unique_path[path_index].zero_fraction; - float next_one_portion = unique_path[unique_depth].pweight; - float total = 0; - for (int i = unique_depth - 1; i >= 0; --i) { - if (one_fraction != 0) { - const float tmp = - next_one_portion * (unique_depth + 1) / static_cast((i + 1) * one_fraction); - total += tmp; - next_one_portion = - unique_path[i].pweight - - tmp * zero_fraction * ((unique_depth - i) / static_cast(unique_depth + 1)); - } else if (zero_fraction != 0) { - total += (unique_path[i].pweight / zero_fraction) / - ((unique_depth - i) / static_cast(unique_depth + 1)); - } else { - CHECK_EQ(unique_path[i].pweight, 0) << "Unique path " << i << " must have zero weight"; - } - } - return total; -} - -/** - * \brief Recursive function that computes the feature attributions for a single tree. - * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \param phi dense output vector of feature attributions - * \param node_index the index of the current node in the tree - * \param unique_depth how many unique features are above the current node in the tree - * \param parent_unique_path a vector of statistics about our current path through the tree - * \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated) - * \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed) - * \param parent_feature_index what feature the parent node used to split - * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) - * \param condition_feature the index of the feature to fix - * \param condition_fraction what fraction of the current weight matches our conditioning feature - */ -void TreeShap(tree::ScalarTreeView const& tree, const RegTree::FVec& feat, float* phi, - bst_node_t nidx, std::uint32_t unique_depth, PathElement* parent_unique_path, - float parent_zero_fraction, float parent_one_fraction, int parent_feature_index, - int condition, std::uint32_t condition_feature, float condition_fraction) { - // stop if we have no weight coming down to us - if (condition_fraction == 0) return; - - // extend the unique path - PathElement* unique_path = parent_unique_path + unique_depth + 1; - std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path); - - if (condition == 0 || condition_feature != static_cast(parent_feature_index)) { - ExtendPath(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction, - parent_feature_index); - } - const std::uint32_t split_index = tree.SplitIndex(nidx); - - // leaf node - if (tree.IsLeaf(nidx)) { - for (std::uint32_t i = 1; i <= unique_depth; ++i) { - const float w = UnwoundPathSum(unique_path, unique_depth, i); - const PathElement& el = unique_path[i]; - phi[el.feature_index] += - w * (el.one_fraction - el.zero_fraction) * tree.LeafValue(nidx) * condition_fraction; - } - - // internal node - } else { - // find which branch is "hot" (meaning x would follow it) - auto const& cats = tree.GetCategoriesMatrix(); - bst_node_t hot_index = predictor::GetNextNode( - tree, nidx, feat.GetFvalue(split_index), feat.IsMissing(split_index), cats); - - const auto cold_index = - (hot_index == tree.LeftChild(nidx) ? tree.RightChild(nidx) : tree.LeftChild(nidx)); - const float w = tree.Stat(nidx).sum_hess; - const float hot_zero_fraction = tree.Stat(hot_index).sum_hess / w; - const float cold_zero_fraction = tree.Stat(cold_index).sum_hess / w; - float incoming_zero_fraction = 1; - float incoming_one_fraction = 1; - - // see if we have already split on this feature, - // if so we undo that split so we can redo it for this node - std::uint32_t path_index = 0; - for (; path_index <= unique_depth; ++path_index) { - if (static_cast(unique_path[path_index].feature_index) == split_index) break; - } - if (path_index != unique_depth + 1) { - incoming_zero_fraction = unique_path[path_index].zero_fraction; - incoming_one_fraction = unique_path[path_index].one_fraction; - UnwindPath(unique_path, unique_depth, path_index); - unique_depth -= 1; - } - - // divide up the condition_fraction among the recursive calls - float hot_condition_fraction = condition_fraction; - float cold_condition_fraction = condition_fraction; - if (condition > 0 && split_index == condition_feature) { - cold_condition_fraction = 0; - unique_depth -= 1; - } else if (condition < 0 && split_index == condition_feature) { - hot_condition_fraction *= hot_zero_fraction; - cold_condition_fraction *= cold_zero_fraction; - unique_depth -= 1; - } - - TreeShap(tree, feat, phi, hot_index, unique_depth + 1, unique_path, - hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index, - condition, condition_feature, hot_condition_fraction); - - TreeShap(tree, feat, phi, cold_index, unique_depth + 1, unique_path, - cold_zero_fraction * incoming_zero_fraction, 0, split_index, condition, - condition_feature, cold_condition_fraction); - } -} - -void CalculateContributions(tree::ScalarTreeView const& tree, const RegTree::FVec& feat, - std::vector* mean_values, float* out_contribs, int condition, - std::uint32_t condition_feature) { - // find the expected value of the tree's predictions - if (condition == 0) { - float node_value = (*mean_values)[0]; - out_contribs[feat.Size()] += node_value; - } - - // Preallocate space for the unique path data - bst_node_t const maxd = tree.MaxDepth() + 2; - std::vector unique_path_data((maxd * (maxd + 1)) / 2); - - TreeShap(tree, feat, out_contribs, 0, 0, unique_path_data.data(), 1, 1, -1, condition, - condition_feature, 1); -} -} // namespace xgboost diff --git a/src/predictor/treeshap.h b/src/predictor/treeshap.h deleted file mode 100644 index 69423dd9d4bd..000000000000 --- a/src/predictor/treeshap.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2017-2025, XGBoost Contributors - */ -#pragma once - -#include // for vector - -#include "xgboost/tree_model.h" // for RegTree - -namespace xgboost { -/** - * @brief calculate the approximate feature contributions for the given root - * - * This follows the idea of http://blog.datadive.net/interpreting-random-forests/ - * - * @param feat dense feature vector, if the feature is missing the field is set to NaN - * @param out_contribs output vector to hold the contributions - */ -void CalculateContributionsApprox(tree::ScalarTreeView const& tree, const RegTree::FVec& feat, - std::vector* mean_values, float* out_contribs); - -/** - * @brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree - * - * @param feat dense feature vector, if the feature is missing the field is set to NaN - * @param out_contribs output vector to hold the contributions - * @param condition fix one feature to either off (-1) on (1) or not fixed (0 default) - * @param condition_feature the index of the feature to fix - */ -void CalculateContributions(tree::ScalarTreeView const& tree, const RegTree::FVec& feat, - std::vector* mean_values, float* out_contribs, int condition, - unsigned condition_feature); -} // namespace xgboost diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index b86cb498733f..5292224b4d79 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -20,10 +20,58 @@ #include "../../../src/common/param_array.h" #include "../../../src/gbm/gbtree_model.h" #include "../../../src/predictor/interpretability/shap.h" +#include "../../../src/tree/tree_view.h" #include "../helpers.h" namespace xgboost { namespace { +struct CoverStats { + double min_child_weight{1.0}; + double min_path_weight{1.0}; + std::size_t internal_nodes{0}; + std::size_t leaves{0}; + std::size_t count_lt_1e2{0}; + std::size_t count_lt_1e3{0}; + std::size_t count_lt_1e4{0}; + std::size_t count_lt_1e5{0}; +}; + +void AccumulateCoverStats(tree::ScalarTreeView const& tree, bst_node_t nidx, double path_min_weight, + CoverStats* stats) { + if (tree.IsLeaf(nidx)) { + ++stats->leaves; + stats->min_path_weight = std::min(stats->min_path_weight, path_min_weight); + return; + } + + ++stats->internal_nodes; + auto left = tree.LeftChild(nidx); + auto right = tree.RightChild(nidx); + auto parent_cover = static_cast(tree.Stat(nidx).sum_hess); + CHECK_GT(parent_cover, 0.0); + + auto visit_child = [&](bst_node_t child) { + auto child_weight = static_cast(tree.Stat(child).sum_hess) / parent_cover; + stats->min_child_weight = std::min(stats->min_child_weight, child_weight); + if (child_weight < 1e-2) { + ++stats->count_lt_1e2; + } + if (child_weight < 1e-3) { + ++stats->count_lt_1e3; + } + if (child_weight < 1e-4) { + ++stats->count_lt_1e4; + } + if (child_weight < 1e-5) { + ++stats->count_lt_1e5; + } + AccumulateCoverStats(tree, child, std::min(path_min_weight, child_weight), stats); + }; + + visit_child(left); + visit_child(right); +} + void SetLabels(DMatrix* dmat, bst_target_t n_classes) { size_t const rows = dmat->Info().num_row_; dmat->Info().labels.Reshape(rows, 1); @@ -291,6 +339,126 @@ TEST(Predictor, DartShapOutputCPU) { CheckDartShapOutput(&ctx); } +TEST(Predictor, QuadratureShapPrototypeMatchesTreeShapCPU) { + Context ctx; + size_t constexpr kRows = 256; + size_t constexpr kCols = 1; + + auto dmat = RandomDataGenerator(kRows, kCols, 0.0).Device(ctx.Device()).GenerateDMatrix(); + SetLabels(dmat.get(), 1); + + auto args = BaseParams(&ctx, "binary:logistic", "6"); + args.emplace_back("tree_method", "exact"); + + std::shared_ptr p_dmat{dmat.get(), [](DMatrix*) {}}; + std::unique_ptr learner{Learner::Create({p_dmat})}; + learner->SetParams(args); + learner->Configure(); + for (size_t i = 0; i < 3; ++i) { + learner->UpdateOneIter(i, p_dmat); + } + + HostDeviceVector margin_predt; + learner->Predict(p_dmat, true, &margin_predt, 0, 0, false, false, false, false, false); + + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), args, &mparam); + + HostDeviceVector treeshap; + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &treeshap, *gbtree, 0, nullptr, 0, 0); + + HostDeviceVector quadrature_shap; + interpretability::cpu_impl::QuadratureShapValues(dmat->Ctx(), p_dmat.get(), &quadrature_shap, + *gbtree, 0, nullptr, 16); + + auto const& h_treeshap = treeshap.ConstHostVector(); + auto const& h_quadrature = quadrature_shap.ConstHostVector(); + ASSERT_EQ(h_treeshap.size(), h_quadrature.size()); + for (size_t i = 0; i < h_treeshap.size(); ++i) { + EXPECT_NEAR(h_treeshap[i], h_quadrature[i], 1e-4f); + } + + CheckShapAdditivity(kRows, kCols, quadrature_shap, margin_predt); +} + +TEST(Predictor, QuadratureShapSelectorMatchesTreeShapCPU) { + Context ctx; + size_t constexpr kRows = 256; + size_t constexpr kCols = 1; + + auto dmat = RandomDataGenerator(kRows, kCols, 0.0).Device(ctx.Device()).GenerateDMatrix(); + SetLabels(dmat.get(), 1); + + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParams(BaseParams(&ctx, "binary:logistic", "6")); + learner->SetParam("tree_method", "exact"); + learner->Configure(); + for (size_t i = 0; i < 3; ++i) { + learner->UpdateOneIter(i, dmat); + } + + HostDeviceVector margin_predt; + learner->Predict(dmat, true, &margin_predt, 0, 0, false, false, false, false, false); + + HostDeviceVector treeshap; + learner->Predict(dmat, false, &treeshap, 0, 0, false, false, true, false, false); + + learner->SetParam("shap_algorithm", "quadratureshap"); + learner->SetParam("quadratureshap_points", "8"); + learner->Configure(); + + HostDeviceVector quadrature_shap; + learner->Predict(dmat, false, &quadrature_shap, 0, 0, false, false, true, false, false); + + auto const& h_treeshap = treeshap.ConstHostVector(); + auto const& h_quadrature = quadrature_shap.ConstHostVector(); + ASSERT_EQ(h_treeshap.size(), h_quadrature.size()); + for (size_t i = 0; i < h_treeshap.size(); ++i) { + EXPECT_NEAR(h_treeshap[i], h_quadrature[i], 1e-4f); + } + + CheckShapAdditivity(kRows, kCols, quadrature_shap, margin_predt); +} + +TEST(Predictor, QuadratureShapExactCoverStatsCPU) { + Context ctx; + size_t constexpr kRows = 256; + size_t constexpr kCols = 1; + + auto dmat = RandomDataGenerator(kRows, kCols, 0.0).Device(ctx.Device()).GenerateDMatrix(); + SetLabels(dmat.get(), 1); + + auto args = BaseParams(&ctx, "binary:logistic", "6"); + args.emplace_back("tree_method", "exact"); + + std::shared_ptr p_dmat{dmat.get(), [](DMatrix*) {}}; + std::unique_ptr learner{Learner::Create({p_dmat})}; + learner->SetParams(args); + learner->Configure(); + for (size_t i = 0; i < 3; ++i) { + learner->UpdateOneIter(i, p_dmat); + } + + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), args, &mparam); + + CoverStats stats; + for (auto const& tree : gbtree->trees) { + AccumulateCoverStats(tree->HostScView(), RegTree::kRoot, 1.0, &stats); + } + + std::cout << "QuadratureShap exact cover stats: internal_nodes=" << stats.internal_nodes + << " leaves=" << stats.leaves << " min_child_weight=" << stats.min_child_weight + << " max_first_occurrence_p=" << (1.0 / stats.min_child_weight) + << " min_path_weight=" << stats.min_path_weight + << " max_path_first_occurrence_p=" << (1.0 / stats.min_path_weight) + << " count_lt_1e-2=" << stats.count_lt_1e2 << " count_lt_1e-3=" << stats.count_lt_1e3 + << " count_lt_1e-4=" << stats.count_lt_1e4 << " count_lt_1e-5=" << stats.count_lt_1e5 + << std::endl; + + EXPECT_GT(stats.internal_nodes, 0); +} + TEST(Predictor, ApproxContribsBasic) { Context ctx; size_t constexpr kRows = 64; diff --git a/tests/cpp/predictor/test_shap.cu b/tests/cpp/predictor/test_shap.cu index ea9039409be7..5f3119bb02b1 100644 --- a/tests/cpp/predictor/test_shap.cu +++ b/tests/cpp/predictor/test_shap.cu @@ -85,6 +85,66 @@ TEST(GPUPredictor, ShapOutputCasesGPU) { } } +TEST(GPUPredictor, CompareCPUQuadratureShap) { + auto ctx = MakeCUDACtx(0); + Context cpu_ctx; + bst_feature_t constexpr kCols{10}; + bst_idx_t constexpr kRows{256}; + std::size_t constexpr kIters{8}; + + HostDeviceVector predictions; + HostDeviceVector cpu_predictions; + + auto dmat = RandomDataGenerator(kRows, kCols, 0.0).Device(ctx.Device()).GenerateDMatrix(); + dmat->Info().labels.Reshape(kRows, 1); + auto& h_labels = dmat->Info().labels.Data()->HostVector(); + for (size_t i = 0; i < kRows; ++i) { + h_labels[i] = i % 2; + } + + std::unique_ptr learner{Learner::Create({dmat})}; + learner->SetParams(Args{{"objective", "binary:logistic"}, + {"tree_method", "hist"}, + {"max_depth", "8"}, + {"min_split_loss", "0"}, + {"min_child_weight", "0"}, + {"reg_lambda", "0"}, + {"reg_alpha", "0"}, + {"subsample", "1"}, + {"colsample_bytree", "1"}, + {"device", ctx.DeviceName()}}); + learner->Configure(); + for (std::size_t i = 0; i < kIters; ++i) { + learner->UpdateOneIter(i, dmat); + } + + Json model{Object{}}; + learner->SaveModel(&model); + + std::unique_ptr learner_gpu{Learner::Create({})}; + learner_gpu->LoadModel(model); + learner_gpu->SetParam("device", ctx.DeviceName()); + learner_gpu->SetParam("shap_algorithm", "quadratureshap"); + learner_gpu->SetParam("quadratureshap_points", "8"); + learner_gpu->Configure(); + + std::unique_ptr learner_cpu{Learner::Create({})}; + learner_cpu->LoadModel(model); + learner_cpu->SetParam("device", cpu_ctx.DeviceName()); + learner_cpu->SetParam("shap_algorithm", "quadratureshap"); + learner_cpu->SetParam("quadratureshap_points", "8"); + learner_cpu->Configure(); + + learner_gpu->Predict(dmat, false, &predictions, 0, 0, false, false, true, false, false); + learner_cpu->Predict(dmat, false, &cpu_predictions, 0, 0, false, false, true, false, false); + auto& phis = predictions.HostVector(); + auto& cpu_phis = cpu_predictions.HostVector(); + ASSERT_EQ(cpu_phis.size(), phis.size()); + for (auto i = 0ull; i < phis.size(); ++i) { + EXPECT_NEAR(cpu_phis[i], phis[i], 1e-4); + } +} + TEST(GPUPredictor, DartShapOutputGPU) { auto ctx = MakeCUDACtx(0); CheckDartShapOutput(&ctx);